diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java index 59e9ad0a96f..be5dd7f4d2c 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java @@ -23,7 +23,6 @@ import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN; import java.io.IOException; import java.util.LinkedList; import java.util.List; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; @@ -31,6 +30,8 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.HnswGraphSearcher; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; /** * The Word2VecSynonymProvider generates the list of sysnonyms of a term. @@ -41,7 +42,6 @@ public class Word2VecSynonymProvider { private static final VectorSimilarityFunction SIMILARITY_FUNCTION = VectorSimilarityFunction.DOT_PRODUCT; - private static final VectorEncoding VECTOR_ENCODING = VectorEncoding.FLOAT32; private final Word2VecModel word2VecModel; private final OnHeapHnswGraph hnswGraph; @@ -51,17 +51,13 @@ public class Word2VecSynonymProvider { * @param model containing the set of TermAndVector entries */ public Word2VecSynonymProvider(Word2VecModel model) throws IOException { - word2VecModel = model; - - HnswGraphBuilder builder = + this.word2VecModel = model; + RandomVectorScorerSupplier scorerSupplier = + RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION); + HnswGraphBuilder builder = HnswGraphBuilder.create( - word2VecModel, - VECTOR_ENCODING, - SIMILARITY_FUNCTION, - DEFAULT_MAX_CONN, - DEFAULT_BEAM_WIDTH, - HnswGraphBuilder.randSeed); - this.hnswGraph = builder.build(word2VecModel.copy()); + scorerSupplier, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, HnswGraphBuilder.randSeed); + this.hnswGraph = builder.build(word2VecModel.size()); } public List getSynonyms( @@ -74,15 +70,14 @@ public class Word2VecSynonymProvider { LinkedList result = new LinkedList<>(); float[] query = word2VecModel.vectorValue(term); if (query != null) { + RandomVectorScorer scorer = + RandomVectorScorer.createFloats(word2VecModel, SIMILARITY_FUNCTION, query); KnnCollector synonyms = HnswGraphSearcher.search( - query, + scorer, // The query vector is in the model. When looking for the top-k // it's always the nearest neighbour of itself so, we look for the top-k+1 maxSynonymsPerTerm + 1, - word2VecModel, - VECTOR_ENCODING, - SIMILARITY_FUNCTION, hnswGraph, null, Integer.MAX_VALUE); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 76481c89c27..b72224d9b48 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -354,6 +354,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { final IndexInput dataIn; final int byteSize; + int lastOrd = -1; final float[] value; int ord = -1; @@ -380,9 +381,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { @Override public float[] vectorValue() throws IOException { - dataIn.seek((long) ord * byteSize); - dataIn.readFloats(value, 0, value.length); - return value; + return vectorValue(ord); } @Override @@ -423,8 +422,12 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { @Override public float[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return value; + } dataIn.seek((long) targetOrd * byteSize); dataIn.readFloats(value, 0, value.length); + lastOrd = targetOrd; return value; } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 4f8b9e922d3..48063d5761f 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -33,7 +33,6 @@ import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; @@ -44,7 +43,9 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; /** * Reads vectors from the index segments along with index data structures supporting KNN search. @@ -235,13 +236,11 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { } OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry); - + RandomVectorScorer scorer = + RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target); HnswGraphSearcher.search( - target, - knnCollector, - vectorValues, - VectorEncoding.FLOAT32, - fieldEntry.similarityFunction, + scorer, + new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), getAcceptOrds(acceptDocs, fieldEntry)); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java index 78e778242e3..cbf80cc0c97 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java @@ -32,7 +32,6 @@ import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; @@ -43,6 +42,8 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; /** @@ -231,13 +232,11 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader { } OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData); - + RandomVectorScorer scorer = + RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target); HnswGraphSearcher.search( - target, - knnCollector, - vectorValues, - VectorEncoding.FLOAT32, - fieldEntry.similarityFunction, + scorer, + new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), vectorValues.getAcceptOrds(acceptDocs)); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index 142c3d20ac0..9ca66cd47a8 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -34,6 +34,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues protected final int size; protected final IndexInput slice; protected final int byteSize; + protected int lastOrd = -1; protected final float[] value; OffHeapFloatVectorValues(int dimension, int size, IndexInput slice) { @@ -56,8 +57,12 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public float[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return value; + } slice.seek((long) targetOrd * byteSize); slice.readFloats(value, 0, value.length); + lastOrd = targetOrd; return value; } @@ -87,9 +92,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public float[] vectorValue() throws IOException { - slice.seek((long) doc * byteSize); - slice.readFloats(value, 0, value.length); - return value; + return vectorValue(doc); } @Override @@ -151,9 +154,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public float[] vectorValue() throws IOException { - slice.seek((long) (disi.index()) * byteSize); - slice.readFloats(value, 0, value.length); - return value; + return vectorValue(disi.index()); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index 8da8f934da1..47d56aec192 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -43,6 +43,8 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; /** @@ -267,13 +269,11 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader { } OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData); - + RandomVectorScorer scorer = + RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target); HnswGraphSearcher.search( - target, - knnCollector, - vectorValues, - fieldEntry.vectorEncoding, - fieldEntry.similarityFunction, + scorer, + new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), vectorValues.getAcceptOrds(acceptDocs)); } @@ -288,13 +288,11 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader { } OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData); - + RandomVectorScorer scorer = + RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target); HnswGraphSearcher.search( - target, - knnCollector, - vectorValues, - fieldEntry.vectorEncoding, - fieldEntry.similarityFunction, + scorer, + new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), vectorValues.getAcceptOrds(acceptDocs)); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 682eb6616ed..82a635e9c46 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -35,6 +35,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues protected final int dimension; protected final int size; protected final IndexInput slice; + protected int lastOrd = -1; protected final byte[] binaryValue; protected final ByteBuffer byteBuffer; protected final int byteSize; @@ -60,7 +61,10 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues @Override public byte[] vectorValue(int targetOrd) throws IOException { - readValue(targetOrd); + if (lastOrd != targetOrd) { + readValue(targetOrd); + lastOrd = targetOrd; + } return binaryValue; } @@ -97,9 +101,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues @Override public byte[] vectorValue() throws IOException { - slice.seek((long) doc * byteSize); - slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); - return binaryValue; + return vectorValue(doc); } @Override @@ -164,9 +166,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues @Override public byte[] vectorValue() throws IOException { - slice.seek((long) (disi.index()) * byteSize); - slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false); - return binaryValue; + return vectorValue(disi.index()); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 1e909b1ca9f..bff4b4bf9a1 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -34,6 +34,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues protected final int size; protected final IndexInput slice; protected final int byteSize; + protected int lastOrd = -1; protected final float[] value; OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) { @@ -56,8 +57,12 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public float[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return value; + } slice.seek((long) targetOrd * byteSize); slice.readFloats(value, 0, value.length); + lastOrd = targetOrd; return value; } @@ -93,9 +98,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public float[] vectorValue() throws IOException { - slice.seek((long) doc * byteSize); - slice.readFloats(value, 0, value.length); - return value; + return vectorValue(doc); } @Override @@ -160,9 +163,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public float[] vectorValue() throws IOException { - slice.seek((long) (disi.index()) * byteSize); - slice.readFloats(value, 0, value.length); - return value; + return vectorValue(disi.index()); } @Override diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java index dbd64a162b0..55738b61d7f 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java @@ -24,7 +24,6 @@ import java.util.Locale; import java.util.Objects; import java.util.SplittableRandom; import java.util.concurrent.TimeUnit; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.InfoStream; @@ -33,6 +32,7 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.HnswGraphSearcher; import org.apache.lucene.util.hnsw.NeighborQueue; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; /** * Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the @@ -57,7 +57,7 @@ public final class Lucene91HnswGraphBuilder { private final RandomAccessVectorValues vectorValues; private final SplittableRandom random; private final Lucene91BoundsChecker bound; - private final HnswGraphSearcher graphSearcher; + private final HnswGraphSearcher graphSearcher; final Lucene91OnHeapHnswGraph hnsw; @@ -103,11 +103,8 @@ public final class Lucene91HnswGraphBuilder { int levelOfFirstNode = getRandomGraphLevel(ml, random); this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode); this.graphSearcher = - new HnswGraphSearcher<>( - VectorEncoding.FLOAT32, - similarityFunction, - new NeighborQueue(beamWidth, true), - new FixedBitSet(vectorValues.size())); + new HnswGraphSearcher( + new NeighborQueue(beamWidth, true), new FixedBitSet(vectorValues.size())); bound = Lucene91BoundsChecker.create(false); scratch = new Lucene91NeighborArray(Math.max(beamWidth, maxConn + 1)); } @@ -147,6 +144,8 @@ public final class Lucene91HnswGraphBuilder { /** Inserts a doc with vector value to the graph */ void addGraphNode(int node, float[] value) throws IOException { + RandomVectorScorer scorer = + RandomVectorScorer.createFloats(vectorValues, similarityFunction, value); HnswGraphBuilder.GraphBuilderKnnCollector candidates; final int nodeLevel = getRandomGraphLevel(ml, random); int curMaxLevel = hnsw.numLevels() - 1; @@ -159,12 +158,12 @@ public final class Lucene91HnswGraphBuilder { // for levels > nodeLevel search with topk = 1 for (int level = curMaxLevel; level > nodeLevel; level--) { - candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw); + candidates = graphSearcher.searchLevel(scorer, 1, level, eps, hnsw); eps = new int[] {candidates.popNode()}; } // for levels <= nodeLevel search with topk = beamWidth, and add connections for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) { - candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw); + candidates = graphSearcher.searchLevel(scorer, beamWidth, level, eps, hnsw); eps = candidates.popUntilNearestKNodes(); hnsw.addNode(level, node); addDiverseNeighbors(level, node, candidates); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java index b2e7629aed1..0794c101cac 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -34,7 +34,6 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.IndexInput; @@ -44,6 +43,7 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; /** @@ -277,16 +277,13 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter { throws IOException { // build graph - HnswGraphBuilder hnswGraphBuilder = - HnswGraphBuilder.create( - vectorValues, - VectorEncoding.FLOAT32, - similarityFunction, - M, - beamWidth, - HnswGraphBuilder.randSeed); + RandomVectorScorerSupplier scorerSupplier = + RandomVectorScorerSupplier.createFloats(vectorValues, similarityFunction); + HnswGraphBuilder hnswGraphBuilder = + HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); - OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.copy()); + + OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.size()); // write vectors' neighbours on each level into the vectorIndex file int countOnLevel0 = graph.size(); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index 9a2a156f98a..ecbe879a912 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -53,6 +53,7 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; /** @@ -420,16 +421,14 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { docsWithField.cardinality(), vectorDataInput, byteSize); - HnswGraphBuilder hnswGraphBuilder = + RandomVectorScorerSupplier scorerSupplier = + RandomVectorScorerSupplier.createBytes( + vectorValues, fieldInfo.getVectorSimilarityFunction()); + HnswGraphBuilder hnswGraphBuilder = HnswGraphBuilder.create( - vectorValues, - fieldInfo.getVectorEncoding(), - fieldInfo.getVectorSimilarityFunction(), - M, - beamWidth, - HnswGraphBuilder.randSeed); + scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); - yield hnswGraphBuilder.build(vectorValues.copy()); + yield hnswGraphBuilder.build(vectorValues.size()); } case FLOAT32 -> { OffHeapFloatVectorValues.DenseOffHeapVectorValues vectorValues = @@ -438,16 +437,13 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { docsWithField.cardinality(), vectorDataInput, byteSize); - HnswGraphBuilder hnswGraphBuilder = + RandomVectorScorerSupplier scorerSupplier = + RandomVectorScorerSupplier.createFloats( + vectorValues, fieldInfo.getVectorSimilarityFunction()); + HnswGraphBuilder hnswGraphBuilder = HnswGraphBuilder.create( - vectorValues, - fieldInfo.getVectorEncoding(), - fieldInfo.getVectorSimilarityFunction(), - M, - beamWidth, - HnswGraphBuilder.randSeed); - hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); - yield hnswGraphBuilder.build(vectorValues.copy()); + scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); + yield hnswGraphBuilder.build(vectorValues.size()); } }; writeGraph(graph); @@ -630,7 +626,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { private final int dim; private final DocsWithFieldSet docsWithField; private final List vectors; - private final HnswGraphBuilder hnswGraphBuilder; + private final HnswGraphBuilder hnswGraphBuilder; private int lastDocID = -1; private int node = 0; @@ -654,21 +650,25 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { }; } + @SuppressWarnings("unchecked") FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) throws IOException { this.fieldInfo = fieldInfo; this.dim = fieldInfo.getVectorDimension(); this.docsWithField = new DocsWithFieldSet(); vectors = new ArrayList<>(); - RAVectorValues raVectorValues = new RAVectorValues<>(vectors, dim); + RandomAccessVectorValues raVectors = new RAVectorValues<>(vectors, dim); + RandomVectorScorerSupplier scorerSupplier = + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> RandomVectorScorerSupplier.createBytes( + (RandomAccessVectorValues) raVectors, + fieldInfo.getVectorSimilarityFunction()); + case FLOAT32 -> RandomVectorScorerSupplier.createFloats( + (RandomAccessVectorValues) raVectors, + fieldInfo.getVectorSimilarityFunction()); + }; hnswGraphBuilder = - HnswGraphBuilder.create( - raVectorValues, - fieldInfo.getVectorEncoding(), - fieldInfo.getVectorSimilarityFunction(), - M, - beamWidth, - HnswGraphBuilder.randSeed); + HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(infoStream); } @@ -685,7 +685,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter { assert docID > lastDocID; docsWithField.add(docID); vectors.add(copyValue(vectorValue)); - hnswGraphBuilder.addGraphNode(node, vectorValue); + hnswGraphBuilder.addGraphNode(node); node++; lastDocID = docID; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java index e5913de0813..a97ce43f896 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java @@ -45,6 +45,8 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; /** @@ -274,12 +276,11 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader { } OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData); + RandomVectorScorer scorer = + RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target); HnswGraphSearcher.search( - target, - knnCollector, - vectorValues, - fieldEntry.vectorEncoding, - fieldEntry.similarityFunction, + scorer, + new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), vectorValues.getAcceptOrds(acceptDocs)); } @@ -296,12 +297,11 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader { } OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData); + RandomVectorScorer scorer = + RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target); HnswGraphSearcher.search( - target, - knnCollector, - vectorValues, - fieldEntry.vectorEncoding, - fieldEntry.similarityFunction, + scorer, + new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), vectorValues.getAcceptOrds(acceptDocs)); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java index 5358d66f16e..fb94833c9fe 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -41,12 +41,8 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.*; -import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.*; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; -import org.apache.lucene.util.hnsw.HnswGraphBuilder; -import org.apache.lucene.util.hnsw.NeighborArray; -import org.apache.lucene.util.hnsw.OnHeapHnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.packed.DirectMonotonicWriter; /** @@ -438,10 +434,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { docsWithField.cardinality(), vectorDataInput, byteSize); - HnswGraphBuilder hnswGraphBuilder = - createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex); + RandomVectorScorerSupplier scorerSupplier = + RandomVectorScorerSupplier.createBytes( + vectorValues, fieldInfo.getVectorSimilarityFunction()); + HnswGraphBuilder hnswGraphBuilder = + createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); - yield hnswGraphBuilder.build(vectorValues.copy()); + yield hnswGraphBuilder.build(vectorValues.size()); } case FLOAT32 -> { OffHeapFloatVectorValues.DenseOffHeapVectorValues vectorValues = @@ -450,10 +449,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { docsWithField.cardinality(), vectorDataInput, byteSize); - HnswGraphBuilder hnswGraphBuilder = - createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex); + RandomVectorScorerSupplier scorerSupplier = + RandomVectorScorerSupplier.createFloats( + vectorValues, fieldInfo.getVectorSimilarityFunction()); + HnswGraphBuilder hnswGraphBuilder = + createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); - yield hnswGraphBuilder.build(vectorValues.copy()); + yield hnswGraphBuilder.build(vectorValues.size()); } }; vectorIndexNodeOffsets = writeGraph(graph); @@ -482,20 +484,14 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { } } - private HnswGraphBuilder createHnswGraphBuilder( + private HnswGraphBuilder createHnswGraphBuilder( MergeState mergeState, FieldInfo fieldInfo, - RandomAccessVectorValues floatVectorValues, + RandomVectorScorerSupplier scorerSupplier, int initializerIndex) throws IOException { if (initializerIndex == -1) { - return HnswGraphBuilder.create( - floatVectorValues, - fieldInfo.getVectorEncoding(), - fieldInfo.getVectorSimilarityFunction(), - M, - beamWidth, - HnswGraphBuilder.randSeed); + return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); } HnswGraph initializerGraph = @@ -503,14 +499,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { Map ordinalMapper = getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex); return HnswGraphBuilder.create( - floatVectorValues, - fieldInfo.getVectorEncoding(), - fieldInfo.getVectorSimilarityFunction(), - M, - beamWidth, - HnswGraphBuilder.randSeed, - initializerGraph, - ordinalMapper); + scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, initializerGraph, ordinalMapper); } private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo) @@ -868,7 +857,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { private final int dim; private final DocsWithFieldSet docsWithField; private final List vectors; - private final HnswGraphBuilder hnswGraphBuilder; + private final HnswGraphBuilder hnswGraphBuilder; private int lastDocID = -1; private int node = 0; @@ -892,20 +881,25 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { }; } + @SuppressWarnings("unchecked") FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) throws IOException { this.fieldInfo = fieldInfo; this.dim = fieldInfo.getVectorDimension(); this.docsWithField = new DocsWithFieldSet(); vectors = new ArrayList<>(); + RAVectorValues raVectors = new RAVectorValues<>(vectors, dim); + RandomVectorScorerSupplier scorerSupplier = + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> RandomVectorScorerSupplier.createBytes( + (RandomAccessVectorValues) raVectors, + fieldInfo.getVectorSimilarityFunction()); + case FLOAT32 -> RandomVectorScorerSupplier.createFloats( + (RandomAccessVectorValues) raVectors, + fieldInfo.getVectorSimilarityFunction()); + }; hnswGraphBuilder = - HnswGraphBuilder.create( - new RAVectorValues<>(vectors, dim), - fieldInfo.getVectorEncoding(), - fieldInfo.getVectorSimilarityFunction(), - M, - beamWidth, - HnswGraphBuilder.randSeed); + HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(infoStream); } @@ -920,7 +914,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter { assert docID > lastDocID; docsWithField.add(docID); vectors.add(copyValue(vectorValue)); - hnswGraphBuilder.addGraphNode(node, vectorValue); + hnswGraphBuilder.addGraphNode(node); node++; lastDocID = docID; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 089148a3300..441444fe2c9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -35,6 +35,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues protected final int dimension; protected final int size; protected final IndexInput slice; + protected int lastOrd = -1; protected final byte[] binaryValue; protected final ByteBuffer byteBuffer; protected final int byteSize; @@ -60,7 +61,10 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues @Override public byte[] vectorValue(int targetOrd) throws IOException { - readValue(targetOrd); + if (lastOrd != targetOrd) { + readValue(targetOrd); + lastOrd = targetOrd; + } return binaryValue; } @@ -97,9 +101,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues @Override public byte[] vectorValue() throws IOException { - slice.seek((long) doc * byteSize); - slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); - return binaryValue; + return vectorValue(doc); } @Override @@ -164,9 +166,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues @Override public byte[] vectorValue() throws IOException { - slice.seek((long) (disi.index()) * byteSize); - slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false); - return binaryValue; + return vectorValue(disi.index()); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 0f53e3f9c8b..b212d18c7a6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -35,6 +35,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues protected final int size; protected final IndexInput slice; protected final int byteSize; + protected int lastOrd = -1; protected final float[] value; OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) { @@ -57,8 +58,12 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public float[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return value; + } slice.seek((long) targetOrd * byteSize); slice.readFloats(value, 0, value.length); + lastOrd = targetOrd; return value; } @@ -91,9 +96,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public float[] vectorValue() throws IOException { - slice.seek((long) doc * byteSize); - slice.readFloats(value, 0, value.length); - return value; + return vectorValue(doc); } @Override @@ -158,9 +161,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues @Override public float[] vectorValue() throws IOException { - slice.seek((long) (disi.index()) * byteSize); - slice.readFloats(value, 0, value.length); - return value; + return vectorValue(disi.index()); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index bcfdde529eb..39605989bb8 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -28,8 +28,6 @@ import java.util.Objects; import java.util.Set; import java.util.SplittableRandom; import java.util.concurrent.TimeUnit; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.FixedBitSet; @@ -37,11 +35,9 @@ import org.apache.lucene.util.InfoStream; /** * Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the - * hyperparameters. - * - * @param the type of vector + * hyper-parameters. */ -public final class HnswGraphBuilder { +public final class HnswGraphBuilder { /** Default number of maximum connections per node */ public static final int DEFAULT_MAX_CONN = 16; @@ -64,11 +60,9 @@ public final class HnswGraphBuilder { private final double ml; private final NeighborArray scratch; - private final VectorSimilarityFunction similarityFunction; - private final VectorEncoding vectorEncoding; - private final RandomAccessVectorValues vectors; private final SplittableRandom random; - private final HnswGraphSearcher graphSearcher; + private final RandomVectorScorerSupplier scorerSupplier; + private final HnswGraphSearcher graphSearcher; private final GraphBuilderKnnCollector entryCandidates; // for upper levels of graph search private final GraphBuilderKnnCollector beamCandidates; // for levels of graph where we add the node @@ -77,34 +71,23 @@ public final class HnswGraphBuilder { private InfoStream infoStream = InfoStream.getDefault(); - // we need two sources of vectors in order to perform diversity check comparisons without - // colliding - private final RandomAccessVectorValues vectorsCopy; private final Set initializedNodes; - public static HnswGraphBuilder create( - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - int M, - int beamWidth, - long seed) + public static HnswGraphBuilder create( + RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) throws IOException { - return new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed); + return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed); } - public static HnswGraphBuilder create( - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, + public static HnswGraphBuilder create( + RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, HnswGraph initializerGraph, Map oldToNewOrdinalMap) throws IOException { - HnswGraphBuilder hnswGraphBuilder = - new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed); + HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed); hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap); return hnswGraphBuilder; } @@ -113,8 +96,7 @@ public final class HnswGraphBuilder { * Reads all the vectors from vector values, builds a graph connecting them by their dense * ordinals, using the given hyperparameter settings, and returns the resulting graph. * - * @param vectors the vectors whose relations are represented by the graph - must provide a - * different view over those vectors than the one used to add via addGraphNode. + * @param scorerSupplier a supplier to create vector scorer from ordinals. * @param M – graph fanout parameter used to calculate the maximum number of connections a node * can have – M on upper layers, and M * 2 on the lowest level. * @param beamWidth the size of the beam search to use when finding nearest neighbors. @@ -122,17 +104,8 @@ public final class HnswGraphBuilder { * to ensure repeatable construction. */ private HnswGraphBuilder( - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - int M, - int beamWidth, - long seed) + RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) throws IOException { - this.vectors = vectors; - this.vectorsCopy = vectors.copy(); - this.vectorEncoding = Objects.requireNonNull(vectorEncoding); - this.similarityFunction = Objects.requireNonNull(similarityFunction); if (M <= 0) { throw new IllegalArgumentException("maxConn must be positive"); } @@ -140,16 +113,15 @@ public final class HnswGraphBuilder { throw new IllegalArgumentException("beamWidth must be positive"); } this.M = M; + this.scorerSupplier = + Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null"); // normalization factor for level generation; currently not configurable this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); this.random = new SplittableRandom(seed); this.hnsw = new OnHeapHnswGraph(M); this.graphSearcher = - new HnswGraphSearcher<>( - vectorEncoding, - similarityFunction, - new NeighborQueue(beamWidth, true), - new FixedBitSet(this.vectors.size())); + new HnswGraphSearcher( + new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size())); // in scratch we store candidates in reverse order: worse candidates are first scratch = new NeighborArray(Math.max(beamWidth, M + 1), false); entryCandidates = new GraphBuilderKnnCollector(1); @@ -158,22 +130,15 @@ public final class HnswGraphBuilder { } /** - * Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two - * copies enables efficient retrieval without extra data copying, while avoiding collision of the - * returned values. + * Adds all nodes to the graph up to the provided {@code maxOrd}. * - * @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an - * independent accessor for the vectors + * @param maxOrd The maximum ordinal of the nodes to be added. */ - public OnHeapHnswGraph build(RandomAccessVectorValues vectorsToAdd) throws IOException { - if (vectorsToAdd == this.vectors) { - throw new IllegalArgumentException( - "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); - } + public OnHeapHnswGraph build(int maxOrd) throws IOException { if (infoStream.isEnabled(HNSW_COMPONENT)) { - infoStream.message(HNSW_COMPONENT, "build graph from " + vectorsToAdd.size() + " vectors"); + infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors"); } - addVectors(vectorsToAdd); + addVectors(maxOrd); return hnsw; } @@ -216,19 +181,6 @@ public final class HnswGraphBuilder { } } - private void addVectors(RandomAccessVectorValues vectorsToAdd) throws IOException { - long start = System.nanoTime(), t = start; - for (int node = 0; node < vectorsToAdd.size(); node++) { - if (initializedNodes.contains(node)) { - continue; - } - addGraphNode(node, vectorsToAdd); - if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) { - t = printGraphBuildStatus(node, start, t); - } - } - } - /** Set info-stream to output debugging information * */ public void setInfoStream(InfoStream infoStream) { this.infoStream = infoStream; @@ -238,8 +190,22 @@ public final class HnswGraphBuilder { return hnsw; } + private void addVectors(int maxOrd) throws IOException { + long start = System.nanoTime(), t = start; + for (int node = 0; node < maxOrd; node++) { + if (initializedNodes.contains(node)) { + continue; + } + addGraphNode(node); + if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) { + t = printGraphBuildStatus(node, start, t); + } + } + } + /** Inserts a doc with vector value to the graph */ - public void addGraphNode(int node, T value) throws IOException { + public void addGraphNode(int node) throws IOException { + RandomVectorScorer scorer = scorerSupplier.scorer(node); final int nodeLevel = getRandomGraphLevel(ml, random); int curMaxLevel = hnsw.numLevels() - 1; @@ -261,24 +227,20 @@ public final class HnswGraphBuilder { GraphBuilderKnnCollector candidates = entryCandidates; for (int level = curMaxLevel; level > nodeLevel; level--) { candidates.clear(); - graphSearcher.searchLevel(candidates, value, level, eps, vectors, hnsw, null); + graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null); eps = new int[] {candidates.popNode()}; } // for levels <= nodeLevel search with topk = beamWidth, and add connections candidates = beamCandidates; for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) { candidates.clear(); - graphSearcher.searchLevel(candidates, value, level, eps, vectors, hnsw, null); + graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null); eps = candidates.popUntilNearestKNodes(); hnsw.addNode(level, node); addDiverseNeighbors(level, node, candidates); } } - public void addGraphNode(int node, RandomAccessVectorValues values) throws IOException { - addGraphNode(node, values.vectorValue(node)); - } - private long printGraphBuildStatus(int node, long start, long t) { long now = System.nanoTime(); infoStream.message( @@ -353,36 +315,9 @@ public final class HnswGraphBuilder { */ private boolean diversityCheck(int candidate, float score, NeighborArray neighbors) throws IOException { - return isDiverse(candidate, neighbors, score); - } - - private boolean isDiverse(int candidate, NeighborArray neighbors, float score) - throws IOException { - return switch (vectorEncoding) { - case BYTE -> isDiverse((byte[]) vectors.vectorValue(candidate), neighbors, score); - case FLOAT32 -> isDiverse((float[]) vectors.vectorValue(candidate), neighbors, score); - }; - } - - private boolean isDiverse(float[] candidate, NeighborArray neighbors, float score) - throws IOException { + RandomVectorScorer scorer = scorerSupplier.scorer(candidate); for (int i = 0; i < neighbors.size(); i++) { - float neighborSimilarity = - similarityFunction.compare( - candidate, (float[]) vectorsCopy.vectorValue(neighbors.node[i])); - if (neighborSimilarity >= score) { - return false; - } - } - return true; - } - - private boolean isDiverse(byte[] candidate, NeighborArray neighbors, float score) - throws IOException { - for (int i = 0; i < neighbors.size(); i++) { - float neighborSimilarity = - similarityFunction.compare( - candidate, (byte[]) vectorsCopy.vectorValue(neighbors.node[i])); + float neighborSimilarity = scorer.score(neighbors.node[i]); if (neighborSimilarity >= score) { return false; } @@ -395,26 +330,8 @@ public final class HnswGraphBuilder { * neighbours */ private int findWorstNonDiverse(NeighborArray neighbors, int nodeOrd) throws IOException { - float[] vectorValue = null; - byte[] binaryValue = null; - switch (this.vectorEncoding) { - case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(nodeOrd); - case BYTE -> binaryValue = (byte[]) vectors.vectorValue(nodeOrd); - } - float[] finalVectorValue = vectorValue; - byte[] finalBinaryValue = binaryValue; - int[] uncheckedIndexes = - neighbors.sort( - nbrOrd -> { - float score = - switch (this.vectorEncoding) { - case FLOAT32 -> this.similarityFunction.compare( - finalVectorValue, (float[]) vectorsCopy.vectorValue(nbrOrd)); - case BYTE -> this.similarityFunction.compare( - finalBinaryValue, (byte[]) vectorsCopy.vectorValue(nbrOrd)); - }; - return score; - }); + RandomVectorScorer scorer = scorerSupplier.scorer(nodeOrd); + int[] uncheckedIndexes = neighbors.sort(scorer); if (uncheckedIndexes == null) { // all nodes are checked, we will directly return the most distant one return neighbors.size() - 1; @@ -438,37 +355,12 @@ public final class HnswGraphBuilder { private boolean isWorstNonDiverse( int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor) throws IOException { - int candidateNode = neighbors.node[candidateIndex]; - return switch (vectorEncoding) { - case BYTE -> isWorstNonDiverse( - candidateIndex, - (byte[]) vectors.vectorValue(candidateNode), - neighbors, - uncheckedIndexes, - uncheckedCursor); - case FLOAT32 -> isWorstNonDiverse( - candidateIndex, - (float[]) vectors.vectorValue(candidateNode), - neighbors, - uncheckedIndexes, - uncheckedCursor); - }; - } - - private boolean isWorstNonDiverse( - int candidateIndex, - float[] candidateVector, - NeighborArray neighbors, - int[] uncheckedIndexes, - int uncheckedCursor) - throws IOException { float minAcceptedSimilarity = neighbors.score[candidateIndex]; + RandomVectorScorer scorer = scorerSupplier.scorer(neighbors.node[candidateIndex]); if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { // the candidate itself is unchecked for (int i = candidateIndex - 1; i >= 0; i--) { - float neighborSimilarity = - similarityFunction.compare( - candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i])); + float neighborSimilarity = scorer.score(neighbors.node[i]); // candidate node is too similar to node i given its score relative to the base node if (neighborSimilarity >= minAcceptedSimilarity) { return true; @@ -479,47 +371,7 @@ public final class HnswGraphBuilder { // inserted) unchecked nodes assert candidateIndex > uncheckedIndexes[uncheckedCursor]; for (int i = uncheckedCursor; i >= 0; i--) { - float neighborSimilarity = - similarityFunction.compare( - candidateVector, - (float[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]])); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { - return true; - } - } - } - return false; - } - - private boolean isWorstNonDiverse( - int candidateIndex, - byte[] candidateVector, - NeighborArray neighbors, - int[] uncheckedIndexes, - int uncheckedCursor) - throws IOException { - float minAcceptedSimilarity = neighbors.score[candidateIndex]; - if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { - // the candidate itself is unchecked - for (int i = candidateIndex - 1; i >= 0; i--) { - float neighborSimilarity = - similarityFunction.compare( - candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i])); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { - return true; - } - } - } else { - // else we just need to make sure candidate does not violate diversity with the (newly - // inserted) unchecked nodes - assert candidateIndex > uncheckedIndexes[uncheckedCursor]; - for (int i = uncheckedCursor; i >= 0; i--) { - float neighborSimilarity = - similarityFunction.compare( - candidateVector, - (byte[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]])); + float neighborSimilarity = scorer.score(neighbors.node[uncheckedIndexes[i]]); // candidate node is too similar to node i given its score relative to the base node if (neighborSimilarity >= minAcceptedSimilarity) { return true; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 28c7ca2b163..aeddbeb56fa 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,8 +20,6 @@ package org.apache.lucene.util.hnsw; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.util.BitSet; @@ -32,13 +30,8 @@ import org.apache.lucene.util.SparseFixedBitSet; /** * Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the * search algorithm, see {@link HnswGraph}. - * - * @param the type of query vector */ -public class HnswGraphSearcher { - private final VectorSimilarityFunction similarityFunction; - private final VectorEncoding vectorEncoding; - +public class HnswGraphSearcher { /** * Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive * to allocate, so they're cleared and reused across calls. @@ -50,17 +43,10 @@ public class HnswGraphSearcher { /** * Creates a new graph searcher. * - * @param similarityFunction the similarity function to compare vectors * @param candidates max heap that will track the candidate nodes to explore * @param visited bit set that will track nodes that have already been visited */ - public HnswGraphSearcher( - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - NeighborQueue candidates, - BitSet visited) { - this.vectorEncoding = vectorEncoding; - this.similarityFunction = similarityFunction; + public HnswGraphSearcher(NeighborQueue candidates, BitSet visited) { this.candidates = candidates; this.visited = visited; } @@ -68,10 +54,27 @@ public class HnswGraphSearcher { /** * Searches HNSW graph for the nearest neighbors of a query vector. * - * @param query search query vector + * @param scorer the scorer to compare the query with the nodes + * @param knnCollector a collector of top knn results to be returned + * @param graph the graph values. May represent the entire graph, or a level in a hierarchical + * graph. + * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or + * {@code null} if they are all allowed to match. + */ + public static void search( + RandomVectorScorer scorer, KnnCollector knnCollector, HnswGraph graph, Bits acceptOrds) + throws IOException { + HnswGraphSearcher graphSearcher = + new HnswGraphSearcher( + new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(graph.size())); + search(scorer, knnCollector, graph, graphSearcher, acceptOrds); + } + + /** + * Search {@link OnHeapHnswGraph}, this method is thread safe. + * + * @param scorer the scorer to compare the query with the nodes * @param topK the number of nodes to be returned - * @param vectors the vector values - * @param similarityFunction the similarity function to compare vectors * @param graph the graph values. May represent the entire graph, or a level in a hierarchical * graph. * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or @@ -80,198 +83,36 @@ public class HnswGraphSearcher { * @return a set of collected vectors holding the nearest neighbors found */ public static KnnCollector search( - float[] query, - int topK, - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - HnswGraph graph, - Bits acceptOrds, - int visitedLimit) + RandomVectorScorer scorer, int topK, OnHeapHnswGraph graph, Bits acceptOrds, int visitedLimit) throws IOException { KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit); - search(query, knnCollector, vectors, vectorEncoding, similarityFunction, graph, acceptOrds); + OnHeapHnswGraphSearcher graphSearcher = + new OnHeapHnswGraphSearcher( + new NeighborQueue(topK, true), new SparseFixedBitSet(graph.size())); + search(scorer, knnCollector, graph, graphSearcher, acceptOrds); return knnCollector; } - /** - * Searches HNSW graph for the nearest neighbors of a query vector. - * - * @param query search query vector - * @param knnCollector a collector of top knn results to be returned - * @param vectors the vector values - * @param similarityFunction the similarity function to compare vectors - * @param graph the graph values. May represent the entire graph, or a level in a hierarchical - * graph. - * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or - * {@code null} if they are all allowed to match. - */ - public static void search( - float[] query, + private static void search( + RandomVectorScorer scorer, KnnCollector knnCollector, - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, HnswGraph graph, - Bits acceptOrds) - throws IOException { - if (query.length != vectors.dimension()) { - throw new IllegalArgumentException( - "vector query dimension: " - + query.length - + " differs from field dimension: " - + vectors.dimension()); - } - HnswGraphSearcher graphSearcher = - new HnswGraphSearcher<>( - vectorEncoding, - similarityFunction, - new NeighborQueue(knnCollector.k(), true), - new SparseFixedBitSet(vectors.size())); - search(query, knnCollector, vectors, graph, graphSearcher, acceptOrds); - } - - /** - * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to - * {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding, - * VectorSimilarityFunction, HnswGraph, Bits, int)} - */ - public static KnnCollector search( - float[] query, - int topK, - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - OnHeapHnswGraph graph, - Bits acceptOrds, - int visitedLimit) - throws IOException { - KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit); - OnHeapHnswGraphSearcher graphSearcher = - new OnHeapHnswGraphSearcher<>( - vectorEncoding, - similarityFunction, - new NeighborQueue(topK, true), - new SparseFixedBitSet(vectors.size())); - search(query, knnCollector, vectors, graph, graphSearcher, acceptOrds); - return knnCollector; - } - - /** - * Searches HNSW graph for the nearest neighbors of a query vector. - * - * @param query search query vector - * @param topK the number of nodes to be returned - * @param vectors the vector values - * @param similarityFunction the similarity function to compare vectors - * @param graph the graph values. May represent the entire graph, or a level in a hierarchical - * graph. - * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or - * {@code null} if they are all allowed to match. - * @param visitedLimit the maximum number of nodes that the search is allowed to visit - * @return a set of collected vectors holding the nearest neighbors found - */ - public static KnnCollector search( - byte[] query, - int topK, - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - HnswGraph graph, - Bits acceptOrds, - int visitedLimit) - throws IOException { - KnnCollector collector = new TopKnnCollector(topK, visitedLimit); - search(query, collector, vectors, vectorEncoding, similarityFunction, graph, acceptOrds); - return collector; - } - - /** - * Searches HNSW graph for the nearest neighbors of a query vector. - * - * @param query search query vector - * @param knnCollector a collector of top knn results to be returned - * @param vectors the vector values - * @param similarityFunction the similarity function to compare vectors - * @param graph the graph values. May represent the entire graph, or a level in a hierarchical - * graph. - * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or - * {@code null} if they are all allowed to match. - */ - public static void search( - byte[] query, - KnnCollector knnCollector, - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - HnswGraph graph, - Bits acceptOrds) - throws IOException { - if (query.length != vectors.dimension()) { - throw new IllegalArgumentException( - "vector query dimension: " - + query.length - + " differs from field dimension: " - + vectors.dimension()); - } - HnswGraphSearcher graphSearcher = - new HnswGraphSearcher<>( - vectorEncoding, - similarityFunction, - new NeighborQueue(knnCollector.k(), true), - new SparseFixedBitSet(vectors.size())); - search(query, knnCollector, vectors, graph, graphSearcher, acceptOrds); - } - - /** - * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to - * {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction, - * HnswGraph, Bits, int)} - */ - public static KnnCollector search( - byte[] query, - int topK, - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - OnHeapHnswGraph graph, - Bits acceptOrds, - int visitedLimit) - throws IOException { - OnHeapHnswGraphSearcher graphSearcher = - new OnHeapHnswGraphSearcher<>( - vectorEncoding, - similarityFunction, - new NeighborQueue(topK, true), - new SparseFixedBitSet(vectors.size())); - KnnCollector collector = new TopKnnCollector(topK, visitedLimit); - search(query, collector, vectors, graph, graphSearcher, acceptOrds); - return collector; - } - - private static void search( - T query, - KnnCollector knnCollector, - RandomAccessVectorValues vectors, - HnswGraph graph, - HnswGraphSearcher graphSearcher, + HnswGraphSearcher graphSearcher, Bits acceptOrds) throws IOException { int initialEp = graph.entryNode(); if (initialEp == -1) { return; } - int[] epAndVisited = - graphSearcher.findBestEntryPoint(query, vectors, graph, knnCollector.visitLimit()); + int[] epAndVisited = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector.visitLimit()); int numVisited = epAndVisited[1]; int ep = epAndVisited[0]; if (ep == -1) { knnCollector.incVisitedCount(numVisited); return; } - KnnCollector results = new OrdinalTranslatedKnnCollector(knnCollector, vectors::ordToDoc); - results.incVisitedCount(numVisited); - graphSearcher.searchLevel(results, query, 0, new int[] {ep}, vectors, graph, acceptOrds); + knnCollector.incVisitedCount(numVisited); + graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds); } /** @@ -280,48 +121,40 @@ public class HnswGraphSearcher { *

If the search stops early because it reaches the visited nodes limit, then the results will * be marked incomplete through {@link NeighborQueue#incomplete()}. * - * @param query search query vector + * @param scorer the scorer to compare the query with the nodes * @param topK the number of nearest to query results to return * @param level level to search * @param eps the entry points for search at this level expressed as level 0th ordinals - * @param vectors vector values * @param graph the graph values * @return a set of collected vectors holding the nearest neighbors found */ public HnswGraphBuilder.GraphBuilderKnnCollector searchLevel( // Note: this is only public because Lucene91HnswGraphBuilder needs it - T query, - int topK, - int level, - final int[] eps, - RandomAccessVectorValues vectors, - HnswGraph graph) + RandomVectorScorer scorer, int topK, int level, final int[] eps, HnswGraph graph) throws IOException { HnswGraphBuilder.GraphBuilderKnnCollector results = new HnswGraphBuilder.GraphBuilderKnnCollector(topK); - searchLevel(results, query, level, eps, vectors, graph, null); + searchLevel(results, scorer, level, eps, graph, null); return results; } /** * Function to find the best entry point from which to search the zeroth graph layer. * - * @param query vector query with which to search - * @param vectors random access vector values + * @param scorer the scorer to compare the query with the nodes * @param graph the HNSWGraph * @param visitLimit How many vectors are allowed to be visited * @return An integer array whose first element is the best entry point, and second is the number * of candidates visited. Entry point of `-1` indicates visitation limit exceed * @throws IOException When accessing the vector fails */ - private int[] findBestEntryPoint( - T query, RandomAccessVectorValues vectors, HnswGraph graph, long visitLimit) + private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit) throws IOException { int size = graph.size(); int visitedCount = 1; - prepareScratchState(vectors.size()); + prepareScratchState(graph.size()); int currentEp = graph.entryNode(); - float currentScore = compare(query, vectors, currentEp); + float currentScore = scorer.score(currentEp); boolean foundBetter; for (int level = graph.numLevels() - 1; level >= 1; level--) { foundBetter = true; @@ -339,7 +172,7 @@ public class HnswGraphSearcher { if (visitedCount >= visitLimit) { return new int[] {-1, visitedCount}; } - float friendSimilarity = compare(query, vectors, friendOrd); + float friendSimilarity = scorer.score(friendOrd); visitedCount++; if (friendSimilarity > currentScore || (friendSimilarity == currentScore && friendOrd < currentEp)) { @@ -361,23 +194,22 @@ public class HnswGraphSearcher { */ void searchLevel( KnnCollector results, - T query, + RandomVectorScorer scorer, int level, final int[] eps, - RandomAccessVectorValues vectors, HnswGraph graph, Bits acceptOrds) throws IOException { int size = graph.size(); - prepareScratchState(vectors.size()); + prepareScratchState(graph.size()); for (int ep : eps) { if (visited.getAndSet(ep) == false) { if (results.earlyTerminated()) { break; } - float score = compare(query, vectors, ep); + float score = scorer.score(ep); results.incVisitedCount(1); candidates.add(ep, score); if (acceptOrds == null || acceptOrds.get(ep)) { @@ -408,7 +240,7 @@ public class HnswGraphSearcher { if (results.earlyTerminated()) { break; } - float friendSimilarity = compare(query, vectors, friendOrd); + float friendSimilarity = scorer.score(friendOrd); results.incVisitedCount(1); if (friendSimilarity >= minAcceptedSimilarity) { candidates.add(friendOrd, friendSimilarity); @@ -422,14 +254,6 @@ public class HnswGraphSearcher { } } - private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException { - if (vectorEncoding == VectorEncoding.BYTE) { - return similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(ord)); - } else { - return similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(ord)); - } - } - private void prepareScratchState(int capacity) { candidates.clear(); if (visited.length() < capacity) { @@ -468,17 +292,13 @@ public class HnswGraphSearcher { *

Note the class itself is NOT thread safe, but since each search will create a new Searcher, * the search methods using this class are thread safe. */ - private static class OnHeapHnswGraphSearcher extends HnswGraphSearcher { + private static class OnHeapHnswGraphSearcher extends HnswGraphSearcher { private NeighborArray cur; private int upto; - private OnHeapHnswGraphSearcher( - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - NeighborQueue candidates, - BitSet visited) { - super(vectorEncoding, similarityFunction, candidates, visited); + private OnHeapHnswGraphSearcher(NeighborQueue candidates, BitSet visited) { + super(candidates, visited); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index d3fa753d32f..f6cd54fa53e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -86,7 +86,7 @@ public class NeighborArray { * @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is * already fully sorted */ - public int[] sort(ScoringFunction scoringFunction) throws IOException { + public int[] sort(RandomVectorScorer scorer) throws IOException { if (size == sortedNodeSize) { // all nodes checked and sorted return null; @@ -95,8 +95,7 @@ public class NeighborArray { int[] uncheckedIndexes = new int[size - sortedNodeSize]; int count = 0; while (sortedNodeSize != size) { - uncheckedIndexes[count] = - insertSortedInternal(scoringFunction); // sortedNodeSize is increased inside + uncheckedIndexes[count] = insertSortedInternal(scorer); // sortedNodeSize is increased inside for (int i = 0; i < count; i++) { if (uncheckedIndexes[i] >= uncheckedIndexes[count]) { // the previous inserted nodes has been shifted @@ -110,13 +109,13 @@ public class NeighborArray { } /** insert the first unsorted node into its sorted position */ - private int insertSortedInternal(ScoringFunction scoringFunction) throws IOException { + private int insertSortedInternal(RandomVectorScorer scorer) throws IOException { assert sortedNodeSize < size : "Call this method only when there's unsorted node"; int tmpNode = node[sortedNodeSize]; float tmpScore = score[sortedNodeSize]; if (Float.isNaN(tmpScore)) { - tmpScore = scoringFunction.computeScore(tmpNode); + tmpScore = scorer.score(tmpNode); } int insertionPoint = @@ -204,20 +203,4 @@ public class NeighborArray { } return start; } - - /** - * ScoringFunction is a lambda function created in HnswGraphBuilder to allow for lazy computation - * of distance score. - */ - interface ScoringFunction { - /** - * Computes the distance score between the given node ID and the root node of this - * NeighborArray. - * - * @param nodeId The ID of the node for which to compute the distance score. - * @return The distance score as a float value. - * @throws IOException If an I/O error occurs during computation. - */ - float computeScore(int nodeId) throws IOException; - } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java index e529b22feaf..ed1a5ffb59f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java @@ -24,12 +24,12 @@ import org.apache.lucene.search.TotalHits; /** * Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId */ -final class OrdinalTranslatedKnnCollector implements KnnCollector { +public final class OrdinalTranslatedKnnCollector implements KnnCollector { private final KnnCollector in; private final IntToIntFunction vectorOrdinalToDocId; - OrdinalTranslatedKnnCollector(KnnCollector in, IntToIntFunction vectorOrdinalToDocId) { + public OrdinalTranslatedKnnCollector(KnnCollector in, IntToIntFunction vectorOrdinalToDocId) { this.in = in; this.vectorOrdinalToDocId = vectorOrdinalToDocId; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java new file mode 100644 index 00000000000..36b7e331dc9 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.io.IOException; +import org.apache.lucene.index.VectorSimilarityFunction; + +/** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */ +public interface RandomVectorScorer { + /** + * Returns the score between the query and the provided node. + * + * @param node a random node in the graph + * @return the computed score + */ + float score(int node) throws IOException; + + /** + * Creates a default scorer for float vectors. + * + *

WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid + * using it after calling this function. If you plan to use it again outside the returned {@link + * RandomVectorScorer}, think about passing a copied version ({@link + * RandomAccessVectorValues#copy}). + * + * @param vectors the underlying storage for vectors + * @param similarityFunction the similarity function to score vectors + * @param query the actual query + */ + static RandomVectorScorer createFloats( + final RandomAccessVectorValues vectors, + final VectorSimilarityFunction similarityFunction, + final float[] query) { + if (query.length != vectors.dimension()) { + throw new IllegalArgumentException( + "vector query dimension: " + + query.length + + " differs from field dimension: " + + vectors.dimension()); + } + return node -> similarityFunction.compare(query, vectors.vectorValue(node)); + } + + /** + * Creates a default scorer for byte vectors. + * + *

WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid + * using it after calling this function. If you plan to use it again outside the returned {@link + * RandomVectorScorer}, think about passing a copied version ({@link + * RandomAccessVectorValues#copy}). + * + * @param vectors the underlying storage for vectors + * @param similarityFunction the similarity function to use to score vectors + * @param query the actual query + */ + static RandomVectorScorer createBytes( + final RandomAccessVectorValues vectors, + final VectorSimilarityFunction similarityFunction, + final byte[] query) { + if (query.length != vectors.dimension()) { + throw new IllegalArgumentException( + "vector query dimension: " + + query.length + + " differs from field dimension: " + + vectors.dimension()); + } + return node -> similarityFunction.compare(query, vectors.vectorValue(node)); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java new file mode 100644 index 00000000000..a922e2fa663 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.io.IOException; +import org.apache.lucene.index.VectorSimilarityFunction; + +/** A supplier that creates {@link RandomVectorScorer} from an ordinal. */ +public interface RandomVectorScorerSupplier { + /** + * This creates a {@link RandomVectorScorer} for scoring random nodes in batches against the given + * ordinal. + * + * @param ord the ordinal of the node to compare + * @return a new {@link RandomVectorScorer} + */ + RandomVectorScorer scorer(int ord) throws IOException; + + /** + * Creates a {@link RandomVectorScorerSupplier} to compare float vectors. + * + *

WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid + * using it after calling this function. If you plan to use it again outside the returned {@link + * RandomVectorScorer}, think about passing a copied version ({@link + * RandomAccessVectorValues#copy}). + * + * @param vectors the underlying storage for vectors + * @param similarityFunction the similarity function to score vectors + */ + static RandomVectorScorerSupplier createFloats( + final RandomAccessVectorValues vectors, + final VectorSimilarityFunction similarityFunction) + throws IOException { + // We copy the provided random accessor just once during the supplier's initialization + // and then reuse it consistently across all scorers for conducting vector comparisons. + final RandomAccessVectorValues vectorsCopy = vectors.copy(); + return queryOrd -> + (RandomVectorScorer) + cand -> + similarityFunction.compare( + vectors.vectorValue(queryOrd), vectorsCopy.vectorValue(cand)); + } + + /** + * Creates a {@link RandomVectorScorerSupplier} to compare byte vectors. + * + *

WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid + * using it after calling this function. If you plan to use it again outside the returned {@link + * RandomVectorScorer}, think about passing a copied version ({@link + * RandomAccessVectorValues#copy}). + * + * @param vectors the underlying storage for vectors + * @param similarityFunction the similarity function to score vectors + */ + static RandomVectorScorerSupplier createBytes( + final RandomAccessVectorValues vectors, + final VectorSimilarityFunction similarityFunction) + throws IOException { + // We copy the provided random accessor just once during the supplier's initialization + // and then reuse it consistently across all scorers for conducting vector comparisons. + final RandomAccessVectorValues vectorsCopy = vectors.copy(); + return queryOrd -> + (RandomVectorScorer) + cand -> + similarityFunction.compare( + vectors.vectorValue(queryOrd), vectorsCopy.vectorValue(cand)); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 82f61342428..b758b441c50 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -109,6 +109,29 @@ abstract class HnswGraphTestCase extends LuceneTestCase { abstract T getTargetVector(); + @SuppressWarnings("unchecked") + protected RandomVectorScorerSupplier buildScorerSupplier(RandomAccessVectorValues vectors) + throws IOException { + return switch (getVectorEncoding()) { + case BYTE -> RandomVectorScorerSupplier.createBytes( + (RandomAccessVectorValues) vectors, similarityFunction); + case FLOAT32 -> RandomVectorScorerSupplier.createFloats( + (RandomAccessVectorValues) vectors, similarityFunction); + }; + } + + @SuppressWarnings("unchecked") + protected RandomVectorScorer buildScorer(RandomAccessVectorValues vectors, T query) + throws IOException { + RandomAccessVectorValues vectorsCopy = vectors.copy(); + return switch (getVectorEncoding()) { + case BYTE -> RandomVectorScorer.createBytes( + (RandomAccessVectorValues) vectorsCopy, similarityFunction, (byte[]) query); + case FLOAT32 -> RandomVectorScorer.createFloats( + (RandomAccessVectorValues) vectorsCopy, similarityFunction, (float[]) query); + }; + } + // test writing out and reading in a graph gives the expected graph public void testReadWrite() throws IOException { int dim = random().nextInt(100) + 1; @@ -118,10 +141,9 @@ abstract class HnswGraphTestCase extends LuceneTestCase { long seed = random().nextLong(); AbstractMockVectorValues vectors = vectorValues(nDoc, dim); AbstractMockVectorValues v2 = vectors.copy(), v3 = vectors.copy(); - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, M, beamWidth, seed); - HnswGraph hnsw = builder.build(vectors.copy()); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed); + HnswGraph hnsw = builder.build(vectors.size()); // Recreate the graph while indexing with the same random seed and write it out HnswGraphBuilder.randSeed = seed; @@ -349,33 +371,14 @@ abstract class HnswGraphTestCase extends LuceneTestCase { int nDoc = 100; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomAccessVectorValues vectors = circularVectorValues(nDoc); - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, 10, 100, random().nextInt()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 100, random().nextInt()); + OnHeapHnswGraph hnsw = builder.build(vectors.size()); + // run some searches KnnCollector nn = - switch (getVectorEncoding()) { - case BYTE -> HnswGraphSearcher.search( - (byte[]) getTargetVector(), - 10, - (RandomAccessVectorValues) vectors.copy(), - getVectorEncoding(), - similarityFunction, - hnsw, - null, - Integer.MAX_VALUE); - case FLOAT32 -> HnswGraphSearcher.search( - (float[]) getTargetVector(), - 10, - (RandomAccessVectorValues) vectors.copy(), - getVectorEncoding(), - similarityFunction, - hnsw, - null, - Integer.MAX_VALUE); - }; - + HnswGraphSearcher.search( + buildScorer(vectors, getTargetVector()), 10, hnsw, null, Integer.MAX_VALUE); TopDocs topDocs = nn.topDocs(); assertEquals("Number of found results is not equal to [10].", 10, topDocs.scoreDocs.length); int sum = 0; @@ -401,33 +404,14 @@ abstract class HnswGraphTestCase extends LuceneTestCase { int nDoc = 100; RandomAccessVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); + OnHeapHnswGraph hnsw = builder.build(vectors.size()); // the first 10 docs must not be deleted to ensure the expected recall Bits acceptOrds = createRandomAcceptOrds(10, nDoc); KnnCollector nn = - switch (getVectorEncoding()) { - case BYTE -> HnswGraphSearcher.search( - (byte[]) getTargetVector(), - 10, - (RandomAccessVectorValues) vectors.copy(), - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); - case FLOAT32 -> HnswGraphSearcher.search( - (float[]) getTargetVector(), - 10, - (RandomAccessVectorValues) vectors.copy(), - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); - }; + HnswGraphSearcher.search( + buildScorer(vectors, getTargetVector()), 10, hnsw, acceptOrds, Integer.MAX_VALUE); TopDocs nodes = nn.topDocs(); assertEquals("Number of found results is not equal to [10].", 10, nodes.scoreDocs.length); int sum = 0; @@ -445,10 +429,9 @@ abstract class HnswGraphTestCase extends LuceneTestCase { int nDoc = 100; RandomAccessVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); + OnHeapHnswGraph hnsw = builder.build(vectors.size()); // Only mark a few vectors as accepted BitSet acceptOrds = new FixedBitSet(nDoc); for (int i = 0; i < nDoc; i += random().nextInt(15, 20)) { @@ -458,27 +441,12 @@ abstract class HnswGraphTestCase extends LuceneTestCase { // Check the search finds all accepted vectors int numAccepted = acceptOrds.cardinality(); KnnCollector nn = - switch (getVectorEncoding()) { - case FLOAT32 -> HnswGraphSearcher.search( - (float[]) getTargetVector(), - numAccepted, - (RandomAccessVectorValues) vectors.copy(), - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); - case BYTE -> HnswGraphSearcher.search( - (byte[]) getTargetVector(), - numAccepted, - (RandomAccessVectorValues) vectors.copy(), - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); - }; - + HnswGraphSearcher.search( + buildScorer(vectors, getTargetVector()), + numAccepted, + hnsw, + acceptOrds, + Integer.MAX_VALUE); TopDocs nodes = nn.topDocs(); assertEquals(numAccepted, nodes.scoreDocs.length); for (ScoreDoc node : nodes.scoreDocs) { @@ -565,32 +533,26 @@ abstract class HnswGraphTestCase extends LuceneTestCase { long seed = random().nextLong(); AbstractMockVectorValues initializerVectors = vectorValues(initializerSize, dim); - HnswGraphBuilder initializerBuilder = - HnswGraphBuilder.create( - initializerVectors, getVectorEncoding(), similarityFunction, 10, 30, seed); + RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors); + HnswGraphBuilder initializerBuilder = + HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed); - OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy()); + OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size()); AbstractMockVectorValues finalVectorValues = vectorValues(totalSize, dim, initializerVectors, docIdOffset); Map initializerOrdMap = createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset); - HnswGraphBuilder finalBuilder = + RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues); + HnswGraphBuilder finalBuilder = HnswGraphBuilder.create( - finalVectorValues, - getVectorEncoding(), - similarityFunction, - 10, - 30, - seed, - initializerGraph, - initializerOrdMap); + finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap); // When offset is 0, the graphs should be identical before vectors are added assertGraphEqual(initializerGraph, finalBuilder.getGraph()); - OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy()); + OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.size()); assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap); } @@ -602,31 +564,26 @@ abstract class HnswGraphTestCase extends LuceneTestCase { long seed = random().nextLong(); AbstractMockVectorValues initializerVectors = vectorValues(initializerSize, dim); - HnswGraphBuilder initializerBuilder = - HnswGraphBuilder.create( - initializerVectors.copy(), getVectorEncoding(), similarityFunction, 10, 30, seed); - OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy()); + RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors); + HnswGraphBuilder initializerBuilder = + HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed); + + OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size()); AbstractMockVectorValues finalVectorValues = vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset); Map initializerOrdMap = createOffsetOrdinalMap(initializerSize, finalVectorValues.copy(), docIdOffset); - HnswGraphBuilder finalBuilder = + RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues); + HnswGraphBuilder finalBuilder = HnswGraphBuilder.create( - finalVectorValues, - getVectorEncoding(), - similarityFunction, - 10, - 30, - seed, - initializerGraph, - initializerOrdMap); + finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap); assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap); // Confirm that the graph is appropriately constructed by checking that the nodes in the old // graph are present in the levels of the new graph - OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy()); + OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.size()); assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap); } @@ -718,65 +675,32 @@ abstract class HnswGraphTestCase extends LuceneTestCase { int nDoc = 500; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomAccessVectorValues vectors = circularVectorValues(nDoc); - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); + OnHeapHnswGraph hnsw = builder.build(vectors.size()); int topK = 50; int visitedLimit = topK + random().nextInt(5); KnnCollector nn = - switch (getVectorEncoding()) { - case FLOAT32 -> HnswGraphSearcher.search( - (float[]) getTargetVector(), - topK, - (RandomAccessVectorValues) vectors.copy(), - getVectorEncoding(), - similarityFunction, - hnsw, - createRandomAcceptOrds(0, nDoc), - visitedLimit); - case BYTE -> HnswGraphSearcher.search( - (byte[]) getTargetVector(), - topK, - (RandomAccessVectorValues) vectors.copy(), - getVectorEncoding(), - similarityFunction, - hnsw, - createRandomAcceptOrds(0, nDoc), - visitedLimit); - }; - + HnswGraphSearcher.search( + buildScorer(vectors, getTargetVector()), + topK, + hnsw, + createRandomAcceptOrds(0, nDoc), + visitedLimit); assertTrue(nn.earlyTerminated()); // The visited count shouldn't exceed the limit assertTrue(nn.visitedCount() <= visitedLimit); } - public void testHnswGraphBuilderInvalid() { - expectThrows( - NullPointerException.class, () -> HnswGraphBuilder.create(null, null, null, 0, 0, 0)); + public void testHnswGraphBuilderInvalid() throws IOException { + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectorValues(1, 1)); // M must be > 0 expectThrows( - IllegalArgumentException.class, - () -> - HnswGraphBuilder.create( - vectorValues(1, 1), - getVectorEncoding(), - VectorSimilarityFunction.EUCLIDEAN, - 0, - 10, - 0)); + IllegalArgumentException.class, () -> HnswGraphBuilder.create(scorerSupplier, 0, 10, 0)); // beamWidth must be > 0 expectThrows( - IllegalArgumentException.class, - () -> - HnswGraphBuilder.create( - vectorValues(1, 1), - getVectorEncoding(), - VectorSimilarityFunction.EUCLIDEAN, - 10, - 0, - 0)); + IllegalArgumentException.class, () -> HnswGraphBuilder.create(scorerSupplier, 10, 0, 0)); } public void testRamUsageEstimate() throws IOException { @@ -784,14 +708,13 @@ abstract class HnswGraphTestCase extends LuceneTestCase { int dim = randomIntBetween(100, 1024); int M = randomIntBetween(4, 96); - VectorSimilarityFunction similarityFunction = - RandomizedTest.randomFrom(VectorSimilarityFunction.values()); + similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values()); RandomAccessVectorValues vectors = vectorValues(size, dim); - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, M, M * 2, random().nextLong()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = + HnswGraphBuilder.create(scorerSupplier, M, M * 2, random().nextLong()); + OnHeapHnswGraph hnsw = builder.build(vectors.size()); long estimated = RamUsageEstimator.sizeOfObject(hnsw); long actual = ramUsed(hnsw); @@ -813,21 +736,19 @@ abstract class HnswGraphTestCase extends LuceneTestCase { }; AbstractMockVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt()); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 2, 10, random().nextInt()); // node 0 is added by the builder constructor - RandomAccessVectorValues vectorsCopy = vectors.copy(); - builder.addGraphNode(0, vectorsCopy); - builder.addGraphNode(1, vectorsCopy); - builder.addGraphNode(2, vectorsCopy); + builder.addGraphNode(0); + builder.addGraphNode(1); + builder.addGraphNode(2); // now every node has tried to attach every other node as a neighbor, but // some were excluded based on diversity check. assertLevel0Neighbors(builder.hnsw, 0, 1, 2); assertLevel0Neighbors(builder.hnsw, 1, 0); assertLevel0Neighbors(builder.hnsw, 2, 0); - builder.addGraphNode(3, vectorsCopy); + builder.addGraphNode(3); assertLevel0Neighbors(builder.hnsw, 0, 1, 2); // we added 3 here assertLevel0Neighbors(builder.hnsw, 1, 0, 3); @@ -835,7 +756,7 @@ abstract class HnswGraphTestCase extends LuceneTestCase { assertLevel0Neighbors(builder.hnsw, 3, 1); // supplant an existing neighbor - builder.addGraphNode(4, vectorsCopy); + builder.addGraphNode(4); // 4 is the same distance from 0 that 2 is; we leave the existing node in place assertLevel0Neighbors(builder.hnsw, 0, 1, 2); assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4); @@ -844,7 +765,7 @@ abstract class HnswGraphTestCase extends LuceneTestCase { assertLevel0Neighbors(builder.hnsw, 3, 1, 4); assertLevel0Neighbors(builder.hnsw, 4, 1, 3); - builder.addGraphNode(5, vectorsCopy); + builder.addGraphNode(5); assertLevel0Neighbors(builder.hnsw, 0, 1, 2); assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5); assertLevel0Neighbors(builder.hnsw, 2, 0); @@ -869,20 +790,18 @@ abstract class HnswGraphTestCase extends LuceneTestCase { }; AbstractMockVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt()); - RandomAccessVectorValues vectorsCopy = vectors.copy(); - builder.addGraphNode(0, vectorsCopy); - builder.addGraphNode(1, vectorsCopy); - builder.addGraphNode(2, vectorsCopy); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt()); + builder.addGraphNode(0); + builder.addGraphNode(1); + builder.addGraphNode(2); assertLevel0Neighbors(builder.hnsw, 0, 1, 2); // 2 is closer to 0 than 1, so it is excluded as non-diverse assertLevel0Neighbors(builder.hnsw, 1, 0); // 1 is closer to 0 than 2, so it is excluded as non-diverse assertLevel0Neighbors(builder.hnsw, 2, 0); - builder.addGraphNode(3, vectorsCopy); + builder.addGraphNode(3); // this is one case we are testing; 2 has been displaced by 3 assertLevel0Neighbors(builder.hnsw, 0, 1, 3); assertLevel0Neighbors(builder.hnsw, 1, 0); @@ -901,20 +820,18 @@ abstract class HnswGraphTestCase extends LuceneTestCase { }; AbstractMockVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt()); - RandomAccessVectorValues vectorsCopy = vectors.copy(); - builder.addGraphNode(0, vectorsCopy); - builder.addGraphNode(1, vectorsCopy); - builder.addGraphNode(2, vectorsCopy); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt()); + builder.addGraphNode(0); + builder.addGraphNode(1); + builder.addGraphNode(2); assertLevel0Neighbors(builder.hnsw, 0, 1, 2); // 2 is closer to 0 than 1, so it is excluded as non-diverse assertLevel0Neighbors(builder.hnsw, 1, 0); // 1 is closer to 0 than 2, so it is excluded as non-diverse assertLevel0Neighbors(builder.hnsw, 2, 0); - builder.addGraphNode(3, vectorsCopy); + builder.addGraphNode(3); // this is one case we are testing; 1 has been displaced by 3 assertLevel0Neighbors(builder.hnsw, 0, 2, 3); assertLevel0Neighbors(builder.hnsw, 1, 0, 3); @@ -939,10 +856,9 @@ abstract class HnswGraphTestCase extends LuceneTestCase { int dim = atLeast(10); AbstractMockVectorValues vectors = vectorValues(size, dim); int topK = 5; - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong()); + OnHeapHnswGraph hnsw = builder.build(vectors.size()); Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size); int totalMatches = 0; @@ -950,27 +866,8 @@ abstract class HnswGraphTestCase extends LuceneTestCase { KnnCollector actual; T query = randomVector(dim); actual = - switch (getVectorEncoding()) { - case BYTE -> HnswGraphSearcher.search( - (byte[]) query, - 100, - (RandomAccessVectorValues) vectors, - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); - case FLOAT32 -> HnswGraphSearcher.search( - (float[]) query, - 100, - (RandomAccessVectorValues) vectors, - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); - }; - + HnswGraphSearcher.search( + buildScorer(vectors, query), 100, hnsw, acceptOrds, Integer.MAX_VALUE); TopDocs topDocs = actual.topDocs(); NeighborQueue expected = new NeighborQueue(topK, false); for (int j = 0; j < size; j++) { @@ -1007,10 +904,9 @@ abstract class HnswGraphTestCase extends LuceneTestCase { int size = atLeast(100); int dim = atLeast(10); AbstractMockVectorValues vectors = vectorValues(size, dim); - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong()); + OnHeapHnswGraph hnsw = builder.build(vectors.size()); Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size); List queries = new ArrayList<>(); @@ -1020,27 +916,8 @@ abstract class HnswGraphTestCase extends LuceneTestCase { T query = randomVector(dim); queries.add(query); expect = - switch (getVectorEncoding()) { - case BYTE -> HnswGraphSearcher.search( - (byte[]) query, - 100, - (RandomAccessVectorValues) vectors, - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); - case FLOAT32 -> HnswGraphSearcher.search( - (float[]) query, - 100, - (RandomAccessVectorValues) vectors, - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); - }; - + HnswGraphSearcher.search( + buildScorer(vectors, query), 100, hnsw, acceptOrds, Integer.MAX_VALUE); expects.add(expect); } @@ -1054,26 +931,8 @@ abstract class HnswGraphTestCase extends LuceneTestCase { KnnCollector actual; try { actual = - switch (getVectorEncoding()) { - case BYTE -> HnswGraphSearcher.search( - (byte[]) query, - 100, - (RandomAccessVectorValues) vectors, - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); - case FLOAT32 -> HnswGraphSearcher.search( - (float[]) query, - 100, - (RandomAccessVectorValues) vectors, - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); - }; + HnswGraphSearcher.search( + buildScorer(vectors, query), 100, hnsw, acceptOrds, Integer.MAX_VALUE); } catch (IOException ioe) { throw new RuntimeException(ioe); } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java index 7ecb68295fb..5cab4cf3256 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java @@ -130,10 +130,9 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase { int nDoc = 1000; similarityFunction = VectorSimilarityFunction.EUCLIDEAN; RandomAccessVectorValues vectors = circularVectorValues(nDoc); - HnswGraphBuilder builder = - HnswGraphBuilder.create( - vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); + OnHeapHnswGraph hnsw = builder.build(vectors.size()); // Skip over half of the documents that are closest to the query vector FixedBitSet acceptOrds = new FixedBitSet(nDoc); @@ -142,14 +141,7 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase { } KnnCollector nn = HnswGraphSearcher.search( - getTargetVector(), - 10, - vectors.copy(), - getVectorEncoding(), - similarityFunction, - hnsw, - acceptOrds, - Integer.MAX_VALUE); + buildScorer(vectors, getTargetVector()), 10, hnsw, acceptOrds, Integer.MAX_VALUE); TopDocs nodes = nn.topDocs(); assertEquals("Number of found results is not equal to [10].", 10, nodes.scoreDocs.length);