LUCENE-10592 Build HNSW Graph on indexing (#992)

Currently, when indexing knn vectors, we buffer them in memory and
on flush during a segment construction we build an HNSW graph.
As building an HNSW graph is very expensive, this makes flush
operation take a lot of time. This also makes overall indexing
performance quite unpredictable – some indexing operations return
almost instantly while others that trigger flush take a lot of time.
This happens because flushes are unpredictable and trigged
by memory used, presence of concurrent searches etc.

Building an HNSW graph as we index documents avoid these problems,
as the load of HNSW graph construction is spread evenly during indexing.

Co-authored-by: Adrien Grand <jpountz@gmail.com>
This commit is contained in:
Mayya Sharipova 2022-07-22 11:29:28 -04:00 committed by GitHub
parent bd360f9b3e
commit ba4bc04271
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1176 additions and 614 deletions

View File

@ -23,7 +23,7 @@ import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
@ -41,7 +41,7 @@ import org.apache.lucene.util.IOUtils;
*
* @lucene.experimental
*/
public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
private final SegmentWriteState segmentWriteState;
private final IndexOutput meta, vectorData, vectorIndex;
@ -55,7 +55,6 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
this.maxConn = maxConn;
this.beamWidth = beamWidth;
assert state.fieldInfos.hasVectorValues();
segmentWriteState = state;
String metaFileName =
@ -107,7 +106,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
}
@Override
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);

View File

