From b36b4af22bb76dc42b466b818b417bcbc0deb006 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Fri, 13 Nov 2020 08:53:51 -0500 Subject: [PATCH] LUCENE-9004: KNN vector search using NSW graphs (#2022) --- .../apache/lucene/codecs/VectorWriter.java | 1 - .../codecs/lucene90/Lucene90VectorFormat.java | 6 +- .../codecs/lucene90/Lucene90VectorReader.java | 212 ++++++-- .../codecs/lucene90/Lucene90VectorWriter.java | 92 +++- .../lucene/index/IndexableFieldType.java | 30 +- .../apache/lucene/index/KnnGraphValues.java | 58 ++ .../lucene/index/SlowCodecReaderWrapper.java | 2 +- .../org/apache/lucene/index/VectorValues.java | 51 +- .../lucene/index/VectorValuesWriter.java | 25 +- .../lucene/util/hnsw/BoundsChecker.java | 74 +++ .../apache/lucene/util/hnsw/HnswGraph.java | 223 ++++++++ .../lucene/util/hnsw/HnswGraphBuilder.java | 188 +++++++ .../org/apache/lucene/util/hnsw/Neighbor.java | 70 +++ .../apache/lucene/util/hnsw/Neighbors.java | 93 ++++ .../apache/lucene/util/hnsw/package-info.java | 22 + .../org/apache/lucene/index/TestKnnGraph.java | 352 +++++++++++++ .../apache/lucene/index/TestVectorValues.java | 45 ++ .../lucene/util/hnsw/KnnGraphTester.java | 494 ++++++++++++++++++ .../org/apache/lucene/util/hnsw/TestHnsw.java | 459 ++++++++++++++++ .../queries/intervals/IntervalQuery.java | 2 +- 20 files changed, 2400 insertions(+), 99 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbor.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/package-info.java create mode 100644 lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java create mode 100644 lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java create mode 100644 lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java diff --git a/lucene/core/src/java/org/apache/lucene/codecs/VectorWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/VectorWriter.java index 9b05cc67904..7b13310cff8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/VectorWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/VectorWriter.java @@ -283,7 +283,6 @@ public abstract class VectorWriter implements Closeable { public BytesRef binaryValue(int targetOrd) throws IOException { throw new UnsupportedOperationException(); } - } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorFormat.java index 632bc8154d8..5363c652f81 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorFormat.java @@ -27,15 +27,17 @@ import org.apache.lucene.index.SegmentWriteState; /** * Lucene 9.0 vector format, which encodes dense numeric vector values. - * TODO: add support for approximate KNN search. + * + * @lucene.experimental */ public final class Lucene90VectorFormat extends VectorFormat { static final String META_CODEC_NAME = "Lucene90VectorFormatMeta"; static final String VECTOR_DATA_CODEC_NAME = "Lucene90VectorFormatData"; - + static final String VECTOR_INDEX_CODEC_NAME = "Lucene90VectorFormatIndex"; static final String META_EXTENSION = "vem"; static final String VECTOR_DATA_EXTENSION = "vec"; + static final String VECTOR_INDEX_EXTENSION = "vex"; static final int VERSION_START = 0; static final int VERSION_CURRENT = VERSION_START; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java index 140310859e4..674959f1f33 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java @@ -22,6 +22,7 @@ import java.nio.ByteBuffer; import java.nio.FloatBuffer; import java.util.HashMap; import java.util.Map; +import java.util.Random; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.VectorReader; @@ -29,19 +30,28 @@ import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnGraphValues; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorValues; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.Neighbor; +import org.apache.lucene.util.hnsw.Neighbors; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** - * Reads vectors from the index segments. + * Reads vectors from the index segments along with index data structures supporting KNN search. * @lucene.experimental */ public final class Lucene90VectorReader extends VectorReader { @@ -49,13 +59,21 @@ public final class Lucene90VectorReader extends VectorReader { private final FieldInfos fieldInfos; private final Map fields = new HashMap<>(); private final IndexInput vectorData; - private final int maxDoc; + private final IndexInput vectorIndex; + private final long checksumSeed; Lucene90VectorReader(SegmentReadState state) throws IOException { this.fieldInfos = state.fieldInfos; - this.maxDoc = state.segmentInfo.maxDoc(); - String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.META_EXTENSION); + int versionMeta = readMetadata(state, Lucene90VectorFormat.META_EXTENSION); + long[] checksumRef = new long[1]; + vectorData = openDataInput(state, versionMeta, Lucene90VectorFormat.VECTOR_DATA_EXTENSION, Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME, checksumRef); + vectorIndex = openDataInput(state, versionMeta, Lucene90VectorFormat.VECTOR_INDEX_EXTENSION, Lucene90VectorFormat.VECTOR_INDEX_CODEC_NAME, checksumRef); + checksumSeed = checksumRef[0]; + } + + private int readMetadata(SegmentReadState state, String fileExtension) throws IOException { + String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); int versionMeta = -1; try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName, state.context)) { Throwable priorE = null; @@ -73,29 +91,32 @@ public final class Lucene90VectorReader extends VectorReader { CodecUtil.checkFooter(meta, priorE); } } + return versionMeta; + } + private static IndexInput openDataInput(SegmentReadState state, int versionMeta, String fileExtension, String codecName, long[] checksumRef) throws IOException { boolean success = false; - String vectorDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_DATA_EXTENSION); - this.vectorData = state.directory.openInput(vectorDataFileName, state.context); + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, state.context); try { - int versionVectorData = CodecUtil.checkIndexHeader(vectorData, - Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME, + int versionVectorData = CodecUtil.checkIndexHeader(in, + codecName, Lucene90VectorFormat.VERSION_START, Lucene90VectorFormat.VERSION_CURRENT, state.segmentInfo.getId(), state.segmentSuffix); if (versionMeta != versionVectorData) { - throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", vector data=" + versionVectorData, vectorData); + throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, in); } - CodecUtil.retrieveChecksum(vectorData); - + checksumRef[0] = CodecUtil.retrieveChecksum(in); success = true; } finally { if (!success) { - IOUtils.closeWhileHandlingException(this.vectorData); + IOUtils.closeWhileHandlingException(in); } } + return in; } private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { @@ -104,23 +125,28 @@ public final class Lucene90VectorReader extends VectorReader { if (info == null) { throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); } - int searchStrategyId = meta.readInt(); - if (searchStrategyId < 0 || searchStrategyId >= VectorValues.SearchStrategy.values().length) { - throw new CorruptIndexException("Invalid search strategy id: " + searchStrategyId, meta); - } - VectorValues.SearchStrategy searchStrategy = VectorValues.SearchStrategy.values()[searchStrategyId]; - long vectorDataOffset = meta.readVLong(); - long vectorDataLength = meta.readVLong(); - int dimension = meta.readInt(); - int size = meta.readInt(); - int[] ordToDoc = new int[size]; - for (int i = 0; i < size; i++) { - int doc = meta.readVInt(); - ordToDoc[i] = doc; - } - FieldEntry fieldEntry = new FieldEntry(dimension, searchStrategy, maxDoc, vectorDataOffset, vectorDataLength, - ordToDoc); - fields.put(info.name, fieldEntry); + fields.put(info.name, readField(meta)); + } + } + + private VectorValues.SearchStrategy readSearchStrategy(DataInput input) throws IOException { + int searchStrategyId = input.readInt(); + if (searchStrategyId < 0 || searchStrategyId >= VectorValues.SearchStrategy.values().length) { + throw new CorruptIndexException("Invalid search strategy id: " + searchStrategyId, input); + } + return VectorValues.SearchStrategy.values()[searchStrategyId]; + } + + private FieldEntry readField(DataInput input) throws IOException { + VectorValues.SearchStrategy searchStrategy = readSearchStrategy(input); + switch(searchStrategy) { + case NONE: + return new FieldEntry(input, searchStrategy); + case DOT_PRODUCT_HNSW: + case EUCLIDEAN_HNSW: + return new HnswGraphFieldEntry(input, searchStrategy); + default: + throw new CorruptIndexException("Unknown vector search strategy: " + searchStrategy, input); } } @@ -137,6 +163,7 @@ public final class Lucene90VectorReader extends VectorReader { @Override public void checkIntegrity() throws IOException { CodecUtil.checksumEntireFile(vectorData); + CodecUtil.checksumEntireFile(vectorIndex); } @Override @@ -167,29 +194,58 @@ public final class Lucene90VectorReader extends VectorReader { return new OffHeapVectorValues(fieldEntry, bytesSlice); } + public KnnGraphValues getGraphValues(String field) throws IOException { + FieldInfo info = fieldInfos.fieldInfo(field); + if (info == null) { + throw new IllegalArgumentException("No such field '" + field + "'"); + } + FieldEntry entry = fields.get(field); + if (entry != null && entry.indexDataLength > 0) { + return getGraphValues(entry); + } else { + return KnnGraphValues.EMPTY; + } + } + + private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException { + if (entry.searchStrategy.isHnsw()) { + HnswGraphFieldEntry graphEntry = (HnswGraphFieldEntry) entry; + IndexInput bytesSlice = vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength); + return new IndexedKnnGraphReader(graphEntry, bytesSlice); + } else { + return KnnGraphValues.EMPTY; + } + } + @Override public void close() throws IOException { - vectorData.close(); + IOUtils.close(vectorData, vectorIndex); } private static class FieldEntry { final int dimension; final VectorValues.SearchStrategy searchStrategy; - final int maxDoc; final long vectorDataOffset; final long vectorDataLength; + final long indexDataOffset; + final long indexDataLength; final int[] ordToDoc; - FieldEntry(int dimension, VectorValues.SearchStrategy searchStrategy, int maxDoc, - long vectorDataOffset, long vectorDataLength, int[] ordToDoc) { - this.dimension = dimension; + FieldEntry(DataInput input, VectorValues.SearchStrategy searchStrategy) throws IOException { this.searchStrategy = searchStrategy; - this.maxDoc = maxDoc; - this.vectorDataOffset = vectorDataOffset; - this.vectorDataLength = vectorDataLength; - this.ordToDoc = ordToDoc; + vectorDataOffset = input.readVLong(); + vectorDataLength = input.readVLong(); + indexDataOffset = input.readVLong(); + indexDataLength = input.readVLong(); + dimension = input.readInt(); + int size = input.readInt(); + ordToDoc = new int[size]; + for (int i = 0; i < size; i++) { + int doc = input.readVInt(); + ordToDoc[i] = doc; + } } int size() { @@ -197,6 +253,21 @@ public final class Lucene90VectorReader extends VectorReader { } } + private static class HnswGraphFieldEntry extends FieldEntry { + + final long[] ordOffsets; + + HnswGraphFieldEntry(DataInput input, VectorValues.SearchStrategy searchStrategy) throws IOException { + super(input, searchStrategy); + ordOffsets = new long[size()]; + long offset = 0; + for (int i = 0; i < ordOffsets.length; i++) { + offset += input.readVLong(); + ordOffsets[i] = offset; + } + } + } + /** Read the vector values from the index input. This supports both iterated and random access. */ private final class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValuesProducer { @@ -252,11 +323,6 @@ public final class Lucene90VectorReader extends VectorReader { return binaryValue; } - @Override - public TopDocs search(float[] target, int k, int fanout) { - throw new UnsupportedOperationException(); - } - @Override public int docID() { return doc; @@ -288,6 +354,30 @@ public final class Lucene90VectorReader extends VectorReader { return new OffHeapRandomAccess(dataIn.clone()); } + @Override + public TopDocs search(float[] vector, int topK, int fanout) throws IOException { + // use a seed that is fixed for the index so we get reproducible results for the same query + final Random random = new Random(checksumSeed); + Neighbors results = HnswGraph.search(vector, topK + fanout, topK + fanout, randomAccess(), getGraphValues(fieldEntry), random); + while (results.size() > topK) { + results.pop(); + } + int i = 0; + ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), topK)]; + boolean reversed = searchStrategy().reversed; + while (results.size() > 0) { + Neighbor n = results.pop(); + float score; + if (reversed) { + score = (float) Math.exp(- n.score() / vector.length); + } else { + score = n.score(); + } + scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[n.node()], score); + } + // always return >= the case where we can assert == is only when there are fewer than topK vectors in the index + return new TopDocs(new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), scoreDocs); + } class OffHeapRandomAccess implements RandomAccessVectorValues { @@ -296,12 +386,10 @@ public final class Lucene90VectorReader extends VectorReader { final BytesRef binaryValue; final ByteBuffer byteBuffer; final FloatBuffer floatBuffer; - final int byteSize; final float[] value; OffHeapRandomAccess(IndexInput dataIn) { this.dataIn = dataIn; - byteSize = Float.BYTES * dimension(); byteBuffer = ByteBuffer.allocate(byteSize); floatBuffer = byteBuffer.asFloatBuffer(); value = new float[dimension()]; @@ -342,7 +430,41 @@ public final class Lucene90VectorReader extends VectorReader { dataIn.seek(offset); dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); } + } + } + /** Read the nearest-neighbors graph from the index input */ + private final class IndexedKnnGraphReader extends KnnGraphValues { + + final HnswGraphFieldEntry entry; + final IndexInput dataIn; + + int arcCount; + int arcUpTo; + int arc; + + IndexedKnnGraphReader(HnswGraphFieldEntry entry, IndexInput dataIn) { + this.entry = entry; + this.dataIn = dataIn; + } + + @Override + public void seek(int targetOrd) throws IOException { + // unsafe; no bounds checking + dataIn.seek(entry.ordOffsets[targetOrd]); + arcCount = dataIn.readInt(); + arc = -1; + arcUpTo = 0; + } + + @Override + public int nextNeighbor() throws IOException { + if (arcUpTo >= arcCount) { + return NO_MORE_DOCS; + } + ++arcUpTo; + arc += dataIn.readVInt(); + return arc; } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java index e64e061f464..71d103b4b7f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java @@ -18,18 +18,20 @@ package org.apache.lucene.codecs.lucene90; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.Arrays; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.VectorWriter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorValues; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.HnswGraphBuilder; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @@ -39,7 +41,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; */ public final class Lucene90VectorWriter extends VectorWriter { - private final IndexOutput meta, vectorData; + private final IndexOutput meta, vectorData, vectorIndex; private boolean finished; @@ -52,6 +54,9 @@ public final class Lucene90VectorWriter extends VectorWriter { String vectorDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_DATA_EXTENSION); vectorData = state.directory.createOutput(vectorDataFileName, state.context); + String indexDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_INDEX_EXTENSION); + vectorIndex = state.directory.createOutput(indexDataFileName, state.context); + try { CodecUtil.writeIndexHeader(meta, Lucene90VectorFormat.META_CODEC_NAME, @@ -61,6 +66,10 @@ public final class Lucene90VectorWriter extends VectorWriter { Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME, Lucene90VectorFormat.VERSION_CURRENT, state.segmentInfo.getId(), state.segmentSuffix); + CodecUtil.writeIndexHeader(vectorIndex, + Lucene90VectorFormat.VECTOR_INDEX_CODEC_NAME, + Lucene90VectorFormat.VERSION_CURRENT, + state.segmentInfo.getId(), state.segmentSuffix); } catch (IOException e) { IOUtils.closeWhileHandlingException(this); } @@ -69,17 +78,47 @@ public final class Lucene90VectorWriter extends VectorWriter { @Override public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException { long vectorDataOffset = vectorData.getFilePointer(); + // TODO - use a better data structure; a bitset? DocsWithFieldSet is p.p. in o.a.l.index - List docIds = new ArrayList<>(); - int docV, ord = 0; - for (docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), ord++) { + int[] docIds = new int[vectors.size()]; + int count = 0; + for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) { + // write vector writeVectorValue(vectors); - docIds.add(docV); - // TODO: write knn graph value + docIds[count] = docV; } + // count may be < vectors.size() e,g, if some documents were deleted + long[] offsets = new long[count]; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + long vectorIndexOffset = vectorIndex.getFilePointer(); + if (vectors.searchStrategy().isHnsw()) { + if (vectors instanceof RandomAccessVectorValuesProducer) { + writeGraph(vectorIndex, (RandomAccessVectorValuesProducer) vectors, vectorIndexOffset, offsets, count); + } else { + throw new IllegalArgumentException("Indexing an HNSW graph requires a random access vector values, got " + vectors); + } + } + long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; if (vectorDataLength > 0) { - writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds); + writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, vectorIndexOffset, vectorIndexLength, count, docIds); + if (vectors.searchStrategy().isHnsw()) { + writeGraphOffsets(meta, offsets); + } + } + } + + private void writeMeta(FieldInfo field, long vectorDataOffset, long vectorDataLength, long indexDataOffset, long indexDataLength, int size, int[] docIds) throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorSearchStrategy().ordinal()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + meta.writeVLong(indexDataOffset); + meta.writeVLong(indexDataLength); + meta.writeInt(field.getVectorDimension()); + meta.writeInt(size); + for (int i = 0; i < size; i ++) { + // TODO: delta-encode, or write as bitset + meta.writeVInt(docIds[i]); } } @@ -90,16 +129,28 @@ public final class Lucene90VectorWriter extends VectorWriter { vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length); } - private void writeMeta(FieldInfo field, long vectorDataOffset, long vectorDataLength, List docIds) throws IOException { - meta.writeInt(field.number); - meta.writeInt(field.getVectorSearchStrategy().ordinal()); - meta.writeVLong(vectorDataOffset); - meta.writeVLong(vectorDataLength); - meta.writeInt(field.getVectorDimension()); - meta.writeInt(docIds.size()); - for (Integer docId : docIds) { - // TODO: delta-encode, or write as bitset - meta.writeVInt(docId); + private void writeGraphOffsets(IndexOutput out, long[] offsets) throws IOException { + long last = 0; + for (long offset : offsets) { + out.writeVLong(offset - last); + last = offset; + } + } + + private void writeGraph(IndexOutput graphData, RandomAccessVectorValuesProducer vectorValues, long graphDataOffset, long[] offsets, int count) throws IOException { + HnswGraph graph = HnswGraphBuilder.build(vectorValues); + for (int ord = 0; ord < count; ord++) { + // write graph + offsets[ord] = graphData.getFilePointer() - graphDataOffset; + int[] arcs = graph.getNeighbors(ord); + Arrays.sort(arcs); + graphData.writeInt(arcs.length); + int lastArc = -1; // to make the assertion work? + for (int arc : arcs) { + assert arc > lastArc : "arcs out of order: " + lastArc + "," + arc; + graphData.writeVInt(arc - lastArc); + lastArc = arc; + } } } @@ -117,11 +168,12 @@ public final class Lucene90VectorWriter extends VectorWriter { } if (vectorData != null) { CodecUtil.writeFooter(vectorData); + CodecUtil.writeFooter(vectorIndex); } } @Override public void close() throws IOException { - IOUtils.close(meta, vectorData); + IOUtils.close(meta, vectorData, vectorIndex); } } diff --git a/lucene/core/src/java/org/apache/lucene/index/IndexableFieldType.java b/lucene/core/src/java/org/apache/lucene/index/IndexableFieldType.java index ce19472caca..38c1c2ae858 100644 --- a/lucene/core/src/java/org/apache/lucene/index/IndexableFieldType.java +++ b/lucene/core/src/java/org/apache/lucene/index/IndexableFieldType.java @@ -28,7 +28,7 @@ import org.apache.lucene.analysis.Analyzer; // javadocs public interface IndexableFieldType { /** True if the field's value should be stored */ - public boolean stored(); + boolean stored(); /** * True if this field's value should be analyzed by the @@ -39,7 +39,7 @@ public interface IndexableFieldType { */ // TODO: shouldn't we remove this? Whether/how a field is // tokenized is an impl detail under Field? - public boolean tokenized(); + boolean tokenized(); /** * True if this field's indexed form should be also stored @@ -52,7 +52,7 @@ public interface IndexableFieldType { * This option is illegal if {@link #indexOptions()} returns * IndexOptions.NONE. */ - public boolean storeTermVectors(); + boolean storeTermVectors(); /** * True if this field's token character offsets should also @@ -61,7 +61,7 @@ public interface IndexableFieldType { * This option is illegal if term vectors are not enabled for the field * ({@link #storeTermVectors()} is false) */ - public boolean storeTermVectorOffsets(); + boolean storeTermVectorOffsets(); /** * True if this field's token positions should also be stored @@ -70,7 +70,7 @@ public interface IndexableFieldType { * This option is illegal if term vectors are not enabled for the field * ({@link #storeTermVectors()} is false). */ - public boolean storeTermVectorPositions(); + boolean storeTermVectorPositions(); /** * True if this field's token payloads should also be stored @@ -79,7 +79,7 @@ public interface IndexableFieldType { * This option is illegal if term vector positions are not enabled * for the field ({@link #storeTermVectors()} is false). */ - public boolean storeTermVectorPayloads(); + boolean storeTermVectorPayloads(); /** * True if normalization values should be omitted for the field. @@ -87,42 +87,42 @@ public interface IndexableFieldType { * This saves memory, but at the expense of scoring quality (length normalization * will be disabled), and if you omit norms, you cannot use index-time boosts. */ - public boolean omitNorms(); + boolean omitNorms(); /** {@link IndexOptions}, describing what should be * recorded into the inverted index */ - public IndexOptions indexOptions(); + IndexOptions indexOptions(); /** * DocValues {@link DocValuesType}: how the field's value will be indexed * into docValues. */ - public DocValuesType docValuesType(); + DocValuesType docValuesType(); /** * If this is positive (representing the number of point dimensions), the field is indexed as a point. */ - public int pointDimensionCount(); + int pointDimensionCount(); /** * The number of dimensions used for the index key */ - public int pointIndexDimensionCount(); + int pointIndexDimensionCount(); /** * The number of bytes in each dimension's values. */ - public int pointNumBytes(); + int pointNumBytes(); /** * The number of dimensions of the field's vector value */ - public int vectorDimension(); + int vectorDimension(); /** * The {@link VectorValues.SearchStrategy} of the field's vector value */ - public VectorValues.SearchStrategy vectorSearchStrategy(); + VectorValues.SearchStrategy vectorSearchStrategy(); /** * Attributes for the field type. @@ -132,5 +132,5 @@ public interface IndexableFieldType { * * @return Map */ - public Map getAttributes(); + Map getAttributes(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java new file mode 100644 index 00000000000..d3ee0dc9027 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.index; + +import java.io.IOException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * Access to per-document neighbor lists in a (hierarchical) knn search graph. + * @lucene.experimental + */ +public abstract class KnnGraphValues { + + /** Sole constructor */ + protected KnnGraphValues() {} + + /** Move the pointer to exactly {@code target}, the id of a node in the graph. + * After this method returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals. + * @param target must be a valid node in the graph, ie. ≥ 0 and < {@link VectorValues#size()}. + */ + public abstract void seek(int target) throws IOException; + + /** + * Iterates over the neighbor list. It is illegal to call this method after it returns + * NO_MORE_DOCS without calling {@link #seek(int)}, which resets the iterator. + * @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete. + */ + public abstract int nextNeighbor() throws IOException; + + /** Empty graph value */ + public static KnnGraphValues EMPTY = new KnnGraphValues() { + + @Override + public int nextNeighbor() { + return NO_MORE_DOCS; + } + + @Override + public void seek(int target) { + } + }; +} diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java index b2ce9aa4d80..2cb8ee3c468 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java @@ -190,7 +190,7 @@ public final class SlowCodecReaderWrapper { } }; } - + private static NormsProducer readerToNormsProducer(final LeafReader reader) { return new NormsProducer() { diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorValues.java b/lucene/core/src/java/org/apache/lucene/index/VectorValues.java index bd3b69c6881..7ede15bd182 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorValues.java @@ -23,6 +23,9 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.BytesRef; +import static org.apache.lucene.util.VectorUtil.dotProduct; +import static org.apache.lucene.util.VectorUtil.squareDistance; + /** * This class provides access to per-document floating point vector values indexed as {@link * org.apache.lucene.document.VectorField}. @@ -91,15 +94,59 @@ public abstract class VectorValues extends DocIdSetIterator { * determine the nearest neighbors. */ public enum SearchStrategy { + /** No search strategy is provided. Note: {@link VectorValues#search(float[], int, int)} * is not supported for fields specifying this strategy. */ NONE, /** HNSW graph built using Euclidean distance */ - EUCLIDEAN_HNSW, + EUCLIDEAN_HNSW(true), /** HNSW graph buit using dot product */ - DOT_PRODUCT_HNSW + DOT_PRODUCT_HNSW; + + /** If true, the scores associated with vector comparisons in this strategy are in reverse order; that is, + * lower scores represent more similar vectors. Otherwise, if false, higher scores represent more similar vectors. + */ + public final boolean reversed; + + SearchStrategy(boolean reversed) { + this.reversed = reversed; + } + + SearchStrategy() { + reversed = false; + } + + /** + * Calculates a similarity score between the two vectors with a specified function. + * @param v1 a vector + * @param v2 another vector, of the same dimension + * @return the value of the strategy's score function applied to the two vectors + */ + public float compare(float[] v1, float[] v2) { + switch (this) { + case EUCLIDEAN_HNSW: + return squareDistance(v1, v2); + case DOT_PRODUCT_HNSW: + return dotProduct(v1, v2); + default: + throw new IllegalStateException("Incomparable search strategy: " + this); + } + } + + /** + * Return true if vectors indexed using this strategy will be indexed using an HNSW graph + */ + public boolean isHnsw() { + switch (this) { + case EUCLIDEAN_HNSW: + case DOT_PRODUCT_HNSW: + return true; + default: + return false; + } + } } /** diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java b/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java index ee6f07449cf..aa78d071def 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java @@ -176,47 +176,48 @@ class VectorValuesWriter { throw new UnsupportedOperationException(); } - @Override - public long cost() { - return size(); - } - - @Override public TopDocs search(float[] target, int k, int fanout) { throw new UnsupportedOperationException(); } + @Override + public long cost() { + return size(); + } + @Override public RandomAccessVectorValues randomAccess() { + // Must make a new delegate randomAccess so that we have our own distinct float[] + final RandomAccessVectorValues delegateRA = ((RandomAccessVectorValuesProducer) SortingVectorValues.this.delegate).randomAccess(); + return new RandomAccessVectorValues() { @Override public int size() { - return delegate.size(); + return delegateRA.size(); } @Override public int dimension() { - return delegate.dimension(); + return delegateRA.dimension(); } @Override public SearchStrategy searchStrategy() { - return delegate.searchStrategy(); + return delegateRA.searchStrategy(); } @Override public float[] vectorValue(int targetOrd) throws IOException { - return randomAccess.vectorValue(ordMap[targetOrd]); + return delegateRA.vectorValue(ordMap[targetOrd]); } @Override public BytesRef binaryValue(int targetOrd) { throw new UnsupportedOperationException(); } - }; } } @@ -252,7 +253,7 @@ class VectorValuesWriter { @Override public RandomAccessVectorValues randomAccess() { - return this; + return new BufferedVectorValues(docsWithField, vectors, dimension, searchStrategy); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java new file mode 100644 index 00000000000..e02cc40409f --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +abstract class BoundsChecker { + + float bound; + + /** + * Update the bound if sample is better + */ + abstract void update(float sample); + + /** + * Return whether the sample exceeds (is worse than) the bound + */ + abstract boolean check(float sample); + + static BoundsChecker create(boolean reversed) { + if (reversed) { + return new Min(); + } else { + return new Max(); + } + } + + static class Max extends BoundsChecker { + Max() { + bound = Float.NEGATIVE_INFINITY; + } + + void update(float sample) { + if (sample > bound) { + bound = sample; + } + } + + boolean check(float sample) { + return sample < bound; + } + } + + static class Min extends BoundsChecker { + + Min() { + bound = Float.POSITIVE_INFINITY; + } + + void update(float sample) { + if (sample < bound) { + bound = sample; + } + } + + boolean check(float sample) { + return sample > bound; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java new file mode 100644 index 00000000000..ed7be7d1b5a --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.TreeSet; + +import org.apache.lucene.index.KnnGraphValues; +import org.apache.lucene.index.RandomAccessVectorValues; +import org.apache.lucene.index.VectorValues; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * Navigable Small-world graph. Provides efficient approximate nearest neighbor + * search for high dimensional vectors. See Approximate nearest + * neighbor algorithm based on navigable small world graphs [2014] and this paper [2018] for details. + * + * The nomenclature is a bit different here from what's used in those papers: + * + *

Hyperparameters

+ *
    + *
  • numSeed is the equivalent of m in the 2012 paper; it controls the number of random entry points to sample.
  • + *
  • beamWidth in {@link HnswGraphBuilder} has the same meaning as efConst in the 2016 paper. It is the number of + * nearest neighbor candidates to track while searching the graph for each newly inserted node.
  • + *
  • maxConn has the same meaning as M in the later paper; it controls how many of the efConst neighbors are + * connected to the new node
  • + *
  • fanout the fanout parameter of {@link VectorValues#search(float[], int, int)} + * is used to control the values of numSeed and topK that are passed to this API. + * Thus fanout is like a combination of ef (search beam width) from the 2016 paper and m from the 2014 paper. + *
  • + *
+ * + *

Note: The graph may be searched by multiple threads concurrently, but updates are not thread-safe. Also note: there is no notion of + * deletions. Document searching built on top of this must do its own deletion-filtering.

+ */ +public final class HnswGraph { + + private final int maxConn; + private final VectorValues.SearchStrategy searchStrategy; + + // Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to HnswBuilder, and the + // node values are the ordinals of those vectors. + private final List graph; + + HnswGraph(int maxConn, VectorValues.SearchStrategy searchStrategy) { + graph = new ArrayList<>(); + graph.add(Neighbors.create(maxConn, searchStrategy.reversed)); + this.maxConn = maxConn; + this.searchStrategy = searchStrategy; + } + + /** + * Searches for the nearest neighbors of a query vector. + * @param query search query vector + * @param topK the number of nodes to be returned + * @param numSeed the number of random entry points to sample + * @param vectors vector values + * @param graphValues the graph values. May represent the entire graph, or a level in a hierarchical graph. + * @param random a source of randomness, used for generating entry points to the graph + * @return a priority queue holding the neighbors found + */ + public static Neighbors search(float[] query, int topK, int numSeed, RandomAccessVectorValues vectors, KnnGraphValues graphValues, + Random random) throws IOException { + VectorValues.SearchStrategy searchStrategy = vectors.searchStrategy(); + // TODO: use unbounded priority queue + TreeSet candidates; + if (searchStrategy.reversed) { + candidates = new TreeSet<>(Comparator.reverseOrder()); + } else { + candidates = new TreeSet<>(); + } + int size = vectors.size(); + for (int i = 0; i < numSeed && i < size; i++) { + int entryPoint = random.nextInt(size); + candidates.add(new Neighbor(entryPoint, searchStrategy.compare(query, vectors.vectorValue(entryPoint)))); + } + // set of ordinals that have been visited by search on this layer, used to avoid backtracking + Set visited = new HashSet<>(); + // TODO: use PriorityQueue's sentinel optimization? + Neighbors results = Neighbors.create(topK, searchStrategy.reversed); + for (Neighbor c : candidates) { + visited.add(c.node()); + results.insertWithOverflow(c); + } + // Set the bound to the worst current result and below reject any newly-generated candidates failing + // to exceed this bound + BoundsChecker bound = BoundsChecker.create(searchStrategy.reversed); + bound.bound = results.top().score(); + while (candidates.size() > 0) { + // get the best candidate (closest or best scoring) + Neighbor c = candidates.pollLast(); + if (results.size() >= topK) { + if (bound.check(c.score())) { + break; + } + } + graphValues.seek(c.node()); + int friendOrd; + while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) { + if (visited.contains(friendOrd)) { + continue; + } + visited.add(friendOrd); + float score = searchStrategy.compare(query, vectors.vectorValue(friendOrd)); + if (results.size() < topK || bound.check(score) == false) { + Neighbor n = new Neighbor(friendOrd, score); + candidates.add(n); + results.insertWithOverflow(n); + bound.bound = results.top().score(); + } + } + } + results.setVisitedCount(visited.size()); + return results; + } + + /** + * Returns the nodes connected to the given node by its outgoing neighborNodes in an unpredictable order. Each node inserted + * by HnswGraphBuilder corresponds to a vector, and the node is the vector's ordinal. + * @param node the node whose friends are returned + */ + public int[] getNeighbors(int node) { + Neighbors neighbors = graph.get(node); + int[] result = new int[neighbors.size()]; + int i = 0; + for (Neighbor n : neighbors) { + result[i++] = n.node(); + } + return result; + } + + /** Connects two nodes symmetrically, limiting the maximum number of connections from either node. + * node1 must be less than node2 and must already have been inserted to the graph */ + void connectNodes(int node1, int node2, float score) { + connect(node1, node2, score); + if (node2 == graph.size()) { + addNode(); + } + connect(node2, node1, score); + } + + KnnGraphValues getGraphValues() { + return new HnswGraphValues(); + } + + /** + * Makes a connection from the node to a neighbor, dropping the worst connection when maxConn is exceeded + * @param node1 node to connect *from* + * @param node2 node to connect *to* + * @param score searchStrategy.score() of the vectors associated with the two nodes + */ + boolean connect(int node1, int node2, float score) { + //System.out.println(" HnswGraph.connect " + node1 + " -> " + node2); + assert node1 >= 0 && node2 >= 0; + Neighbors nn = graph.get(node1); + assert nn != null; + if (nn.size() == maxConn) { + Neighbor top = nn.top(); + if (score < top.score() == nn.reversed()) { + top.update(node2, score); + nn.updateTop(); + return true; + } + } else { + nn.add(new Neighbor(node2, score)); + return true; + } + return false; + } + + int addNode() { + graph.add(Neighbors.create(maxConn, searchStrategy.reversed)); + return graph.size() - 1; + } + + /** + * Present this graph as KnnGraphValues, used for searching while inserting new nodes. + */ + private class HnswGraphValues extends KnnGraphValues { + + private int arcUpTo; + private int[] neighborNodes; + + @Override + public void seek(int targetNode) { + arcUpTo = 0; + neighborNodes = HnswGraph.this.getNeighbors(targetNode); + } + + @Override + public int nextNeighbor() { + if (arcUpTo >= neighborNodes.length) { + return NO_MORE_DOCS; + } + return neighborNodes[arcUpTo++]; + } + + } + +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java new file mode 100644 index 00000000000..b308a86fa91 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.io.IOException; +import java.util.Random; + +import org.apache.lucene.index.KnnGraphValues; +import org.apache.lucene.index.RandomAccessVectorValues; +import org.apache.lucene.index.RandomAccessVectorValuesProducer; +import org.apache.lucene.index.VectorValues; +import org.apache.lucene.util.BytesRef; + +/** + * Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the hyperparameters. + */ +public final class HnswGraphBuilder { + + // default random seed for level generation + private static final long DEFAULT_RAND_SEED = System.currentTimeMillis(); + + // expose for testing. + public static long randSeed = DEFAULT_RAND_SEED; + + // These "default" hyperparameter settings are exposed (and nonfinal) to enable performance testing + // since the indexing API doesn't provide any control over them. + + // default max connections per node + public static int DEFAULT_MAX_CONN = 16; + + // default candidate list size + static int DEFAULT_BEAM_WIDTH = 16; + + private final int maxConn; + private final int beamWidth; + + private final BoundedVectorValues boundedVectors; + private final VectorValues.SearchStrategy searchStrategy; + private final HnswGraph hnsw; + private final Random random; + + /** + * Reads all the vectors from a VectorValues, builds a graph connecting them by their dense ordinals, using default + * hyperparameter settings, and returns the resulting graph. + * @param vectorValues the vectors whose relations are represented by the graph + */ + public static HnswGraph build(RandomAccessVectorValuesProducer vectorValues) throws IOException { + HnswGraphBuilder builder = new HnswGraphBuilder(vectorValues); + return builder.build(vectorValues.randomAccess()); + } + + /** + * Reads all the vectors from a VectorValues, builds a graph connecting them by their dense ordinals, using the given + * hyperparameter settings, and returns the resulting graph. + * @param vectorValues the vectors whose relations are represented by the graph + * @param maxConn the number of connections to make when adding a new graph node; roughly speaking the graph fanout. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + * @param seed the seed for a random number generator used during graph construction. Provide this to ensure repeatable construction. + */ + public static HnswGraph build(RandomAccessVectorValuesProducer vectorValues, int maxConn, int beamWidth, long seed) throws IOException { + HnswGraphBuilder builder = new HnswGraphBuilder(vectorValues, maxConn, beamWidth, seed); + return builder.build(vectorValues.randomAccess()); + } + + /** + * Reads all the vectors from two copies of a random access VectorValues. Providing two copies enables efficient retrieval + * without extra data copying, while avoiding collision of the returned values. + * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet accessor for the vectors + */ + HnswGraph build(RandomAccessVectorValues vectors) throws IOException { + if (vectors == boundedVectors.raDelegate) { + throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); + } + for (int node = 1; node < vectors.size(); node++) { + insert(vectors.vectorValue(node)); + } + return hnsw; + } + + /** Construct the builder with default configurations */ + private HnswGraphBuilder(RandomAccessVectorValuesProducer vectors) { + this(vectors, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, randSeed); + } + + /** Full constructor */ + HnswGraphBuilder(RandomAccessVectorValuesProducer vectors, int maxConn, int beamWidth, long seed) { + RandomAccessVectorValues vectorValues = vectors.randomAccess(); + searchStrategy = vectorValues.searchStrategy(); + if (searchStrategy == VectorValues.SearchStrategy.NONE) { + throw new IllegalStateException("No distance function"); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + boundedVectors = new BoundedVectorValues(vectorValues); + this.hnsw = new HnswGraph(maxConn, searchStrategy); + random = new Random(seed); + } + + /** Inserts a doc with vector value to the graph */ + private void insert(float[] value) throws IOException { + addGraphNode(value); + + // add the vector value + boundedVectors.inc(); + } + + private void addGraphNode(float[] value) throws IOException { + KnnGraphValues graphValues = hnsw.getGraphValues(); + Neighbors candidates = HnswGraph.search(value, beamWidth, 2 * beamWidth, boundedVectors, graphValues, random); + + int node = hnsw.addNode(); + + // connect the nearest neighbors to the just inserted node + addNearestNeighbors(node, candidates); + } + + private void addNearestNeighbors(int newNode, Neighbors neighbors) { + // connect the nearest neighbors, relying on the graph's Neighbors' priority queues to drop off distant neighbors + for (Neighbor neighbor : neighbors) { + if (hnsw.connect(newNode, neighbor.node(), neighbor.score())) { + hnsw.connect(neighbor.node(), newNode, neighbor.score()); + } + } + } + + /** + * Provides a random access VectorValues view over a delegate VectorValues, bounding the maximum ord. + * TODO: get rid of this, all it does is track a counter + */ + private static class BoundedVectorValues implements RandomAccessVectorValues { + + final RandomAccessVectorValues raDelegate; + + int size; + + BoundedVectorValues(RandomAccessVectorValues delegate) { + raDelegate = delegate; + if (delegate.size() > 0) { + // we implicitly add the first node + size = 1; + } + } + + void inc() { + ++size; + } + + @Override + public int size() { + return size; + } + + @Override + public int dimension() { return raDelegate.dimension(); } + + @Override + public VectorValues.SearchStrategy searchStrategy() { + return raDelegate.searchStrategy(); + } + + @Override + public float[] vectorValue(int target) throws IOException { + return raDelegate.vectorValue(target); + } + + @Override + public BytesRef binaryValue(int targetOrd) throws IOException { + throw new UnsupportedOperationException(); + } + } + + +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbor.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbor.java new file mode 100644 index 00000000000..01cf2311e59 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbor.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +/** A neighbor node in the HNSW graph; holds the node ordinal and its distance score. */ +public class Neighbor implements Comparable { + + private int node; + + private float score; + + public Neighbor(int node, float score) { + this.node = node; + this.score = score; + } + + public int node() { + return node; + } + + public float score() { + return score; + } + + void update(int node, float score) { + this.node = node; + this.score = score; + } + + @Override + public int compareTo(Neighbor o) { + if (score == o.score) { + return o.node - node; + } else { + assert node != o.node : "attempt to add the same node " + node + " twice with different scores: " + score + " != " + o.score; + return Float.compare(score, o.score); + } + } + + @Override + public boolean equals(Object other) { + return other instanceof Neighbor + && ((Neighbor) other).node == node; + } + + @Override + public int hashCode() { + return 39 + 61 * node; + } + + @Override + public String toString() { + return "(" + node + ", " + score + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java new file mode 100644 index 00000000000..6ca761b39e0 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import org.apache.lucene.util.PriorityQueue; + +/** Neighbors queue. */ +public abstract class Neighbors extends PriorityQueue { + + public static Neighbors create(int maxSize, boolean reversed) { + if (reversed) { + return new ReverseNeighbors(maxSize); + } else { + return new ForwardNeighbors(maxSize); + } + } + + public abstract boolean reversed(); + + // Used to track the number of neighbors visited during a single graph traversal + private int visitedCount; + + private Neighbors(int maxSize) { + super(maxSize); + } + + private static class ForwardNeighbors extends Neighbors { + ForwardNeighbors(int maxSize) { + super(maxSize); + } + + @Override + protected boolean lessThan(Neighbor a, Neighbor b) { + if (a.score() == b.score()) { + return a.node() > b.node(); + } + return a.score() < b.score(); + } + + @Override + public boolean reversed() { return false; } + } + + private static class ReverseNeighbors extends Neighbors { + ReverseNeighbors(int maxSize) { + super(maxSize); + } + + @Override + protected boolean lessThan(Neighbor a, Neighbor b) { + if (a.score() == b.score()) { + return a.node() > b.node(); + } + return b.score() < a.score(); + } + + @Override + public boolean reversed() { return true; } + } + + void setVisitedCount(int visitedCount) { + this.visitedCount = visitedCount; + } + + public int visitedCount() { + return visitedCount; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("Neighbors=["); + this.iterator().forEachRemaining(sb::append); + sb.append("]"); + return sb.toString(); + } + +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/package-info.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/package-info.java new file mode 100644 index 00000000000..ba95e66a71f --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Navigable Small-World graph, nominally Hierarchical but currently only has a single + * layer. Provides efficient approximate nearest neighbor search for high dimensional vectors. + */ +package org.apache.lucene.util.hnsw; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java new file mode 100644 index 00000000000..d7f8fc91248 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.index; + + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.lucene90.Lucene90VectorReader; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.StringField; +import org.apache.lucene.document.VectorField; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; + +import org.apache.lucene.util.hnsw.HnswGraphBuilder; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.hnsw.HnswGraphBuilder.randSeed; + +/** Tests indexing of a knn-graph */ +public class TestKnnGraph extends LuceneTestCase { + + private static final String KNN_GRAPH_FIELD = "vector"; + + @Before + public void setup() { + randSeed = random().nextLong(); + } + + /** + * Basic test of creating documents in a graph + */ + public void testBasic() throws Exception { + try (Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) { + int numDoc = atLeast(10); + int dimension = atLeast(3); + float[][] values = new float[numDoc][]; + for (int i = 0; i < numDoc; i++) { + if (random().nextBoolean()) { + values[i] = new float[dimension]; + for (int j = 0; j < dimension; j++) { + values[i][j] = random().nextFloat(); + } + } + add(iw, i, values[i]); + } + assertConsistentGraph(iw, values); + } + } + + public void testSingleDocument() throws Exception { + try (Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) { + float[][] values = new float[][]{new float[]{0, 1, 2}}; + add(iw, 0, values[0]); + assertConsistentGraph(iw, values); + iw.commit(); + assertConsistentGraph(iw, values); + } + } + + /** + * Verify that the graph properties are preserved when merging + */ + public void testMerge() throws Exception { + try (Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) { + int numDoc = atLeast(100); + int dimension = atLeast(10); + float[][] values = new float[numDoc][]; + for (int i = 0; i < numDoc; i++) { + if (random().nextBoolean()) { + values[i] = new float[dimension]; + for (int j = 0; j < dimension; j++) { + values[i][j] = random().nextFloat(); + } + } + add(iw, i, values[i]); + if (random().nextInt(10) == 3) { + //System.out.println("commit @" + i); + iw.commit(); + } + } + if (random().nextBoolean()) { + iw.forceMerge(1); + } + assertConsistentGraph(iw, values); + } + } + + // TODO: testSorted + // TODO: testDeletions + + /** + * Verify that searching does something reasonable + */ + public void testSearch() throws Exception { + try (Directory dir = newDirectory(); + // don't allow random merges; they mess up the docid tie-breaking assertion + IndexWriter iw = new IndexWriter(dir, new IndexWriterConfig().setCodec(Codec.forName("Lucene90")))) { + // Add a document for every cartesian point in an NxN square so we can + // easily know which are the nearest neighbors to every point. Insert by iterating + // using a prime number that is not a divisor of N*N so that we will hit each point once, + // and chosen so that points will be inserted in a deterministic + // but somewhat distributed pattern + int n = 5, stepSize = 17; + float[][] values = new float[n * n][]; + int index = 0; + for (int i = 0; i < values.length; i++) { + // System.out.printf("%d: (%d, %d)\n", i, index % n, index / n); + values[i] = new float[]{index % n, index / n}; + index = (index + stepSize) % (n * n); + add(iw, i, values[i]); + if (i == 13) { + // create 2 segments + iw.commit(); + } + } + boolean forceMerge = random().nextBoolean(); + //System.out.println(""); + if (forceMerge) { + iw.forceMerge(1); + } + assertConsistentGraph(iw, values); + try (DirectoryReader dr = DirectoryReader.open(iw)) { + // results are ordered by score (descending) and docid (ascending); + // This is the insertion order: + // column major, origin at upper left + // 0 15 5 20 10 + // 3 18 8 23 13 + // 6 21 11 1 16 + // 9 24 14 4 19 + // 12 2 17 7 22 + + // For this small graph the "search" is exhaustive, so this mostly tests the APIs, the orientation of the + // various priority queues, the scoring function, but not so much the approximate KNN search algo + assertGraphSearch(new int[]{0, 15, 3, 18, 5}, new float[]{0f, 0.1f}, dr); + // test tiebreaking by docid + assertGraphSearch(new int[]{11, 1, 8, 14, 21}, new float[]{2, 2}, dr); + assertGraphSearch(new int[]{15, 18, 0, 3, 5},new float[]{0.3f, 0.8f}, dr); + } + } + } + + private void assertGraphSearch(int[] expected, float[] vector, IndexReader reader) throws IOException { + TopDocs results = doKnnSearch(reader, vector, 5); + for (ScoreDoc doc : results.scoreDocs) { + // map docId to insertion id + int id = Integer.parseInt(reader.document(doc.doc).get("id")); + doc.doc = id; + } + assertResults(expected, results); + } + + private static TopDocs doKnnSearch(IndexReader reader, float[] vector, int k) throws IOException { + TopDocs[] results = new TopDocs[reader.leaves().size()]; + for (LeafReaderContext ctx: reader.leaves()) { + results[ctx.ord] = ctx.reader().getVectorValues(KNN_GRAPH_FIELD) + .search(vector, k, 10); + if (ctx.docBase > 0) { + for (ScoreDoc doc : results[ctx.ord].scoreDocs) { + doc.doc += ctx.docBase; + } + } + } + return TopDocs.merge(k, results); + } + + private void assertResults(int[] expected, TopDocs results) { + assertEquals(results.toString(), expected.length, results.scoreDocs.length); + for (int i = expected.length - 1; i >= 0; i--) { + assertEquals(Arrays.toString(results.scoreDocs), expected[i], results.scoreDocs[i].doc); + } + } + + // For each leaf, verify that its graph nodes are 1-1 with vectors, that the vectors are the expected values, + // and that the graph is fully connected and symmetric. + // NOTE: when we impose max-fanout on the graph it wil no longer be symmetric, but should still + // be fully connected. Is there any other invariant we can test? Well, we can check that max fanout + // is respected. We can test *desirable* properties of the graph like small-world (the graph diameter + // should be tightly bounded). + private void assertConsistentGraph(IndexWriter iw, float[][] values) throws IOException { + int totalGraphDocs = 0; + try (DirectoryReader dr = DirectoryReader.open(iw)) { + for (LeafReaderContext ctx: dr.leaves()) { + LeafReader reader = ctx.reader(); + VectorValues vectorValues = reader.getVectorValues(KNN_GRAPH_FIELD); + Lucene90VectorReader vectorReader = ((Lucene90VectorReader) ((CodecReader) reader).getVectorReader()); + if (vectorReader == null) { + continue; + } + KnnGraphValues graphValues = vectorReader.getGraphValues(KNN_GRAPH_FIELD); + assertTrue((vectorValues == null) == (graphValues == null)); + if (vectorValues == null) { + continue; + } + int[][] graph = new int[reader.maxDoc()][]; + boolean foundOrphan= false; + int graphSize = 0; + int node = -1; + for (int i = 0; i < reader.maxDoc(); i++) { + int nextDocWithVectors = vectorValues.advance(i); + //System.out.println("advanced to " + nextDocWithVectors); + while (i < nextDocWithVectors && i < reader.maxDoc()) { + int id = Integer.parseInt(reader.document(i).get("id")); + assertNull("document " + id + " has no vector, but was expected to", values[id]); + ++i; + } + if (nextDocWithVectors == NO_MORE_DOCS) { + break; + } + int id = Integer.parseInt(reader.document(i).get("id")); + graphValues.seek(++node); + // documents with KnnGraphValues have the expected vectors + float[] scratch = vectorValues.vectorValue(); + assertArrayEquals("vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch), + values[id], scratch, 0f); + // We collect neighbors for analysis below + List friends = new ArrayList<>(); + int arc; + while ((arc = graphValues.nextNeighbor()) != NO_MORE_DOCS) { + friends.add(arc); + } + if (friends.size() == 0) { + //System.out.printf("knngraph @%d is singleton (advance returns %d)\n", i, nextWithNeighbors); + foundOrphan = true; + } else { + // NOTE: these friends are dense ordinals, not docIds. + int[] friendCopy = new int[friends.size()]; + for (int j = 0; j < friends.size(); j++) { + friendCopy[j] = friends.get(j); + } + graph[graphSize] = friendCopy; + //System.out.printf("knngraph @%d => %s\n", i, Arrays.toString(graph[i])); + } + graphSize++; + } + assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + if (foundOrphan) { + assertEquals("graph is not fully connected", 1, graphSize); + } else { + assertTrue("Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1); + } + // assert that the graph in each leaf is connected and undirected (ie links are reciprocated) + // assertReciprocal(graph); + assertConnected(graph); + assertMaxConn(graph, HnswGraphBuilder.DEFAULT_MAX_CONN); + totalGraphDocs += graphSize; + } + } + int expectedCount = 0; + for (float[] friends : values) { + if (friends != null) { + ++expectedCount; + } + } + assertEquals(expectedCount, totalGraphDocs); + } + + private void assertMaxConn(int[][] graph, int maxConn) { + for (int i = 0; i < graph.length; i++) { + if (graph[i] != null) { + assert (graph[i].length <= maxConn); + for (int j = 0; j < graph[i].length; j++) { + int k = graph[i][j]; + assertNotNull(graph[k]); + } + } + } + } + + private void assertReciprocal(int[][] graph) { + // The graph is undirected: if a -> b then b -> a. + for (int i = 0; i < graph.length; i++) { + if (graph[i] != null) { + for (int j = 0; j < graph[i].length; j++) { + int k = graph[i][j]; + assertNotNull(graph[k]); + assertTrue("" + i + "->" + k + " is not reciprocated", Arrays.binarySearch(graph[k], i) >= 0); + } + } + } + } + + private void assertConnected(int[][] graph) { + // every node in the graph is reachable from every other node + Set visited = new HashSet<>(); + List queue = new LinkedList<>(); + int count = 0; + for (int[] entry : graph) { + if (entry != null) { + if (queue.isEmpty()) { + queue.add(entry[0]); // start from any node + //System.out.println("start at " + entry[0]); + } + ++count; + } + } + while(queue.isEmpty() == false) { + int i = queue.remove(0); + assertNotNull("expected neighbors of " + i, graph[i]); + visited.add(i); + for (int j : graph[i]) { + if (visited.contains(j) == false) { + //System.out.println(" ... " + j); + queue.add(j); + } + } + } + // we visited each node exactly once + assertEquals("Attempted to walk entire graph but only visited " + visited.size(), count, visited.size()); + } + + + private void add(IndexWriter iw, int id, float[] vector) throws IOException { + Document doc = new Document(); + if (vector != null) { + // TODO: choose random search strategy + doc.add(new VectorField(KNN_GRAPH_FIELD, vector, VectorValues.SearchStrategy.EUCLIDEAN_HNSW)); + } + doc.add(new StringField("id", Integer.toString(id), Field.Store.YES)); + //System.out.println("add " + id + " " + Arrays.toString(vector)); + iw.addDocument(doc); + } + +} diff --git a/lucene/core/src/test/org/apache/lucene/index/TestVectorValues.java b/lucene/core/src/test/org/apache/lucene/index/TestVectorValues.java index 4b7ccbde301..a9cd946469c 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestVectorValues.java @@ -601,6 +601,51 @@ public class TestVectorValues extends LuceneTestCase { } } + public void testIndexMultipleVectorFields() throws Exception { + try (Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + float[] v = new float[]{1}; + doc.add(new VectorField("field1", v, SearchStrategy.EUCLIDEAN_HNSW)); + doc.add(new VectorField("field2", new float[]{1, 2, 3}, SearchStrategy.NONE)); + iw.addDocument(doc); + v[0] = 2; + iw.addDocument(doc); + doc = new Document(); + doc.add(new VectorField("field3", new float[]{1, 2, 3}, SearchStrategy.DOT_PRODUCT_HNSW)); + iw.addDocument(doc); + iw.forceMerge(1); + try (IndexReader reader = iw.getReader()) { + LeafReader leaf = reader.leaves().get(0).reader(); + + VectorValues vectorValues = leaf.getVectorValues("field1"); + assertEquals(1, vectorValues.dimension()); + assertEquals(2, vectorValues.size()); + vectorValues.nextDoc(); + assertEquals(1f, vectorValues.vectorValue()[0], 0); + vectorValues.nextDoc(); + assertEquals(2f, vectorValues.vectorValue()[0], 0); + assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + + VectorValues vectorValues2 = leaf.getVectorValues("field2"); + assertEquals(3, vectorValues2.dimension()); + assertEquals(2, vectorValues2.size()); + vectorValues2.nextDoc(); + assertEquals(2f, vectorValues2.vectorValue()[1], 0); + vectorValues2.nextDoc(); + assertEquals(2f, vectorValues2.vectorValue()[1], 0); + assertEquals(NO_MORE_DOCS, vectorValues2.nextDoc()); + + VectorValues vectorValues3 = leaf.getVectorValues("field3"); + assertEquals(3, vectorValues3.dimension()); + assertEquals(1, vectorValues3.size()); + vectorValues3.nextDoc(); + assertEquals(1f, vectorValues3.vectorValue()[0], 0); + assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc()); + } + } + } + /** * Index random vectors, sometimes skipping documents, sometimes deleting a document, * sometimes merging, sometimes sorting the index, diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java new file mode 100644 index 00000000000..ce8a6edf931 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.io.IOException; +import java.io.OutputStream; +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashSet; +import java.util.Locale; +import java.util.Set; + +import org.apache.lucene.codecs.lucene90.Lucene90VectorReader; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.document.VectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnGraphValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorValues; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.PrintStreamInfoStream; +import org.apache.lucene.util.SuppressForbidden; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** For testing indexing and search performance of a knn-graph + * + * java -cp .../lib/*.jar org.apache.lucene.util.hnsw.KnnGraphTester -ndoc 1000000 -search .../vectors.bin +*/ +public class KnnGraphTester { + + private final static String KNN_FIELD = "knn"; + private final static String ID_FIELD = "id"; + private final static VectorValues.SearchStrategy SEARCH_STRATEGY = VectorValues.SearchStrategy.DOT_PRODUCT_HNSW; + + private int numDocs; + private int dim; + private int topK; + private int numIters; + private int fanout; + private Path indexPath; + private boolean quiet; + private boolean reindex; + private int reindexTimeMsec; + + @SuppressForbidden(reason="uses Random()") + private KnnGraphTester() { + // set defaults + numDocs = 1000; + numIters = 1000; + dim = 256; + topK = 100; + fanout = topK; + indexPath = Paths.get("knn_test_index"); + } + + public static void main(String... args) throws Exception { + new KnnGraphTester().run(args); + } + + private void run(String... args) throws Exception { + String operation = null, docVectorsPath = null, queryPath = null; + for (int iarg = 0; iarg < args.length; iarg++) { + String arg = args[iarg]; + switch(arg) { + case "-generate": + case "-search": + case "-check": + case "-stats": + if (operation != null) { + throw new IllegalArgumentException("Specify only one operation, not both " + arg + " and " + operation); + } + if (iarg == args.length - 1) { + throw new IllegalArgumentException("Operation " + arg + " requires a following pathname"); + } + operation = arg; + docVectorsPath = args[++iarg]; + if (operation.equals("-search")) { + queryPath = args[++iarg]; + } + break; + case "-fanout": + if (iarg == args.length - 1) { + throw new IllegalArgumentException("-fanout requires a following number"); + } + fanout = Integer.parseInt(args[++iarg]); + break; + case "-beamWidthIndex": + if (iarg == args.length - 1) { + throw new IllegalArgumentException("-beamWidthIndex requires a following number"); + } + HnswGraphBuilder.DEFAULT_BEAM_WIDTH = Integer.parseInt(args[++iarg]); + break; + case "-maxConn": + if (iarg == args.length - 1) { + throw new IllegalArgumentException("-maxConn requires a following number"); + } + HnswGraphBuilder.DEFAULT_MAX_CONN = Integer.parseInt(args[++iarg]); + break; + case "-dim": + if (iarg == args.length - 1) { + throw new IllegalArgumentException("-dim requires a following number"); + } + dim = Integer.parseInt(args[++iarg]); + break; + case "-ndoc": + if (iarg == args.length - 1) { + throw new IllegalArgumentException("-ndoc requires a following number"); + } + numDocs = Integer.parseInt(args[++iarg]); + break; + case "-niter": + if (iarg == args.length - 1) { + throw new IllegalArgumentException("-niter requires a following number"); + } + numIters = Integer.parseInt(args[++iarg]); + break; + case "-reindex": + reindex = true; + break; + case "-forceMerge": + operation = "-forceMerge"; + break; + case "-quiet": + quiet = true; + break; + default: + throw new IllegalArgumentException("unknown argument " + arg); + //usage(); + } + } + if (operation == null) { + usage(); + } + if (reindex) { + reindexTimeMsec = createIndex(Paths.get(docVectorsPath), indexPath); + } + switch (operation) { + case "-search": + testSearch(indexPath, Paths.get(queryPath), getNN(Paths.get(docVectorsPath), Paths.get(queryPath))); + break; + case "-forceMerge": + forceMerge(); + break; + case "-stats": + printFanoutHist(indexPath); + break; + } + } + + @SuppressForbidden(reason="Prints stuff") + private void printFanoutHist(Path indexPath) throws IOException { + try (Directory dir = FSDirectory.open(indexPath); + DirectoryReader reader = DirectoryReader.open(dir)) { + // int[] globalHist = new int[reader.maxDoc()]; + for (LeafReaderContext context : reader.leaves()) { + LeafReader leafReader = context.reader(); + KnnGraphValues knnValues = ((Lucene90VectorReader) ((CodecReader) leafReader).getVectorReader()).getGraphValues(KNN_FIELD); + System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc()); + printGraphFanout(knnValues, leafReader.maxDoc()); + } + } + } + + @SuppressForbidden(reason="Prints stuff") + private void forceMerge() throws IOException { + IndexWriterConfig iwc = new IndexWriterConfig() + .setOpenMode(IndexWriterConfig.OpenMode.APPEND); + iwc.setInfoStream(new PrintStreamInfoStream(System.out)); + System.out.println("Force merge index in " + indexPath); + try (IndexWriter iw = new IndexWriter(FSDirectory.open(indexPath), iwc)) { + iw.forceMerge(1); + } + } + + @SuppressForbidden(reason="Prints stuff") + private void printGraphFanout(KnnGraphValues knnValues, int numDocs) throws IOException { + int min = Integer.MAX_VALUE, max = 0, total = 0; + int count = 0; + int[] leafHist = new int[numDocs]; + for (int node = 0; node < numDocs; node++) { + knnValues.seek(node); + int n = 0; + while (knnValues.nextNeighbor() != NO_MORE_DOCS) { + ++n; + } + ++leafHist[n]; + max = Math.max(max, n); + min = Math.min(min, n); + if (n > 0) { + ++count; + total += n; + } + } + System.out.printf("Graph size=%d, Fanout min=%d, mean=%.2f, max=%d\n", count, min, total / (float) count, max); + printHist(leafHist, max, count, 10); + } + + @SuppressForbidden(reason="Prints stuff") + private void printHist(int[] hist, int max, int count, int nbuckets) { + System.out.print("%"); + for (int i=0; i <= nbuckets; i ++) { + System.out.printf("%4d", i * 100 / nbuckets); + } + System.out.printf("\n %4d", hist[0]); + int total = 0, ibucket = 1; + for (int i = 1; i <= max && ibucket <= nbuckets; i++) { + total += hist[i]; + while (total >= count * ibucket / nbuckets) { + System.out.printf("%4d", i); + ++ibucket; + } + } + System.out.println(); + } + + @SuppressForbidden(reason="Prints stuff") + private void testSearch(Path indexPath, Path queryPath, int[][] nn) throws IOException { + TopDocs[] results = new TopDocs[numIters]; + long elapsed, totalCpuTime, totalVisited = 0; + try (FileChannel q = FileChannel.open(queryPath)) { + FloatBuffer targets = q.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN) + .asFloatBuffer(); + float[] target = new float[dim]; + if (quiet == false) { + System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout); + } + long start; + ThreadMXBean bean = ManagementFactory.getThreadMXBean(); + long cpuTimeStartNs; + try (Directory dir = FSDirectory.open(indexPath); + DirectoryReader reader = DirectoryReader.open(dir)) { + + for (int i = 0; i < 1000; i++) { + // warm up + targets.get(target); + results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout); + } + targets.position(0); + start = System.nanoTime(); + cpuTimeStartNs = bean.getCurrentThreadCpuTime(); + for (int i = 0; i < numIters; i++) { + targets.get(target); + results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout); + } + totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000; + elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms + for (int i = 0; i < numIters; i++) { + totalVisited += results[i].totalHits.value; + for (ScoreDoc doc : results[i].scoreDocs) { + doc.doc = Integer.parseInt(reader.document(doc.doc).get("id")); + } + } + } + if (quiet == false) { + System.out.println("completed " + numIters + " searches in " + elapsed + " ms: " + ((1000 * numIters) / elapsed) + " QPS " + + "CPU time=" + totalCpuTime + "ms"); + } + } + if (quiet == false) { + System.out.println("checking results"); + } + float recall = checkResults(results, nn); + totalVisited /= numIters; + if (quiet) { + System.out.printf(Locale.ROOT, "%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n", recall, totalCpuTime / (float) numIters, + numDocs, fanout, HnswGraphBuilder.DEFAULT_MAX_CONN, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, totalVisited, reindexTimeMsec); + } + } + + private static TopDocs doKnnSearch(IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException { + TopDocs[] results = new TopDocs[reader.leaves().size()]; + for (LeafReaderContext ctx: reader.leaves()) { + results[ctx.ord] = ctx.reader().getVectorValues(field).search(vector, k, fanout); + int docBase = ctx.docBase; + for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) { + scoreDoc.doc += docBase; + } + } + return TopDocs.merge(k, results); + } + + private float checkResults(TopDocs[] results, int[][] nn) { + int totalMatches = 0; + int totalResults = 0; + for (int i = 0; i < results.length; i++) { + int n = results[i].scoreDocs.length; + totalResults += n; + //System.out.println(Arrays.toString(nn[i])); + //System.out.println(Arrays.toString(results[i].scoreDocs)); + totalMatches += compareNN(nn[i], results[i]); + } + if (quiet == false) { + System.out.println("total matches = " + totalMatches + " out of " + totalResults); + System.out.printf(Locale.ROOT, "Average overlap = %.2f%%\n", ((100.0 * totalMatches) / totalResults)); + } + return totalMatches / (float) totalResults; + } + + private int compareNN(int[] expected, TopDocs results) { + int matched = 0; + /* + System.out.print("expected="); + for (int j = 0; j < expected.length; j++) { + System.out.print(expected[j]); + System.out.print(", "); + } + System.out.print('\n'); + System.out.println("results="); + for (int j = 0; j < results.scoreDocs.length; j++) { + System.out.print("" + results.scoreDocs[j].doc + ":" + results.scoreDocs[j].score + ", "); + } + System.out.print('\n'); + */ + Set expectedSet = new HashSet<>(); + for (int i = 0; i < results.scoreDocs.length; i++) { + expectedSet.add(expected[i]); + } + for (ScoreDoc scoreDoc : results.scoreDocs) { + if (expectedSet.contains(scoreDoc.doc)) { + ++matched; + } + } + return matched; + } + + private int[][] getNN(Path docPath, Path queryPath) throws IOException { + // look in working directory for cached nn file + String nnFileName = "nn-" + numDocs + "-" + numIters + "-" + topK + "-" + dim + ".bin"; + Path nnPath = Paths.get(nnFileName); + if (Files.exists(nnPath)) { + return readNN(nnPath); + } else { + int[][] nn = computeNN(docPath, queryPath); + writeNN(nn, nnPath); + return nn; + } + } + + private int[][] readNN(Path nnPath) throws IOException { + int[][] result = new int[numIters][]; + try (FileChannel in = FileChannel.open(nnPath)) { + IntBuffer intBuffer = in.map(FileChannel.MapMode.READ_ONLY, 0, numIters * topK * Integer.BYTES) + .order(ByteOrder.LITTLE_ENDIAN) + .asIntBuffer(); + for (int i = 0; i < numIters; i++) { + result[i] = new int[topK]; + intBuffer.get(result[i]); + } + } + return result; + } + + private void writeNN(int[][] nn, Path nnPath) throws IOException { + if (quiet == false) { + System.out.println("writing true nearest neighbors to " + nnPath); + } + ByteBuffer tmp = ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); + try (OutputStream out = Files.newOutputStream(nnPath)) { + for (int i = 0; i < numIters; i++) { + tmp.asIntBuffer().put(nn[i]); + out.write(tmp.array()); + } + } + } + + private int[][] computeNN(Path docPath, Path queryPath) throws IOException { + int[][] result = new int[numIters][]; + if (quiet == false) { + System.out.println("computing true nearest neighbors of " + numIters + " target vectors"); + } + try (FileChannel in = FileChannel.open(docPath); + FileChannel qIn = FileChannel.open(queryPath)) { + FloatBuffer queries = qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN) + .asFloatBuffer(); + float[] vector = new float[dim]; + float[] query = new float[dim]; + for (int i = 0; i < numIters; i++) { + queries.get(query); + long totalBytes = (long) numDocs * dim * Float.BYTES; + int blockSize = (int) Math.min(totalBytes, (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)), offset = 0; + int j = 0; + //System.out.println("totalBytes=" + totalBytes); + while (j < numDocs) { + FloatBuffer vectors = in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize) + .order(ByteOrder.LITTLE_ENDIAN) + .asFloatBuffer(); + offset += blockSize; + Neighbors queue = Neighbors.create(topK, SEARCH_STRATEGY.reversed); + for (; j < numDocs && vectors.hasRemaining(); j++) { + vectors.get(vector); + float d = SEARCH_STRATEGY.compare(query, vector); + queue.insertWithOverflow(new Neighbor(j, d)); + } + result[i] = new int[topK]; + for (int k = topK - 1; k >= 0; k--) { + Neighbor n = queue.pop(); + result[i][k] = n.node(); + //System.out.print(" " + n); + } + if (quiet == false && (i + 1) % 10 == 0) { + System.out.print(" " + (i + 1)); + System.out.flush(); + } + } + } + } + return result; + } + + private int createIndex(Path docsPath, Path indexPath) throws IOException { + IndexWriterConfig iwc = new IndexWriterConfig() + .setOpenMode(IndexWriterConfig.OpenMode.CREATE); + // iwc.setMergePolicy(NoMergePolicy.INSTANCE); + iwc.setRAMBufferSizeMB(1994d); + if (quiet == false) { + iwc.setInfoStream(new PrintStreamInfoStream(System.out)); + System.out.println("creating index in " + indexPath); + } + long start = System.nanoTime(); + long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0; + try (FSDirectory dir = FSDirectory.open(indexPath); + IndexWriter iw = new IndexWriter(dir, iwc)) { + int blockSize = (int) Math.min(totalBytes, (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)); + float[] vector = new float[dim]; + try (FileChannel in = FileChannel.open(docsPath)) { + int i = 0; + while (i < numDocs) { + FloatBuffer vectors = in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize) + .order(ByteOrder.LITTLE_ENDIAN) + .asFloatBuffer(); + offset += blockSize; + for (; vectors.hasRemaining() && i < numDocs ; i++) { + vectors.get(vector); + Document doc = new Document(); + //System.out.println("vector=" + vector[0] + "," + vector[1] + "..."); + doc.add(new VectorField(KNN_FIELD, vector, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW)); + doc.add(new StoredField(ID_FIELD, i)); + iw.addDocument(doc); + } + } + if (quiet == false) { + System.out.println("Done indexing " + numDocs + " documents; now flush"); + } + } + } + long elapsed = System.nanoTime() - start; + if (quiet == false) { + System.out.println("Indexed " + numDocs + " documents in " + elapsed / 1_000_000_000 + "s"); + } + return (int) (elapsed / 1_000_000); + } + + private static void usage() { + String error = "Usage: TestKnnGraph -generate|-search|-stats|-check {datafile} [-beamWidth N]"; + System.err.println(error); + System.exit(1); + } + +} diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java new file mode 100644 index 00000000000..5a1b7325b7b --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java @@ -0,0 +1,459 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.lucene90.Lucene90VectorReader; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.document.VectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnGraphValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomAccessVectorValues; +import org.apache.lucene.index.RandomAccessVectorValuesProducer; +import org.apache.lucene.index.VectorValues; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.LuceneTestCase; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** Tests HNSW KNN graphs */ +public class TestHnsw extends LuceneTestCase { + + // test writing out and reading in a graph gives the same graph + public void testReadWrite() throws IOException { + int dim = random().nextInt(100) + 1; + int nDoc = random().nextInt(100) + 1; + RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random()); + RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy(); + long seed = random().nextLong(); + HnswGraphBuilder.randSeed = seed; + HnswGraph hnsw = HnswGraphBuilder.build((RandomAccessVectorValuesProducer) vectors); + // Recreate the graph while indexing with the same random seed and write it out + HnswGraphBuilder.randSeed = seed; + try (Directory dir = newDirectory()) { + int nVec = 0, indexedDoc = 0; + // Don't merge randomly, create a single segment because we rely on the docid ordering for this test + IndexWriterConfig iwc = new IndexWriterConfig() + .setCodec(Codec.forName("Lucene90")); + try (IndexWriter iw = new IndexWriter(dir, iwc)) { + while (v2.nextDoc() != NO_MORE_DOCS) { + while (indexedDoc < v2.docID()) { + // increment docId in the index by adding empty documents + iw.addDocument(new Document()); + indexedDoc++; + } + Document doc = new Document(); + doc.add(new VectorField("field", v2.vectorValue(), v2.searchStrategy)); + doc.add(new StoredField("id", v2.docID())); + iw.addDocument(doc); + nVec++; + indexedDoc++; + } + } + try (IndexReader reader = DirectoryReader.open(dir)) { + for (LeafReaderContext ctx : reader.leaves()) { + VectorValues values = ctx.reader().getVectorValues("field"); + assertEquals(vectors.searchStrategy, values.searchStrategy()); + assertEquals(dim, values.dimension()); + assertEquals(nVec, values.size()); + assertEquals(indexedDoc, ctx.reader().maxDoc()); + assertEquals(indexedDoc, ctx.reader().numDocs()); + assertVectorsEqual(v3, values); + KnnGraphValues graphValues = ((Lucene90VectorReader) ((CodecReader) ctx.reader()).getVectorReader()).getGraphValues("field"); + assertGraphEqual(hnsw.getGraphValues(), graphValues, nVec); + } + } + } + } + + // Make sure we actually approximately find the closest k elements. Mostly this is about + // ensuring that we have all the distance functions, comparators, priority queues and so on + // oriented in the right directions + public void testAknn() throws IOException { + int nDoc = 100; + RandomAccessVectorValuesProducer vectors = new CircularVectorValues(nDoc); + HnswGraph hnsw = HnswGraphBuilder.build(vectors); + // run some searches + Neighbors nn = HnswGraph.search(new float[]{1, 0}, 10, 5, vectors.randomAccess(), hnsw.getGraphValues(), random()); + int sum = 0; + for (Neighbor n : nn) { + sum += n.node(); + } + // We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) = 45 + assertTrue("sum(result docs)=" + sum, sum < 75); + } + + public void testMaxConnections() throws Exception { + // verify that maxConnections is observed, and that the retained arcs point to the best-scoring neighbors + HnswGraph graph = new HnswGraph(1, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW); + graph.connectNodes(0, 1, 1); + assertArrayEquals(new int[]{1}, graph.getNeighbors(0)); + assertArrayEquals(new int[]{0}, graph.getNeighbors(1)); + graph.connectNodes(0, 2, 2); + assertArrayEquals(new int[]{2}, graph.getNeighbors(0)); + assertArrayEquals(new int[]{0}, graph.getNeighbors(1)); + assertArrayEquals(new int[]{0}, graph.getNeighbors(2)); + graph.connectNodes(2, 3, 1); + assertArrayEquals(new int[]{2}, graph.getNeighbors(0)); + assertArrayEquals(new int[]{0}, graph.getNeighbors(1)); + assertArrayEquals(new int[]{0}, graph.getNeighbors(2)); + assertArrayEquals(new int[]{2}, graph.getNeighbors(3)); + + graph = new HnswGraph(1, VectorValues.SearchStrategy.EUCLIDEAN_HNSW); + graph.connectNodes(0, 1, 1); + assertArrayEquals(new int[]{1}, graph.getNeighbors(0)); + assertArrayEquals(new int[]{0}, graph.getNeighbors(1)); + graph.connectNodes(0, 2, 2); + assertArrayEquals(new int[]{1}, graph.getNeighbors(0)); + assertArrayEquals(new int[]{0}, graph.getNeighbors(1)); + assertArrayEquals(new int[]{0}, graph.getNeighbors(2)); + graph.connectNodes(2, 3, 1); + assertArrayEquals(new int[]{1}, graph.getNeighbors(0)); + assertArrayEquals(new int[]{0}, graph.getNeighbors(1)); + assertArrayEquals(new int[]{3}, graph.getNeighbors(2)); + assertArrayEquals(new int[]{2}, graph.getNeighbors(3)); + } + + /** Returns vectors evenly distributed around the unit circle. + */ + class CircularVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer { + private final int size; + private final float[] value; + + int doc = -1; + + CircularVectorValues(int size) { + this.size = size; + value = new float[2]; + } + + public CircularVectorValues copy() { + return new CircularVectorValues(size); + } + + @Override + public SearchStrategy searchStrategy() { + return SearchStrategy.DOT_PRODUCT_HNSW; + } + + @Override + public int dimension() { + return 2; + } + + @Override + public int size() { + return size; + } + + @Override + public float[] vectorValue() { + return vectorValue(doc); + } + + @Override + public RandomAccessVectorValues randomAccess() { + return new CircularVectorValues(size); + } + + @Override + public int docID() { + return doc; + } + + @Override + public int nextDoc() { + return advance(doc + 1); + } + + @Override + public int advance(int target) { + if (target >= 0 && target < size) { + doc = target; + } else { + doc = NO_MORE_DOCS; + } + return doc; + } + + @Override + public long cost() { + return size; + } + + @Override + public float[] vectorValue(int ord) { + value[0] = (float) Math.cos(Math.PI * ord / (double) size); + value[1] = (float) Math.sin(Math.PI * ord / (double) size); + return value; + } + + @Override + public BytesRef binaryValue(int ord) { + return null; + } + + @Override + public TopDocs search(float[] target, int k, int fanout) { + return null; + } + + } + + private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException { + for (int node = 0; node < size; node ++) { + g.seek(node); + h.seek(node); + assertEquals("arcs differ for node " + node, getNeighbors(g), getNeighbors(h)); + } + } + + private Set getNeighbors(KnnGraphValues g) throws IOException { + Set neighbors = new HashSet<>(); + for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) { + neighbors.add(n); + } + return neighbors; + } + + private void assertVectorsEqual(VectorValues u, VectorValues v) throws IOException { + int uDoc, vDoc; + while (true) { + uDoc = u.nextDoc(); + vDoc = v.nextDoc(); + assertEquals(uDoc, vDoc); + if (uDoc == NO_MORE_DOCS) { + break; + } + assertArrayEquals("vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), 1e-4f); + } + } + + public void testNeighbors() { + // make sure we have the sign correct + Neighbors nn = Neighbors.create(2, false); + Neighbor a = new Neighbor(1, 10); + Neighbor b = new Neighbor(2, 20); + Neighbor c = new Neighbor(3, 30); + assertNull(nn.insertWithOverflow(b)); + assertNull(nn.insertWithOverflow(a)); + assertSame(a, nn.insertWithOverflow(c)); + assertEquals(20, (int) nn.top().score()); + assertEquals(20, (int) nn.pop().score()); + assertEquals(30, (int) nn.top().score()); + assertEquals(30, (int) nn.pop().score()); + + Neighbors fn = Neighbors.create(2, true); + assertNull(fn.insertWithOverflow(b)); + assertNull(fn.insertWithOverflow(a)); + assertSame(c, fn.insertWithOverflow(c)); + assertEquals(20, (int) fn.top().score()); + assertEquals(20, (int) fn.pop().score()); + assertEquals(10, (int) fn.top().score()); + assertEquals(10, (int) fn.pop().score()); + } + + @SuppressWarnings("SelfComparison") + public void testNeighbor() { + Neighbor a = new Neighbor(1, 10); + Neighbor b = new Neighbor(2, 20); + Neighbor c = new Neighbor(3, 20); + assertEquals(0, a.compareTo(a)); + assertEquals(-1, a.compareTo(b)); + assertEquals(1, b.compareTo(a)); + assertEquals(1, b.compareTo(c)); + assertEquals(-1, c.compareTo(b)); + } + + private static float[] randomVector(Random random, int dim) { + float[] vec = new float[dim]; + for (int i = 0; i < dim; i++) { + vec[i] = random.nextFloat(); + } + return vec; + } + + /** + * Produces random vectors and caches them for random-access. + */ + class RandomVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer { + + private final int dimension; + private final float[][] denseValues; + private final float[][] values; + private final float[] scratch; + private final SearchStrategy searchStrategy; + + final int numVectors; + final int maxDoc; + + private int pos = -1; + + RandomVectorValues(int size, int dimension, Random random) { + this.dimension = dimension; + values = new float[size][]; + denseValues = new float[size][]; + scratch = new float[dimension]; + int sz = 0; + int md = -1; + for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) { + values[offset] = randomVector(random, dimension); + denseValues[sz++] = values[offset]; + md = offset; + } + numVectors = sz; + maxDoc = md; + // get a random SearchStrategy other than NONE (0) + searchStrategy = SearchStrategy.values()[random.nextInt(SearchStrategy.values().length - 1) + 1]; + } + + private RandomVectorValues(int dimension, SearchStrategy searchStrategy, float[][] denseValues, float[][] values, int size) { + this.dimension = dimension; + this.searchStrategy = searchStrategy; + this.values = values; + this.denseValues = denseValues; + scratch = new float[dimension]; + numVectors = size; + maxDoc = values.length - 1; + } + + public RandomVectorValues copy() { + return new RandomVectorValues(dimension, searchStrategy, denseValues, values, numVectors); + } + + @Override + public int size() { + return numVectors; + } + + @Override + public SearchStrategy searchStrategy() { + return searchStrategy; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public float[] vectorValue() { + if(random().nextBoolean()) { + return values[pos]; + } else { + // Sometimes use the same scratch array repeatedly, mimicing what the codec will do. + // This should help us catch cases of aliasing where the same VectorValues source is used twice in a + // single computation. + System.arraycopy(values[pos], 0, scratch, 0, dimension); + return scratch; + } + } + + @Override + public RandomAccessVectorValues randomAccess() { + return copy(); + } + + @Override + public float[] vectorValue(int targetOrd) { + return denseValues[targetOrd]; + } + + @Override + public BytesRef binaryValue(int targetOrd) { + return null; + } + + @Override + public TopDocs search(float[] target, int k, int fanout) { + return null; + } + + private boolean seek(int target) { + if (target >= 0 && target < values.length && values[target] != null) { + pos = target; + return true; + } else { + return false; + } + } + + @Override + public int docID() { + return pos; + } + + @Override + public int nextDoc() { + return advance(pos + 1); + } + + public int advance(int target) { + while (++pos < values.length) { + if (seek(pos)) { + return pos; + } + } + return NO_MORE_DOCS; + } + + @Override + public long cost() { + return size(); + } + + } + + public void testBoundsCheckerMax() { + BoundsChecker max = BoundsChecker.create(false); + float f = random().nextFloat() - 0.5f; + // any float > -MAX_VALUE is in bounds + assertFalse(max.check(f)); + // f is now the bound (minus some delta) + max.update(f); + assertFalse(max.check(f)); // f is not out of bounds + assertFalse(max.check(f + 1)); // anything greater than f is in bounds + assertTrue(max.check(f - 1e-5f)); // delta is zero initially + } + + public void testBoundsCheckerMin() { + BoundsChecker min = BoundsChecker.create(true); + float f = random().nextFloat() - 0.5f; + // any float < MAX_VALUE is in bounds + assertFalse(min.check(f)); + // f is now the bound (minus some delta) + min.update(f); + assertFalse(min.check(f)); // f is not out of bounds + assertFalse(min.check(f - 1)); // anything less than f is in bounds + assertTrue(min.check(f + 1e-5f)); // delta is zero initially + } + +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/intervals/IntervalQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/intervals/IntervalQuery.java index 10bff723b17..a3c4e5df3d0 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/intervals/IntervalQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/intervals/IntervalQuery.java @@ -99,7 +99,7 @@ public final class IntervalQuery extends Query { private IntervalQuery(String field, IntervalsSource intervalsSource, IntervalScoreFunction scoreFunction) { Objects.requireNonNull(field, "null field aren't accepted"); Objects.requireNonNull(intervalsSource, "null intervalsSource aren't accepted"); - Objects.requireNonNull(scoreFunction, "null scoreFunction aren't accepted"); + Objects.requireNonNull(scoreFunction, "null searchStrategy aren't accepted"); this.field = field; this.intervalsSource = intervalsSource; this.scoreFunction = scoreFunction;