@ -23,7 +23,7 @@ import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
@ -43,11 +43,10 @@ import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
*
* @lucene.experimental
*/
public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
private final SegmentWriteState segmentWriteState;
private final IndexOutput meta, vectorData, vectorIndex;
private final int maxDoc;
private final int maxConn;
private final int beamWidth;
@ -58,7 +57,6 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
this.maxConn = maxConn;
this.beamWidth = beamWidth;
assert state.fieldInfos.hasVectorValues();
segmentWriteState = state;
String metaFileName =
@ -101,7 +99,6 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
Lucene91HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
maxDoc = state.segmentInfo.maxDoc();
success = true;
} finally {
if (success == false) {
@ -111,7 +108,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
}
@Override
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
@ -149,6 +146,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
fieldInfo,
maxDoc,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset,
@ -186,6 +184,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
private void writeMeta(
FieldInfo field,
int maxDoc,
long vectorDataOffset,
long vectorDataLength,
long vectorIndexOffset,

View File

@ -24,8 +24,8 @@ import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
@ -49,11 +49,10 @@ import org.apache.lucene.util.packed.DirectMonotonicWriter;
*
* @lucene.experimental
*/
public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
private final SegmentWriteState segmentWriteState;
private final IndexOutput meta, vectorData, vectorIndex;
private final int maxDoc;
private final int M;
private final int beamWidth;
@ -63,7 +62,6 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
this.M = M;
this.beamWidth = beamWidth;
assert state.fieldInfos.hasVectorValues();
segmentWriteState = state;
String metaFileName =
@ -106,7 +104,6 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
Lucene92HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
maxDoc = state.segmentInfo.maxDoc();
success = true;
} finally {
if (success == false) {
@ -116,7 +113,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
}
@Override
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
@ -155,6 +152,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
fieldInfo,
maxDoc,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset,
@ -192,6 +190,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
private void writeMeta(
FieldInfo field,
int maxDoc,
long vectorDataOffset,
long vectorDataLength,
long vectorIndexOffset,

View File

@ -24,7 +24,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentWriteState;
@ -35,7 +35,7 @@ import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IOUtils;
/** Writes vector-valued fields in a plain text format */
public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter {
public class SimpleTextKnnVectorsWriter extends BufferingKnnVectorsWriter {
static final BytesRef FIELD_NUMBER = new BytesRef("field-number ");
static final BytesRef FIELD_NAME = new BytesRef("field-name ");
@ -48,8 +48,6 @@ public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter {
private final BytesRefBuilder scratch = new BytesRefBuilder();
SimpleTextKnnVectorsWriter(SegmentWriteState state) throws IOException {
assert state.fieldInfos.hasVectorValues();
boolean success = false;
// exception handling to pass TestSimpleTextKnnVectorsFormat#testRandomExceptions
try {
@ -75,7 +73,7 @@ public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter {
}
@Override
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc)
throws IOException {
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
long vectorDataOffset = vectorData.getFilePointer();

View File

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs;
import java.io.IOException;
import org.apache.lucene.util.Accountable;
/** Vectors' writer for a field */
public abstract class KnnFieldVectorsWriter implements Accountable {
/** Sole constructor */
protected KnnFieldVectorsWriter() {}
/**
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
* increasing order.
*/
public abstract void addValue(int docID, float[] vectorValue) throws IOException;
}

View File

@ -24,30 +24,44 @@ import java.util.List;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BytesRef;
/** Writes vectors to an index. */
public abstract class KnnVectorsWriter implements Closeable {
public abstract class KnnVectorsWriter implements Accountable, Closeable {
/** Sole constructor */
protected KnnVectorsWriter() {}
/** Write all values contained in the provided reader */
public abstract void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException;
/** Add new field for indexing */
public abstract KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException;
/** Flush all buffered data on disk * */
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
/** Write field for merging */
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
KnnFieldVectorsWriter writer = addField(fieldInfo);
VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
for (int doc = mergedValues.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = mergedValues.nextDoc()) {
writer.addValue(doc, mergedValues.vectorValue());
}
}
/** Called once at the end before close */
public abstract void finish() throws IOException;
/**
* Merges the segment vectors for all fields. This default implementation delegates to {@link
* #writeField}, passing a {@link KnnVectorsReader} that combines the vector values and ignores
* #mergeOneField}, passing a {@link KnnVectorsReader} that combines the vector values and ignores
* deleted documents.
*/
public void merge(MergeState mergeState) throws IOException {
public final void merge(MergeState mergeState) throws IOException {
for (int i = 0; i < mergeState.fieldInfos.length; i++) {
KnnVectorsReader reader = mergeState.knnVectorsReaders[i];
assert reader != null || mergeState.fieldInfos[i].hasVectorValues() == false;
@ -62,35 +76,7 @@ public abstract class KnnVectorsWriter implements Closeable {
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
}
writeField(
fieldInfo,
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() {
throw new UnsupportedOperationException();
}
@Override
public void checkIntegrity() {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
}
@Override
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
});
mergeOneField(fieldInfo, mergeState);
if (mergeState.infoStream.isEnabled("VV")) {
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
@ -118,7 +104,7 @@ public abstract class KnnVectorsWriter implements Closeable {
}
/** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
private static class MergedVectorValues extends VectorValues {
protected static class MergedVectorValues extends VectorValues {
private final List<VectorValuesSub> subs;
private final DocIDMerger<VectorValuesSub> docIdMerger;
private final int cost;
@ -128,7 +114,7 @@ public abstract class KnnVectorsWriter implements Closeable {
private VectorValuesSub current;
/** Returns a merged view over all the segment's {@link VectorValues}. */
static MergedVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
public static MergedVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues();

View File

@ -21,23 +21,32 @@ import static org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat.DIRECT
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
@ -53,19 +62,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private final SegmentWriteState segmentWriteState;
private final IndexOutput meta, vectorData, vectorIndex;
private final int maxDoc;
private final int M;
private final int beamWidth;
private final List<FieldWriter> fields = new ArrayList<>();
private boolean finished;
Lucene94HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException {
this.M = M;
this.beamWidth = beamWidth;
assert state.fieldInfos.hasVectorValues();
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene94HnswVectorsFormat.META_EXTENSION);
@ -106,7 +112,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
Lucene94HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
maxDoc = state.segmentInfo.maxDoc();
success = true;
} finally {
if (success == false) {
@ -116,10 +121,231 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
}
@Override
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
FieldWriter newField = new FieldWriter(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
fields.add(newField);
return newField;
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter field : fields) {
if (sortMap == null) {
writeField(field, maxDoc);
} else {
writeSortingField(field, maxDoc, sortMap);
}
}
}
@Override
public void finish() throws IOException {
if (finished) {
throw new IllegalStateException("already finished");
}
finished = true;
if (meta != null) {
// write end of fields marker
meta.writeInt(-1);
CodecUtil.writeFooter(meta);
}
if (vectorData != null) {
CodecUtil.writeFooter(vectorData);
CodecUtil.writeFooter(vectorIndex);
}
}
@Override
public long ramBytesUsed() {
long total = 0;
for (FieldWriter field : fields) {
total += field.ramBytesUsed();
}
return total;
}
private void writeField(FieldWriter fieldData, int maxDoc) throws IOException {
// write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (float[] vector : fieldData.vectors) {
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph
long vectorIndexOffset = vectorIndex.getFilePointer();
OnHeapHnswGraph graph = fieldData.getGraph();
writeGraph(graph);
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
fieldData.fieldInfo,
maxDoc,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset,
vectorIndexLength,
fieldData.docsWithField,
graph);
}
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
throws IOException {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
final int[] ordMap = new int[offset - 1]; // new ord to old ord
final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord
int ord = 0;
int doc = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord] = docIdOffset - 1;
oldOrdMap[docIdOffset - 1] = ord;
newDocsWithField.add(doc);
ord++;
}
doc++;
}
// write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (int ordinal : ordMap) {
float[] vector = fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph
long vectorIndexOffset = vectorIndex.getFilePointer();
OnHeapHnswGraph graph = fieldData.getGraph();
HnswGraph mockGraph = reconstructAndWriteGraph(graph, ordMap, oldOrdMap);
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
fieldData.fieldInfo,
maxDoc,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset,
vectorIndexLength,
newDocsWithField,
mockGraph);
}
// reconstruct graph substituting old ordinals with new ordinals
private HnswGraph reconstructAndWriteGraph(
OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap) throws IOException {
if (graph == null) return null;
List<int[]> nodesByLevel = new ArrayList<>(graph.numLevels());
nodesByLevel.add(null);
int maxOrd = graph.size();
int maxConnOnLevel = M * 2;
NodesIterator nodesOnLevel0 = graph.getNodesOnLevel(0);
while (nodesOnLevel0.hasNext()) {
int node = nodesOnLevel0.nextInt();
NeighborArray neighbors = graph.getNeighbors(0, newToOldMap[node]);
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxConnOnLevel, maxOrd);
}
maxConnOnLevel = M;
for (int level = 1; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
int[] newNodes = new int[nodesOnLevel.size()];
int n = 0;
while (nodesOnLevel.hasNext()) {
newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()];
}
Arrays.sort(newNodes);
nodesByLevel.add(newNodes);
for (int node : newNodes) {
NeighborArray neighbors = graph.getNeighbors(level, newToOldMap[node]);
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxConnOnLevel, maxOrd);
}
}
return new HnswGraph() {
@Override
public int nextNeighbor() {
throw new UnsupportedOperationException("Not supported on a mock graph");
}
@Override
public void seek(int level, int target) {
throw new UnsupportedOperationException("Not supported on a mock graph");
}
@Override
public int size() {
return graph.size();
}
@Override
public int numLevels() {
return graph.numLevels();
}
@Override
public int entryNode() {
throw new UnsupportedOperationException("Not supported on a mock graph");
}
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return graph.getNodesOnLevel(0);
} else {
return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
}
}
};
}
private void reconstructAndWriteNeigbours(
NeighborArray neighbors, int[] oldToNewMap, int maxConnOnLevel, int maxOrd)
throws IOException {
int size = neighbors.size();
vectorIndex.writeInt(size);
// Destructively modify; it's ok we are discarding it after this
int[] nnodes = neighbors.node();
for (int i = 0; i < size; i++) {
nnodes[i] = oldToNewMap[nnodes[i]];
}
Arrays.sort(nnodes, 0, size);
for (int i = 0; i < size; i++) {
int nnode = nnodes[i];
assert nnode < maxOrd : "node too large: " + nnode + ">=" + maxOrd;
vectorIndex.writeInt(nnode);
}
// if number of connections < maxConn,
// add bogus values up to maxConn to have predictable offsets
for (int i = size; i < maxConnOnLevel; i++) {
vectorIndex.writeInt(0);
}
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
IndexOutput tempVectorData =
segmentWriteState.directory.createTempOutput(
@ -148,13 +374,24 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
OffHeapVectorValues offHeapVectors =
new OffHeapVectorValues.DenseOffHeapVectorValues(
vectors.dimension(), docsWithField.cardinality(), vectorDataInput);
OnHeapHnswGraph graph =
offHeapVectors.size() == 0
? null
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
OnHeapHnswGraph graph = null;
if (offHeapVectors.size() != 0) {
// build graph
HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(
offHeapVectors,
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
graph = hnswGraphBuilder.build(offHeapVectors.randomAccess());
writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
vectorDataOffset,
vectorDataLength,
vectorIndexOffset,
@ -174,30 +411,44 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
}
}
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
private static DocsWithFieldSet writeVectorData(IndexOutput output, VectorValues vectors)
throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
// write vector
BytesRef binaryValue = vectors.binaryValue();
assert binaryValue.length == vectors.dimension() * Float.BYTES;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docsWithField.add(docV);
private void writeGraph(OnHeapHnswGraph graph) throws IOException {
if (graph == null) return;
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) {
int maxConnOnLevel = level == 0 ? (M * 2) : M;
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
vectorIndex.writeInt(size);
// Destructively modify; it's ok we are discarding it after this
int[] nnodes = neighbors.node();
Arrays.sort(nnodes, 0, size);
for (int i = 0; i < size; i++) {
int nnode = nnodes[i];
assert nnode < countOnLevel0 : "node too large: " + nnode + ">=" + countOnLevel0;
vectorIndex.writeInt(nnode);
}
// if number of connections < maxConn, add bogus values up to maxConn to have predictable
// offsets
for (int i = size; i < maxConnOnLevel; i++) {
vectorIndex.writeInt(0);
}
}
}
return docsWithField;
}
private void writeMeta(
FieldInfo field,
int maxDoc,
long vectorDataOffset,
long vectorDataLength,
long vectorIndexOffset,
long vectorIndexLength,
DocsWithFieldSet docsWithField,
OnHeapHnswGraph graph)
HnswGraph graph)
throws IOException {
meta.writeInt(field.number);
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
@ -266,65 +517,129 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
}
}
private OnHeapHnswGraph writeGraph(
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
private static DocsWithFieldSet writeVectorData(IndexOutput output, VectorValues vectors)
throws IOException {
// build graph
HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(
vectorValues, similarityFunction, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) {
int maxConnOnLevel = level == 0 ? (M * 2) : M;
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
vectorIndex.writeInt(size);
// Destructively modify; it's ok we are discarding it after this
int[] nnodes = neighbors.node();
Arrays.sort(nnodes, 0, size);
for (int i = 0; i < size; i++) {
int nnode = nnodes[i];
assert nnode < countOnLevel0 : "node too large: " + nnode + ">=" + countOnLevel0;
vectorIndex.writeInt(nnode);
}
// if number of connections < maxConn, add bogus values up to maxConn to have predictable
// offsets
for (int i = size; i < maxConnOnLevel; i++) {
vectorIndex.writeInt(0);
}
}
}
return graph;
}
@Override
public void finish() throws IOException {
if (finished) {
throw new IllegalStateException("already finished");
}
finished = true;
if (meta != null) {
// write end of fields marker
meta.writeInt(-1);
CodecUtil.writeFooter(meta);
}
if (vectorData != null) {
CodecUtil.writeFooter(vectorData);
CodecUtil.writeFooter(vectorIndex);
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
// write vector
BytesRef binaryValue = vectors.binaryValue();
assert binaryValue.length == vectors.dimension() * Float.BYTES;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docsWithField.add(docV);
}
return docsWithField;
}
@Override
public void close() throws IOException {
IOUtils.close(meta, vectorData, vectorIndex);
}
private static class FieldWriter extends KnnFieldVectorsWriter {
private final FieldInfo fieldInfo;
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<float[]> vectors;
private final RAVectorValues raVectorValues;
private final HnswGraphBuilder hnswGraphBuilder;
private int lastDocID = -1;
private int node = 0;
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException {
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
raVectorValues = new RAVectorValues(vectors, dim);
hnswGraphBuilder =
new HnswGraphBuilder(
() -> raVectorValues,
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
}
@Override
public void addValue(int docID, float[] vectorValue) throws IOException {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
if (vectorValue.length != dim) {
throw new IllegalArgumentException(
"Attempt to index a vector of dimension "
+ vectorValue.length
+ " but \""
+ fieldInfo.name
+ "\" has dimension "
+ dim);
}
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
if (node > 0) {
// start at node 1! node 0 is added implicitly, in the constructor
hnswGraphBuilder.addGraphNode(node, vectorValue);
}
node++;
lastDocID = docID;
}
OnHeapHnswGraph getGraph() {
if (vectors.size() > 0) {
return hnswGraphBuilder.getGraph();
} else {
return null;
}
}
@Override
public long ramBytesUsed() {
if (vectors.size() == 0) return 0;
return docsWithField.ramBytesUsed()
+ vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ vectors.size() * vectors.get(0).length * Float.BYTES
+ hnswGraphBuilder.getGraph().ramBytesUsed();
}
}
private static class RAVectorValues implements RandomAccessVectorValues {
private final List<float[]> vectors;
private final int dim;
RAVectorValues(List<float[]> vectors, int dim) {
this.vectors = vectors;
this.dim = dim;
}
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public float[] vectorValue(int targetOrd) throws IOException {
return vectors.get(targetOrd);
}
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
}
}

View File

@ -19,13 +19,11 @@ package org.apache.lucene.codecs.perfield;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.TreeMap;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
@ -33,6 +31,7 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
@ -102,34 +101,21 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
getInstance(fieldInfo).writeField(fieldInfo, knnVectorsReader);
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
KnnVectorsWriter writer = getInstance(fieldInfo);
return writer.addField(fieldInfo);
}
@Override
public final void merge(MergeState mergeState) throws IOException {
Map<KnnVectorsWriter, Collection<String>> writersToFields = new IdentityHashMap<>();
// Group each writer by the fields it handles
for (FieldInfo fi : mergeState.mergeFieldInfos) {
if (fi.hasVectorValues() == false) {
continue;
}
KnnVectorsWriter writer = getInstance(fi);
Collection<String> fields = writersToFields.computeIfAbsent(writer, k -> new ArrayList<>());
fields.add(fi.name);
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (WriterAndSuffix was : formats.values()) {
was.writer.flush(maxDoc, sortMap);
}
}
// Delegate the merge to the appropriate writer
PerFieldMergeState pfMergeState = new PerFieldMergeState(mergeState);
try {
for (Map.Entry<KnnVectorsWriter, Collection<String>> e : writersToFields.entrySet()) {
e.getKey().merge(pfMergeState.apply(e.getValue()));
}
} finally {
pfMergeState.reset();
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
getInstance(fieldInfo).mergeOneField(fieldInfo, mergeState);
}
@Override
@ -180,10 +166,18 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
assert suffixes.containsKey(formatName);
suffix = writerAndSuffix.suffix;
}
field.putAttribute(PER_FIELD_SUFFIX_KEY, Integer.toString(suffix));
return writerAndSuffix.writer;
}
@Override
public long ramBytesUsed() {
long total = 0;
for (WriterAndSuffix was : formats.values()) {
total += was.writer.ramBytesUsed();
}
return total;
}
}
/** VectorReader that can wrap multiple delegate readers, selected by field. */

View File

@ -0,0 +1,277 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.RamUsageEstimator;
/**
* Buffers up pending vector value(s) per doc, then flushes when segment flushes. Used for {@code
* SimpleTextKnnVectorsWriter} and for vectors writers before v 9.3 .
*
* @lucene.experimental
*/
public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
private final List<FieldWriter> fields = new ArrayList<>();
/** Sole constructor */
protected BufferingKnnVectorsWriter() {}
@Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
FieldWriter newField = new FieldWriter(fieldInfo);
fields.add(newField);
return newField;
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter fieldData : fields) {
KnnVectorsReader knnVectorsReader =
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() {
throw new UnsupportedOperationException();
}
@Override
public void checkIntegrity() {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
VectorValues vectorValues =
new BufferedVectorValues(
fieldData.docsWithField,
fieldData.vectors,
fieldData.fieldInfo.getVectorDimension());
return sortMap != null
? new VectorValues.SortingVectorValues(vectorValues, sortMap)
: vectorValues;
}
@Override
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
};
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
}
}
@Override
public long ramBytesUsed() {
long total = 0;
for (FieldWriter field : fields) {
total += field.ramBytesUsed();
}
return total;
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
KnnVectorsReader knnVectorsReader =
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() {}
@Override
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
}
@Override
public void checkIntegrity() {}
};
writeField(fieldInfo, knnVectorsReader, mergeState.segmentInfo.maxDoc());
}
/** Write the provided field */
protected abstract void writeField(
FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc) throws IOException;
private static class FieldWriter extends KnnFieldVectorsWriter {
private final FieldInfo fieldInfo;
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<float[]> vectors;
private int lastDocID = -1;
public FieldWriter(FieldInfo fieldInfo) {
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
}
@Override
public void addValue(int docID, float[] vectorValue) {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
if (vectorValue.length != dim) {
throw new IllegalArgumentException(
"Attempt to index a vector of dimension "
+ vectorValue.length
+ " but \""
+ fieldInfo.name
+ "\" has dimension "
+ dim);
}
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
lastDocID = docID;
}
@Override
public long ramBytesUsed() {
if (vectors.size() == 0) return 0;
return docsWithField.ramBytesUsed()
+ vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ vectors.size() * vectors.get(0).length * Float.BYTES;
}
}
private static class BufferedVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
final DocsWithFieldSet docsWithField;
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
final List<float[]> vectors;
final int dimension;
final ByteBuffer buffer;
final BytesRef binaryValue;
final ByteBuffer raBuffer;
final BytesRef raBinaryValue;
DocIdSetIterator docsWithFieldIter;
int ord = -1;
BufferedVectorValues(DocsWithFieldSet docsWithField, List<float[]> vectors, int dimension) {
this.docsWithField = docsWithField;
this.vectors = vectors;
this.dimension = dimension;
buffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
binaryValue = new BytesRef(buffer.array());
raBuffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
raBinaryValue = new BytesRef(raBuffer.array());
docsWithFieldIter = docsWithField.iterator();
}
@Override
public RandomAccessVectorValues randomAccess() {
return new BufferedVectorValues(docsWithField, vectors, dimension);
}
@Override
public int dimension() {
return dimension;
}
@Override
public int size() {
return vectors.size();
}
@Override
public BytesRef binaryValue() {
buffer.asFloatBuffer().put(vectorValue());
return binaryValue;
}
@Override
public BytesRef binaryValue(int targetOrd) {
raBuffer.asFloatBuffer().put(vectors.get(targetOrd));
return raBinaryValue;
}
@Override
public float[] vectorValue() {
return vectors.get(ord);
}
@Override
public float[] vectorValue(int targetOrd) {
return vectors.get(targetOrd);
}
@Override
public int docID() {
return docsWithFieldIter.docID();
}
@Override
public int nextDoc() throws IOException {
int docID = docsWithFieldIter.nextDoc();
if (docID != NO_MORE_DOCS) {
++ord;
}
return docID;
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
return docsWithFieldIter.cost();
}
}
}

View File

@ -30,8 +30,7 @@ import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.codecs.DocValuesConsumer;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.NormsConsumer;
import org.apache.lucene.codecs.NormsFormat;
import org.apache.lucene.codecs.NormsProducer;
@ -68,6 +67,7 @@ final class IndexingChain implements Accountable {
final ByteBlockPool docValuesBytePool;
// Writes stored fields
final StoredFieldsConsumer storedFieldsConsumer;
final VectorValuesConsumer vectorValuesConsumer;
final TermVectorsConsumer termVectorsWriter;
// NOTE: I tried using Hash Map<String,PerField>
@ -104,6 +104,8 @@ final class IndexingChain implements Accountable {
this.fieldInfos = fieldInfos;
this.infoStream = indexWriterConfig.getInfoStream();
this.abortingExceptionConsumer = abortingExceptionConsumer;
this.vectorValuesConsumer =
new VectorValuesConsumer(indexWriterConfig.getCodec(), directory, segmentInfo, infoStream);
if (segmentInfo.getIndexSort() == null) {
storedFieldsConsumer =
@ -262,7 +264,7 @@ final class IndexingChain implements Accountable {
}
t0 = System.nanoTime();
writeVectors(state, sortMap);
vectorValuesConsumer.flush(state, sortMap);
if (infoStream.isEnabled("IW")) {
infoStream.message("IW", ((System.nanoTime() - t0) / 1000000) + " msec to write vectors");
}
@ -428,63 +430,6 @@ final class IndexingChain implements Accountable {
}
}
/** Writes all buffered vectors. */
private void writeVectors(SegmentWriteState state, Sorter.DocMap sortMap) throws IOException {
KnnVectorsWriter knnVectorsWriter = null;
boolean success = false;
try {
for (int i = 0; i < fieldHash.length; i++) {
PerField perField = fieldHash[i];
while (perField != null) {
if (perField.vectorValuesWriter != null) {
if (perField.fieldInfo.getVectorDimension() == 0) {
// BUG
throw new AssertionError(
"segment="
+ state.segmentInfo
+ ": field=\""
+ perField.fieldInfo.name
+ "\" has no vectors but wrote them");
}
if (knnVectorsWriter == null) {
// lazy init
KnnVectorsFormat fmt = state.segmentInfo.getCodec().knnVectorsFormat();
if (fmt == null) {
throw new IllegalStateException(
"field=\""
+ perField.fieldInfo.name
+ "\" was indexed as vectors but codec does not support vectors");
}
knnVectorsWriter = fmt.fieldsWriter(state);
}
perField.vectorValuesWriter.flush(sortMap, knnVectorsWriter);
perField.vectorValuesWriter = null;
} else if (perField.fieldInfo != null && perField.fieldInfo.getVectorDimension() != 0) {
// BUG
throw new AssertionError(
"segment="
+ state.segmentInfo
+ ": field=\""
+ perField.fieldInfo.name
+ "\" has vectors but did not write them");
}
perField = perField.next;
}
}
if (knnVectorsWriter != null) {
knnVectorsWriter.finish();
}
success = true;
} finally {
if (success) {
IOUtils.close(knnVectorsWriter);
} else {
IOUtils.closeWhileHandlingException(knnVectorsWriter);
}
}
}
private void writeNorms(SegmentWriteState state, Sorter.DocMap sortMap) throws IOException {
boolean success = false;
NormsConsumer normsConsumer = null;
@ -522,6 +467,7 @@ final class IndexingChain implements Accountable {
// finalizer will e.g. close any open files in the term vectors writer:
try (Closeable finalizer = termsHash::abort) {
storedFieldsConsumer.abort();
vectorValuesConsumer.abort();
} finally {
Arrays.fill(fieldHash, null);
}
@ -714,7 +660,12 @@ final class IndexingChain implements Accountable {
pf.pointValuesWriter = new PointValuesWriter(bytesUsed, fi);
}
if (fi.getVectorDimension() != 0) {
pf.vectorValuesWriter = new VectorValuesWriter(fi, bytesUsed);
try {
pf.knnFieldVectorsWriter = vectorValuesConsumer.addField(fi);
} catch (Throwable th) {
onAbortingException(th);
throw th;
}
}
}
@ -761,7 +712,7 @@ final class IndexingChain implements Accountable {
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
}
if (fieldType.vectorDimension() != 0) {
pf.vectorValuesWriter.addValue(docID, ((KnnVectorField) field).vectorValue());
pf.knnFieldVectorsWriter.addValue(docID, ((KnnVectorField) field).vectorValue());
}
return indexedField;
}
@ -1006,12 +957,16 @@ final class IndexingChain implements Accountable {
public long ramBytesUsed() {
return bytesUsed.get()
+ storedFieldsConsumer.accountable.ramBytesUsed()
+ termVectorsWriter.accountable.ramBytesUsed();
+ termVectorsWriter.accountable.ramBytesUsed()
+ vectorValuesConsumer.getAccountable().ramBytesUsed();
}
@Override
public Collection<Accountable> getChildResources() {
return List.of(storedFieldsConsumer.accountable, termVectorsWriter.accountable);
return List.of(
storedFieldsConsumer.accountable,
termVectorsWriter.accountable,
vectorValuesConsumer.getAccountable());
}
/** NOTE: not static: accesses at least docState, termsHash. */
@ -1032,8 +987,8 @@ final class IndexingChain implements Accountable {
// Non-null if this field ever had points in this segment:
PointValuesWriter pointValuesWriter;
// Non-null if this field ever had vector values in this segment:
VectorValuesWriter vectorValuesWriter;
// Non-null if this field had vectors in this segment
KnnFieldVectorsWriter knnFieldVectorsWriter;
/** We use this to know when a PerField is seen for the first time in the current document. */
long fieldGen = -1;

View File

@ -28,7 +28,7 @@ import org.apache.lucene.util.packed.PackedLongValues;
*
* @lucene.experimental
*/
final class Sorter {
public final class Sorter {
final Sort sort;
/** Creates a new Sorter to sort the index with {@code sort} */
@ -44,20 +44,23 @@ final class Sorter {
* A permutation of doc IDs. For every document ID between <code>0</code> and {@link
* IndexReader#maxDoc()}, <code>oldToNew(newToOld(docID))</code> must return <code>docID</code>.
*/
abstract static class DocMap {
public abstract static class DocMap {
/** Sole constructor. */
protected DocMap() {}
/** Given a doc ID from the original index, return its ordinal in the sorted index. */
abstract int oldToNew(int docID);
public abstract int oldToNew(int docID);
/** Given the ordinal of a doc ID, return its doc ID in the original index. */
abstract int newToOld(int docID);
public abstract int newToOld(int docID);
/**
* Return the number of documents in this map. This must be equal to the {@link
* org.apache.lucene.index.LeafReader#maxDoc() number of documents} of the {@link
* org.apache.lucene.index.LeafReader} which is sorted.
*/
abstract int size();
public abstract int size();
}
/** Check consistency of a {@link DocMap}, useful for assertions. */

View File

@ -380,7 +380,7 @@ public final class SortingCodecReader extends FilterCodecReader {
@Override
public VectorValues getVectorValues(String field) throws IOException {
return new VectorValuesWriter.SortingVectorValues(delegate.getVectorValues(field), docMap);
return new VectorValues.SortingVectorValues(delegate.getVectorValues(field), docMap);
}
@Override

View File

@ -111,4 +111,117 @@ public abstract class VectorValues extends DocIdSetIterator {
return 0;
}
};
/** Sorting VectorValues that iterate over documents in the order of the provided sortMap */
public static class SortingVectorValues extends VectorValues
implements RandomAccessVectorValuesProducer {
private final VectorValues delegate;
private final RandomAccessVectorValues randomAccess;
private final int[] docIdOffsets;
private final int[] ordMap;
private int docId = -1;
/** Sorting VectorValues */
public SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.delegate = delegate;
randomAccess = ((RandomAccessVectorValuesProducer) delegate).randomAccess();
docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
int docID;
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
// set up ordMap to map from new dense ordinal to old dense ordinal
ordMap = new int[offset - 1];
int ord = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord++] = docIdOffset - 1;
}
}
assert ord == ordMap.length;
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
while (docId < docIdOffsets.length - 1) {
++docId;
if (docIdOffsets[docId] != 0) {
return docId;
}
}
docId = NO_MORE_DOCS;
return docId;
}
@Override
public BytesRef binaryValue() throws IOException {
return randomAccess.binaryValue(docIdOffsets[docId] - 1);
}
@Override
public float[] vectorValue() throws IOException {
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
}
@Override
public int dimension() {
return delegate.dimension();
}
@Override
public int size() {
return delegate.size();
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
return size();
}
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
// Must make a new delegate randomAccess so that we have our own distinct float[]
final RandomAccessVectorValues delegateRA =
((RandomAccessVectorValuesProducer) SortingVectorValues.this.delegate).randomAccess();
return new RandomAccessVectorValues() {
@Override
public int size() {
return delegateRA.size();
}
@Override
public int dimension() {
return delegateRA.dimension();
}
@Override
public float[] vectorValue(int targetOrd) throws IOException {
return delegateRA.vectorValue(ordMap[targetOrd]);
}
@Override
public BytesRef binaryValue(int targetOrd) {
throw new UnsupportedOperationException();
}
};
}
}
}

View File

@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
/**
* Streams vector values for indexing to the given codec's vectors writer. The codec's vectors
* writer is responsible for buffering and processing vectors.
*/
class VectorValuesConsumer {
private final Codec codec;
private final Directory directory;
private final SegmentInfo segmentInfo;
private final InfoStream infoStream;
private Accountable accountable = Accountable.NULL_ACCOUNTABLE;
private KnnVectorsWriter writer;
VectorValuesConsumer(
Codec codec, Directory directory, SegmentInfo segmentInfo, InfoStream infoStream) {
this.codec = codec;
this.directory = directory;
this.segmentInfo = segmentInfo;
this.infoStream = infoStream;
}
private void initKnnVectorsWriter(String fieldName) throws IOException {
if (writer == null) {
KnnVectorsFormat fmt = codec.knnVectorsFormat();
if (fmt == null) {
throw new IllegalStateException(
"field=\""
+ fieldName
+ "\" was indexed as vectors but codec does not support vectors");
}
SegmentWriteState initialWriteState =
new SegmentWriteState(infoStream, directory, segmentInfo, null, null, IOContext.DEFAULT);
writer = fmt.fieldsWriter(initialWriteState);
accountable = writer;
}
}
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
initKnnVectorsWriter(fieldInfo.name);
return writer.addField(fieldInfo);
}
void flush(SegmentWriteState state, Sorter.DocMap sortMap) throws IOException {
if (writer == null) return;
try {
writer.flush(state.segmentInfo.maxDoc(), sortMap);
writer.finish();
} finally {
IOUtils.close(writer);
}
}
void abort() {
IOUtils.closeWhileHandlingException(writer);
}
public Accountable getAccountable() {
return accountable;
}
}

View File

@ -1,348 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Counter;
import org.apache.lucene.util.RamUsageEstimator;
/**
* Buffers up pending vector value(s) per doc, then flushes when segment flushes.
*
* @lucene.experimental
*/
class VectorValuesWriter {
private final FieldInfo fieldInfo;
private final Counter iwBytesUsed;
private final List<float[]> vectors = new ArrayList<>();
private final DocsWithFieldSet docsWithField;
private int lastDocID = -1;
private long bytesUsed;
VectorValuesWriter(FieldInfo fieldInfo, Counter iwBytesUsed) {
this.fieldInfo = fieldInfo;
this.iwBytesUsed = iwBytesUsed;
this.docsWithField = new DocsWithFieldSet();
this.bytesUsed = docsWithField.ramBytesUsed();
if (iwBytesUsed != null) {
iwBytesUsed.addAndGet(bytesUsed);
}
}
/**
* Adds a value for the given document. Only a single value may be added.
*
* @param docID the value is added to this document
* @param vectorValue the value to add
* @throws IllegalArgumentException if a value has already been added to the given document
*/
public void addValue(int docID, float[] vectorValue) {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
if (vectorValue.length != fieldInfo.getVectorDimension()) {
throw new IllegalArgumentException(
"Attempt to index a vector of dimension "
+ vectorValue.length
+ " but \""
+ fieldInfo.name
+ "\" has dimension "
+ fieldInfo.getVectorDimension());
}
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
updateBytesUsed();
lastDocID = docID;
}
private void updateBytesUsed() {
final long newBytesUsed =
docsWithField.ramBytesUsed()
+ vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ vectors.size() * vectors.get(0).length * Float.BYTES;
if (iwBytesUsed != null) {
iwBytesUsed.addAndGet(newBytesUsed - bytesUsed);
}
bytesUsed = newBytesUsed;
}
/**
* Flush this field's values to storage, sorting the values in accordance with sortMap
*
* @param sortMap specifies the order of documents being flushed, or null if they are to be
* flushed in docid order
* @param knnVectorsWriter the Codec's vector writer that handles the actual encoding and I/O
* @throws IOException if there is an error writing the field and its values
*/
public void flush(Sorter.DocMap sortMap, KnnVectorsWriter knnVectorsWriter) throws IOException {
KnnVectorsReader knnVectorsReader =
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public void checkIntegrity() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
VectorValues vectorValues =
new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
return sortMap != null ? new SortingVectorValues(vectorValues, sortMap) : vectorValues;
}
@Override
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
throw new UnsupportedOperationException();
}
};
knnVectorsWriter.writeField(fieldInfo, knnVectorsReader);
}
static class SortingVectorValues extends VectorValues
implements RandomAccessVectorValuesProducer {
private final VectorValues delegate;
private final RandomAccessVectorValues randomAccess;
private final int[] docIdOffsets;
private final int[] ordMap;
private int docId = -1;
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.delegate = delegate;
randomAccess = ((RandomAccessVectorValuesProducer) delegate).randomAccess();
docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
int docID;
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
// set up ordMap to map from new dense ordinal to old dense ordinal
ordMap = new int[offset - 1];
int ord = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord++] = docIdOffset - 1;
}
}
assert ord == ordMap.length;
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
while (docId < docIdOffsets.length - 1) {
++docId;
if (docIdOffsets[docId] != 0) {
return docId;
}
}
docId = NO_MORE_DOCS;
return docId;
}
@Override
public BytesRef binaryValue() throws IOException {
return randomAccess.binaryValue(docIdOffsets[docId] - 1);
}
@Override
public float[] vectorValue() throws IOException {
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
}
@Override
public int dimension() {
return delegate.dimension();
}
@Override
public int size() {
return delegate.size();
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
return size();
}
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
// Must make a new delegate randomAccess so that we have our own distinct float[]
final RandomAccessVectorValues delegateRA =
((RandomAccessVectorValuesProducer) SortingVectorValues.this.delegate).randomAccess();
return new RandomAccessVectorValues() {
@Override
public int size() {
return delegateRA.size();
}
@Override
public int dimension() {
return delegateRA.dimension();
}
@Override
public float[] vectorValue(int targetOrd) throws IOException {
return delegateRA.vectorValue(ordMap[targetOrd]);
}
@Override
public BytesRef binaryValue(int targetOrd) {
throw new UnsupportedOperationException();
}
};
}
}
private static class BufferedVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
final DocsWithFieldSet docsWithField;
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
final List<float[]> vectors;
final int dimension;
final ByteBuffer buffer;
final BytesRef binaryValue;
final ByteBuffer raBuffer;
final BytesRef raBinaryValue;
DocIdSetIterator docsWithFieldIter;
int ord = -1;
BufferedVectorValues(DocsWithFieldSet docsWithField, List<float[]> vectors, int dimension) {
this.docsWithField = docsWithField;
this.vectors = vectors;
this.dimension = dimension;
buffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
binaryValue = new BytesRef(buffer.array());
raBuffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
raBinaryValue = new BytesRef(raBuffer.array());
docsWithFieldIter = docsWithField.iterator();
}
@Override
public RandomAccessVectorValues randomAccess() {
return new BufferedVectorValues(docsWithField, vectors, dimension);
}
@Override
public int dimension() {
return dimension;
}
@Override
public int size() {
return vectors.size();
}
@Override
public BytesRef binaryValue() {
buffer.asFloatBuffer().put(vectorValue());
return binaryValue;
}
@Override
public BytesRef binaryValue(int targetOrd) {
raBuffer.asFloatBuffer().put(vectors.get(targetOrd));
return raBinaryValue;
}
@Override
public float[] vectorValue() {
return vectors.get(ord);
}
@Override
public float[] vectorValue(int targetOrd) {
return vectors.get(targetOrd);
}
@Override
public int docID() {
return docsWithFieldIter.docID();
}
@Override
public int nextDoc() throws IOException {
int docID = docsWithFieldIter.nextDoc();
if (docID != NO_MORE_DOCS) {
++ord;
}
return docID;
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
return docsWithFieldIter.cost();
}
}
}

View File

@ -137,8 +137,12 @@ public final class HnswGraphBuilder {
this.infoStream = infoStream;
}
public OnHeapHnswGraph getGraph() {
return hnsw;
}
/** Inserts a doc with vector value to the graph */
void addGraphNode(int node, float[] value) throws IOException {
public void addGraphNode(int node, float[] value) throws IOException {
NeighborQueue candidates;
final int nodeLevel = getRandomGraphLevel(ml, random);
int curMaxLevel = hnsw.numLevels() - 1;

View File

@ -24,6 +24,7 @@ import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.SparseFixedBitSet;
/**
@ -38,7 +39,7 @@ public final class HnswGraphSearcher {
*/
private final NeighborQueue candidates;
private final BitSet visited;
private BitSet visited;
/**
* Creates a new graph searcher.
@ -140,7 +141,7 @@ public final class HnswGraphSearcher {
throws IOException {
int size = graph.size();
NeighborQueue results = new NeighborQueue(topK, false);
clearScratchState();
prepareScratchState(vectors.size());
int numVisited = 0;
for (int ep : eps) {
@ -203,8 +204,11 @@ public final class HnswGraphSearcher {
return results;
}
private void clearScratchState() {
private void prepareScratchState(int capacity) {
candidates.clear();
if (visited.length() < capacity) {
visited = FixedBitSet.ensureCapacity((FixedBitSet) visited, capacity);
}
visited.clear(0, visited.length());
}
}

View File

@ -22,13 +22,15 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.RamUsageEstimator;
/**
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
* construct the HNSW graph before it's written to the index.
*/
public final class OnHeapHnswGraph extends HnswGraph {
public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level
@ -167,4 +169,28 @@ public final class OnHeapHnswGraph extends HnswGraph {
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
}
}
@Override
public long ramBytesUsed() {
long neighborArrayBytes0 =
nsize0 * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
long neighborArrayBytes =
nsize * (Integer.BYTES + Float.BYTES)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF;
long total = 0;
for (int l = 0; l < numLevels; l++) {
int numNodesOnLevel = graph.get(l).size();
if (l == 0) {
total += numNodesOnLevel * neighborArrayBytes0; // for graph;
} else {
total += numNodesOnLevel * Integer.BYTES; // for nodesByLevel
total += numNodesOnLevel * neighborArrayBytes; // for graph;
}
}
return total;
}
}

View File

@ -25,6 +25,7 @@ import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
@ -41,6 +42,7 @@ import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
@ -174,19 +176,22 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
KnnVectorsWriter writer = delegate.fieldsWriter(state);
return new KnnVectorsWriter() {
@Override
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
fieldsWritten.add(fieldInfo.name);
writer.writeField(fieldInfo, knnVectorsReader);
return writer.addField(fieldInfo);
}
@Override
public void merge(MergeState mergeState) throws IOException {
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
fieldsWritten.add(fieldInfo.name);
}
writer.merge(mergeState);
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
writer.flush(maxDoc, sortMap);
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
fieldsWritten.add(fieldInfo.name);
writer.mergeOneField(fieldInfo, mergeState);
}
@Override
@ -198,6 +203,11 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
public void close() throws IOException {
writer.close();
}
@Override
public long ramBytesUsed() {
return writer.ramBytesUsed();
}
};
}

View File

@ -20,8 +20,10 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import org.apache.lucene.codecs.KnnVectorsFormat;
@ -31,6 +33,7 @@ import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
@ -42,6 +45,12 @@ import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
@ -122,6 +131,94 @@ public class TestHnswGraph extends LuceneTestCase {
}
}
// test that sorted index returns the same search results are unsorted
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
int dim = random().nextInt(10) + 3;
int nDoc = random().nextInt(500) + 1;
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
int M = random().nextInt(10) + 5;
int beamWidth = random().nextInt(10) + 5;
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
long seed = random().nextLong();
HnswGraphBuilder.randSeed = seed;
IndexWriterConfig iwc =
new IndexWriterConfig()
.setCodec(
new Lucene94Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene94HnswVectorsFormat(M, beamWidth);
}
});
IndexWriterConfig iwc2 =
new IndexWriterConfig()
.setCodec(
new Lucene94Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene94HnswVectorsFormat(M, beamWidth);
}
})
.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.LONG)));
;
try (Directory dir = newDirectory();
Directory dir2 = newDirectory()) {
int indexedDoc = 0;
try (IndexWriter iw = new IndexWriter(dir, iwc);
IndexWriter iw2 = new IndexWriter(dir2, iwc2)) {
while (vectors.nextDoc() != NO_MORE_DOCS) {
while (indexedDoc < vectors.docID()) {
// increment docId in the index by adding empty documents
iw.addDocument(new Document());
indexedDoc++;
}
Document doc = new Document();
doc.add(new KnnVectorField("vector", vectors.vectorValue(), similarityFunction));
doc.add(new StoredField("id", vectors.docID()));
doc.add(new NumericDocValuesField("sortkey", random().nextLong()));
iw.addDocument(doc);
iw2.addDocument(doc);
indexedDoc++;
}
}
try (IndexReader reader = DirectoryReader.open(dir);
IndexReader reader2 = DirectoryReader.open(dir2)) {
IndexSearcher searcher = new IndexSearcher(reader);
IndexSearcher searcher2 = new IndexSearcher(reader2);
for (int i = 0; i < 10; i++) {
// ask to explore a lot of candidates to ensure the same returned hits,
// as graphs of 2 indices are organized differently
KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(random(), dim), 50);
List<String> ids1 = new ArrayList<>();
List<Integer> docs1 = new ArrayList<>();
List<String> ids2 = new ArrayList<>();
List<Integer> docs2 = new ArrayList<>();
TopDocs topDocs = searcher.search(query, 5);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
Document doc = reader.document(scoreDoc.doc, Set.of("id"));
ids1.add(doc.get("id"));
docs1.add(scoreDoc.doc);
}
TopDocs topDocs2 = searcher2.search(query, 5);
for (ScoreDoc scoreDoc : topDocs2.scoreDocs) {
Document doc = reader2.document(scoreDoc.doc, Set.of("id"));
ids2.add(doc.get("id"));
docs2.add(scoreDoc.doc);
}
assertEquals(ids1, ids2);
// doc IDs are not equal, as in the second sorted index docs are organized differently
assertNotEquals(docs1, docs2);
}
}
}
}
private void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());

View File

@ -18,6 +18,7 @@
package org.apache.lucene.tests.codecs.asserting;
import java.io.IOException;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
@ -26,6 +27,7 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.tests.util.TestUtil;
@ -59,20 +61,20 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
assert fieldInfo != null;
assert knnVectorsReader != null;
// assert that knnVectorsReader#getVectorValues returns different instances upon repeated
// calls
assert knnVectorsReader.getVectorValues(fieldInfo.name)
!= knnVectorsReader.getVectorValues(fieldInfo.name);
delegate.writeField(fieldInfo, knnVectorsReader);
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
return delegate.addField(fieldInfo);
}
@Override
public void merge(MergeState mergeState) throws IOException {
delegate.merge(mergeState);
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
delegate.flush(maxDoc, sortMap);
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
assert fieldInfo != null;
assert mergeState != null;
delegate.mergeOneField(fieldInfo, mergeState);
}
@Override
@ -84,6 +86,11 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
public void close() throws IOException {
delegate.close();
}
@Override
public long ramBytesUsed() {
return delegate.ramBytesUsed();
}
}
static class AssertingKnnVectorsReader extends KnnVectorsReader {