Introduce a random vector scorer in HNSW builder/searcher (#12529)

This PR involves the refactoring of the HNSW builder and searcher, aiming to create an abstraction for the random access and vector comparisons conducted during graph traversal.

The newly added RandomVectorScorer provides a means to directly compare ordinals, eliminating the need to expose the raw vector primitive type.
This scorer takes charge of vector retrieval and comparison during the graph's construction and search processes.

The primary purpose of this abstraction is to enable the implementation of various strategies.
For example, it opens the door to constructing the graph using the original float vectors while performing searches using their quantized int8 vector counterparts.
This commit is contained in:
Jim Ferenczi 2023-09-12 13:57:07 +01:00 committed by GitHub
parent d77195d705
commit c26b0180bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 540 additions and 880 deletions

View File

@ -23,7 +23,6 @@ import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
@ -31,6 +30,8 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
/**
* The Word2VecSynonymProvider generates the list of sysnonyms of a term.
@ -41,7 +42,6 @@ public class Word2VecSynonymProvider {
private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
VectorSimilarityFunction.DOT_PRODUCT;
private static final VectorEncoding VECTOR_ENCODING = VectorEncoding.FLOAT32;
private final Word2VecModel word2VecModel;
private final OnHeapHnswGraph hnswGraph;
@ -51,17 +51,13 @@ public class Word2VecSynonymProvider {
* @param model containing the set of TermAndVector entries
*/
public Word2VecSynonymProvider(Word2VecModel model) throws IOException {
word2VecModel = model;
HnswGraphBuilder<float[]> builder =
this.word2VecModel = model;
RandomVectorScorerSupplier scorerSupplier =
RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION);
HnswGraphBuilder builder =
HnswGraphBuilder.create(
word2VecModel,
VECTOR_ENCODING,
SIMILARITY_FUNCTION,
DEFAULT_MAX_CONN,
DEFAULT_BEAM_WIDTH,
HnswGraphBuilder.randSeed);
this.hnswGraph = builder.build(word2VecModel.copy());
scorerSupplier, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, HnswGraphBuilder.randSeed);
this.hnswGraph = builder.build(word2VecModel.size());
}
public List<TermAndBoost> getSynonyms(
@ -74,15 +70,14 @@ public class Word2VecSynonymProvider {
LinkedList<TermAndBoost> result = new LinkedList<>();
float[] query = word2VecModel.vectorValue(term);
if (query != null) {
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(word2VecModel, SIMILARITY_FUNCTION, query);
KnnCollector synonyms =
HnswGraphSearcher.search(
query,
scorer,
// The query vector is in the model. When looking for the top-k
// it's always the nearest neighbour of itself so, we look for the top-k+1
maxSynonymsPerTerm + 1,
word2VecModel,
VECTOR_ENCODING,
SIMILARITY_FUNCTION,
hnswGraph,
null,
Integer.MAX_VALUE);

View File

@ -354,6 +354,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
final IndexInput dataIn;
final int byteSize;
int lastOrd = -1;
final float[] value;
int ord = -1;
@ -380,9 +381,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
@Override
public float[] vectorValue() throws IOException {
dataIn.seek((long) ord * byteSize);
dataIn.readFloats(value, 0, value.length);
return value;
return vectorValue(ord);
}
@Override
@ -423,8 +422,12 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
@Override
public float[] vectorValue(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
return value;
}
dataIn.seek((long) targetOrd * byteSize);
dataIn.readFloats(value, 0, value.length);
lastOrd = targetOrd;
return value;
}
}

View File

@ -33,7 +33,6 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
@ -44,7 +43,9 @@ import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
/**
* Reads vectors from the index segments along with index data structures supporting KNN search.
@ -235,13 +236,11 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
HnswGraphSearcher.search(
target,
knnCollector,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry));
}

View File

@ -32,7 +32,6 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
@ -43,6 +42,8 @@ import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/**
@ -231,13 +232,11 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
}
OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
HnswGraphSearcher.search(
target,
knnCollector,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
}

View File

@ -34,6 +34,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
protected final int size;
protected final IndexInput slice;
protected final int byteSize;
protected int lastOrd = -1;
protected final float[] value;
OffHeapFloatVectorValues(int dimension, int size, IndexInput slice) {
@ -56,8 +57,12 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public float[] vectorValue(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
return value;
}
slice.seek((long) targetOrd * byteSize);
slice.readFloats(value, 0, value.length);
lastOrd = targetOrd;
return value;
}
@ -87,9 +92,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public float[] vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
slice.readFloats(value, 0, value.length);
return value;
return vectorValue(doc);
}
@Override
@ -151,9 +154,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public float[] vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readFloats(value, 0, value.length);
return value;
return vectorValue(disi.index());
}
@Override

View File

@ -43,6 +43,8 @@ import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/**
@ -267,13 +269,11 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
}
OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
HnswGraphSearcher.search(
target,
knnCollector,
vectorValues,
fieldEntry.vectorEncoding,
fieldEntry.similarityFunction,
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
}
@ -288,13 +288,11 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
}
OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target);
HnswGraphSearcher.search(
target,
knnCollector,
vectorValues,
fieldEntry.vectorEncoding,
fieldEntry.similarityFunction,
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
}

View File

@ -35,6 +35,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
protected final int dimension;
protected final int size;
protected final IndexInput slice;
protected int lastOrd = -1;
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
@ -60,7 +61,10 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public byte[] vectorValue(int targetOrd) throws IOException {
readValue(targetOrd);
if (lastOrd != targetOrd) {
readValue(targetOrd);
lastOrd = targetOrd;
}
return binaryValue;
}
@ -97,9 +101,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public byte[] vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
return binaryValue;
return vectorValue(doc);
}
@Override
@ -164,9 +166,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public byte[] vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
return binaryValue;
return vectorValue(disi.index());
}
@Override

View File

@ -34,6 +34,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
protected final int size;
protected final IndexInput slice;
protected final int byteSize;
protected int lastOrd = -1;
protected final float[] value;
OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
@ -56,8 +57,12 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public float[] vectorValue(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
return value;
}
slice.seek((long) targetOrd * byteSize);
slice.readFloats(value, 0, value.length);
lastOrd = targetOrd;
return value;
}
@ -93,9 +98,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public float[] vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
slice.readFloats(value, 0, value.length);
return value;
return vectorValue(doc);
}
@Override
@ -160,9 +163,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public float[] vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readFloats(value, 0, value.length);
return value;
return vectorValue(disi.index());
}
@Override

View File

@ -24,7 +24,6 @@ import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
@ -33,6 +32,7 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
@ -57,7 +57,7 @@ public final class Lucene91HnswGraphBuilder {
private final RandomAccessVectorValues<float[]> vectorValues;
private final SplittableRandom random;
private final Lucene91BoundsChecker bound;
private final HnswGraphSearcher<float[]> graphSearcher;
private final HnswGraphSearcher graphSearcher;
final Lucene91OnHeapHnswGraph hnsw;
@ -103,11 +103,8 @@ public final class Lucene91HnswGraphBuilder {
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode);
this.graphSearcher =
new HnswGraphSearcher<>(
VectorEncoding.FLOAT32,
similarityFunction,
new NeighborQueue(beamWidth, true),
new FixedBitSet(vectorValues.size()));
new HnswGraphSearcher(
new NeighborQueue(beamWidth, true), new FixedBitSet(vectorValues.size()));
bound = Lucene91BoundsChecker.create(false);
scratch = new Lucene91NeighborArray(Math.max(beamWidth, maxConn + 1));
}
@ -147,6 +144,8 @@ public final class Lucene91HnswGraphBuilder {
/** Inserts a doc with vector value to the graph */
void addGraphNode(int node, float[] value) throws IOException {
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, similarityFunction, value);
HnswGraphBuilder.GraphBuilderKnnCollector candidates;
final int nodeLevel = getRandomGraphLevel(ml, random);
int curMaxLevel = hnsw.numLevels() - 1;
@ -159,12 +158,12 @@ public final class Lucene91HnswGraphBuilder {
// for levels > nodeLevel search with topk = 1
for (int level = curMaxLevel; level > nodeLevel; level--) {
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw);
candidates = graphSearcher.searchLevel(scorer, 1, level, eps, hnsw);
eps = new int[] {candidates.popNode()};
}
// for levels <= nodeLevel search with topk = beamWidth, and add connections
for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw);
candidates = graphSearcher.searchLevel(scorer, beamWidth, level, eps, hnsw);
eps = candidates.popUntilNearestKNodes();
hnsw.addNode(level, node);
addDiverseNeighbors(level, node, candidates);

View File

@ -34,7 +34,6 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
@ -44,6 +43,7 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
/**
@ -277,16 +277,13 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
throws IOException {
// build graph
HnswGraphBuilder<float[]> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
VectorEncoding.FLOAT32,
similarityFunction,
M,
beamWidth,
HnswGraphBuilder.randSeed);
RandomVectorScorerSupplier scorerSupplier =
RandomVectorScorerSupplier.createFloats(vectorValues, similarityFunction);
HnswGraphBuilder hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.copy());
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.size());
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();

View File

@ -53,6 +53,7 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
/**
@ -420,16 +421,14 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
docsWithField.cardinality(),
vectorDataInput,
byteSize);
HnswGraphBuilder<byte[]> hnswGraphBuilder =
RandomVectorScorerSupplier scorerSupplier =
RandomVectorScorerSupplier.createBytes(
vectorValues, fieldInfo.getVectorSimilarityFunction());
HnswGraphBuilder hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
yield hnswGraphBuilder.build(vectorValues.size());
}
case FLOAT32 -> {
OffHeapFloatVectorValues.DenseOffHeapVectorValues vectorValues =
@ -438,16 +437,13 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
docsWithField.cardinality(),
vectorDataInput,
byteSize);
HnswGraphBuilder<float[]> hnswGraphBuilder =
RandomVectorScorerSupplier scorerSupplier =
RandomVectorScorerSupplier.createFloats(
vectorValues, fieldInfo.getVectorSimilarityFunction());
HnswGraphBuilder hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
yield hnswGraphBuilder.build(vectorValues.size());
}
};
writeGraph(graph);
@ -630,7 +626,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
private final HnswGraphBuilder<T> hnswGraphBuilder;
private final HnswGraphBuilder hnswGraphBuilder;
private int lastDocID = -1;
private int node = 0;
@ -654,21 +650,25 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
};
}
@SuppressWarnings("unchecked")
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException {
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
RAVectorValues<T> raVectorValues = new RAVectorValues<>(vectors, dim);
RandomAccessVectorValues<T> raVectors = new RAVectorValues<>(vectors, dim);
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
(RandomAccessVectorValues<byte[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
(RandomAccessVectorValues<float[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
};
hnswGraphBuilder =
HnswGraphBuilder.create(
raVectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
}
@ -685,7 +685,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(copyValue(vectorValue));
hnswGraphBuilder.addGraphNode(node, vectorValue);
hnswGraphBuilder.addGraphNode(node);
node++;
lastDocID = docID;
}

View File

@ -45,6 +45,8 @@ import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/**
@ -274,12 +276,11 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
}
OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
HnswGraphSearcher.search(
target,
knnCollector,
vectorValues,
fieldEntry.vectorEncoding,
fieldEntry.similarityFunction,
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
}
@ -296,12 +297,11 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
}
OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target);
HnswGraphSearcher.search(
target,
knnCollector,
vectorValues,
fieldEntry.vectorEncoding,
fieldEntry.similarityFunction,
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
}

View File

@ -41,12 +41,8 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.*;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.*;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
/**
@ -438,10 +434,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
docsWithField.cardinality(),
vectorDataInput,
byteSize);
HnswGraphBuilder<byte[]> hnswGraphBuilder =
createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex);
RandomVectorScorerSupplier scorerSupplier =
RandomVectorScorerSupplier.createBytes(
vectorValues, fieldInfo.getVectorSimilarityFunction());
HnswGraphBuilder hnswGraphBuilder =
createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
yield hnswGraphBuilder.build(vectorValues.size());
}
case FLOAT32 -> {
OffHeapFloatVectorValues.DenseOffHeapVectorValues vectorValues =
@ -450,10 +449,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
docsWithField.cardinality(),
vectorDataInput,
byteSize);
HnswGraphBuilder<float[]> hnswGraphBuilder =
createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex);
RandomVectorScorerSupplier scorerSupplier =
RandomVectorScorerSupplier.createFloats(
vectorValues, fieldInfo.getVectorSimilarityFunction());
HnswGraphBuilder hnswGraphBuilder =
createHnswGraphBuilder(mergeState, fieldInfo, scorerSupplier, initializerIndex);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
yield hnswGraphBuilder.build(vectorValues.size());
}
};
vectorIndexNodeOffsets = writeGraph(graph);
@ -482,20 +484,14 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
}
}
private <T> HnswGraphBuilder<T> createHnswGraphBuilder(
private HnswGraphBuilder createHnswGraphBuilder(
MergeState mergeState,
FieldInfo fieldInfo,
RandomAccessVectorValues<T> floatVectorValues,
RandomVectorScorerSupplier scorerSupplier,
int initializerIndex)
throws IOException {
if (initializerIndex == -1) {
return HnswGraphBuilder.create(
floatVectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
}
HnswGraph initializerGraph =
@ -503,14 +499,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
Map<Integer, Integer> ordinalMapper =
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
return HnswGraphBuilder.create(
floatVectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed,
initializerGraph,
ordinalMapper);
scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, initializerGraph, ordinalMapper);
}
private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
@ -868,7 +857,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
private final HnswGraphBuilder<T> hnswGraphBuilder;
private final HnswGraphBuilder hnswGraphBuilder;
private int lastDocID = -1;
private int node = 0;
@ -892,20 +881,25 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
};
}
@SuppressWarnings("unchecked")
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException {
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
RAVectorValues<T> raVectors = new RAVectorValues<>(vectors, dim);
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
(RandomAccessVectorValues<byte[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
(RandomAccessVectorValues<float[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
};
hnswGraphBuilder =
HnswGraphBuilder.create(
new RAVectorValues<>(vectors, dim),
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
}
@ -920,7 +914,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(copyValue(vectorValue));
hnswGraphBuilder.addGraphNode(node, vectorValue);
hnswGraphBuilder.addGraphNode(node);
node++;
lastDocID = docID;
}

View File

@ -35,6 +35,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
protected final int dimension;
protected final int size;
protected final IndexInput slice;
protected int lastOrd = -1;
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
@ -60,7 +61,10 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public byte[] vectorValue(int targetOrd) throws IOException {
readValue(targetOrd);
if (lastOrd != targetOrd) {
readValue(targetOrd);
lastOrd = targetOrd;
}
return binaryValue;
}
@ -97,9 +101,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public byte[] vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
return binaryValue;
return vectorValue(doc);
}
@Override
@ -164,9 +166,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
@Override
public byte[] vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
return binaryValue;
return vectorValue(disi.index());
}
@Override

View File

@ -35,6 +35,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
protected final int size;
protected final IndexInput slice;
protected final int byteSize;
protected int lastOrd = -1;
protected final float[] value;
OffHeapFloatVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
@ -57,8 +58,12 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public float[] vectorValue(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
return value;
}
slice.seek((long) targetOrd * byteSize);
slice.readFloats(value, 0, value.length);
lastOrd = targetOrd;
return value;
}
@ -91,9 +96,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public float[] vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
slice.readFloats(value, 0, value.length);
return value;
return vectorValue(doc);
}
@Override
@ -158,9 +161,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
@Override
public float[] vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readFloats(value, 0, value.length);
return value;
return vectorValue(disi.index());
}
@Override

View File

@ -28,8 +28,6 @@ import java.util.Objects;
import java.util.Set;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.FixedBitSet;
@ -37,11 +35,9 @@ import org.apache.lucene.util.InfoStream;
/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
* hyperparameters.
*
* @param <T> the type of vector
* hyper-parameters.
*/
public final class HnswGraphBuilder<T> {
public final class HnswGraphBuilder {
/** Default number of maximum connections per node */
public static final int DEFAULT_MAX_CONN = 16;
@ -64,11 +60,9 @@ public final class HnswGraphBuilder<T> {
private final double ml;
private final NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
private final RandomAccessVectorValues<T> vectors;
private final SplittableRandom random;
private final HnswGraphSearcher<T> graphSearcher;
private final RandomVectorScorerSupplier scorerSupplier;
private final HnswGraphSearcher graphSearcher;
private final GraphBuilderKnnCollector entryCandidates; // for upper levels of graph search
private final GraphBuilderKnnCollector
beamCandidates; // for levels of graph where we add the node
@ -77,34 +71,23 @@ public final class HnswGraphBuilder<T> {
private InfoStream infoStream = InfoStream.getDefault();
// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
private final RandomAccessVectorValues<T> vectorsCopy;
private final Set<Integer> initializedNodes;
public static <T> HnswGraphBuilder<T> create(
RandomAccessVectorValues<T> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
int M,
int beamWidth,
long seed)
public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
throws IOException {
return new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed);
}
public static <T> HnswGraphBuilder<T> create(
RandomAccessVectorValues<T> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
long seed,
HnswGraph initializerGraph,
Map<Integer, Integer> oldToNewOrdinalMap)
throws IOException {
HnswGraphBuilder<T> hnswGraphBuilder =
new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed);
hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
return hnswGraphBuilder;
}
@ -113,8 +96,7 @@ public final class HnswGraphBuilder<T> {
* Reads all the vectors from vector values, builds a graph connecting them by their dense
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
*
* @param vectors the vectors whose relations are represented by the graph - must provide a
* different view over those vectors than the one used to add via addGraphNode.
* @param scorerSupplier a supplier to create vector scorer from ordinals.
* @param M graph fanout parameter used to calculate the maximum number of connections a node
* can have M on upper layers, and M * 2 on the lowest level.
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
@ -122,17 +104,8 @@ public final class HnswGraphBuilder<T> {
* to ensure repeatable construction.
*/
private HnswGraphBuilder(
RandomAccessVectorValues<T> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
int M,
int beamWidth,
long seed)
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
throws IOException {
this.vectors = vectors;
this.vectorsCopy = vectors.copy();
this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
this.similarityFunction = Objects.requireNonNull(similarityFunction);
if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
}
@ -140,16 +113,15 @@ public final class HnswGraphBuilder<T> {
throw new IllegalArgumentException("beamWidth must be positive");
}
this.M = M;
this.scorerSupplier =
Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null");
// normalization factor for level generation; currently not configurable
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
this.hnsw = new OnHeapHnswGraph(M);
this.graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(beamWidth, true),
new FixedBitSet(this.vectors.size()));
new HnswGraphSearcher(
new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
// in scratch we store candidates in reverse order: worse candidates are first
scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
entryCandidates = new GraphBuilderKnnCollector(1);
@ -158,22 +130,15 @@ public final class HnswGraphBuilder<T> {
}
/**
* Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two
* copies enables efficient retrieval without extra data copying, while avoiding collision of the
* returned values.
* Adds all nodes to the graph up to the provided {@code maxOrd}.
*
* @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an
* independent accessor for the vectors
* @param maxOrd The maximum ordinal of the nodes to be added.
*/
public OnHeapHnswGraph build(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
if (vectorsToAdd == this.vectors) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
}
public OnHeapHnswGraph build(int maxOrd) throws IOException {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "build graph from " + vectorsToAdd.size() + " vectors");
infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors");
}
addVectors(vectorsToAdd);
addVectors(maxOrd);
return hnsw;
}
@ -216,19 +181,6 @@ public final class HnswGraphBuilder<T> {
}
}
private void addVectors(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
long start = System.nanoTime(), t = start;
for (int node = 0; node < vectorsToAdd.size(); node++) {
if (initializedNodes.contains(node)) {
continue;
}
addGraphNode(node, vectorsToAdd);
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
t = printGraphBuildStatus(node, start, t);
}
}
}
/** Set info-stream to output debugging information * */
public void setInfoStream(InfoStream infoStream) {
this.infoStream = infoStream;
@ -238,8 +190,22 @@ public final class HnswGraphBuilder<T> {
return hnsw;
}
private void addVectors(int maxOrd) throws IOException {
long start = System.nanoTime(), t = start;
for (int node = 0; node < maxOrd; node++) {
if (initializedNodes.contains(node)) {
continue;
}
addGraphNode(node);
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
t = printGraphBuildStatus(node, start, t);
}
}
}
/** Inserts a doc with vector value to the graph */
public void addGraphNode(int node, T value) throws IOException {
public void addGraphNode(int node) throws IOException {
RandomVectorScorer scorer = scorerSupplier.scorer(node);
final int nodeLevel = getRandomGraphLevel(ml, random);
int curMaxLevel = hnsw.numLevels() - 1;
@ -261,24 +227,20 @@ public final class HnswGraphBuilder<T> {
GraphBuilderKnnCollector candidates = entryCandidates;
for (int level = curMaxLevel; level > nodeLevel; level--) {
candidates.clear();
graphSearcher.searchLevel(candidates, value, level, eps, vectors, hnsw, null);
graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
eps = new int[] {candidates.popNode()};
}
// for levels <= nodeLevel search with topk = beamWidth, and add connections
candidates = beamCandidates;
for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
candidates.clear();
graphSearcher.searchLevel(candidates, value, level, eps, vectors, hnsw, null);
graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
eps = candidates.popUntilNearestKNodes();
hnsw.addNode(level, node);
addDiverseNeighbors(level, node, candidates);
}
}
public void addGraphNode(int node, RandomAccessVectorValues<T> values) throws IOException {
addGraphNode(node, values.vectorValue(node));
}
private long printGraphBuildStatus(int node, long start, long t) {
long now = System.nanoTime();
infoStream.message(
@ -353,36 +315,9 @@ public final class HnswGraphBuilder<T> {
*/
private boolean diversityCheck(int candidate, float score, NeighborArray neighbors)
throws IOException {
return isDiverse(candidate, neighbors, score);
}
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
throws IOException {
return switch (vectorEncoding) {
case BYTE -> isDiverse((byte[]) vectors.vectorValue(candidate), neighbors, score);
case FLOAT32 -> isDiverse((float[]) vectors.vectorValue(candidate), neighbors, score);
};
}
private boolean isDiverse(float[] candidate, NeighborArray neighbors, float score)
throws IOException {
RandomVectorScorer scorer = scorerSupplier.scorer(candidate);
for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity =
similarityFunction.compare(
candidate, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
if (neighborSimilarity >= score) {
return false;
}
}
return true;
}
private boolean isDiverse(byte[] candidate, NeighborArray neighbors, float score)
throws IOException {
for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity =
similarityFunction.compare(
candidate, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
float neighborSimilarity = scorer.score(neighbors.node[i]);
if (neighborSimilarity >= score) {
return false;
}
@ -395,26 +330,8 @@ public final class HnswGraphBuilder<T> {
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors, int nodeOrd) throws IOException {
float[] vectorValue = null;
byte[] binaryValue = null;
switch (this.vectorEncoding) {
case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(nodeOrd);
case BYTE -> binaryValue = (byte[]) vectors.vectorValue(nodeOrd);
}
float[] finalVectorValue = vectorValue;
byte[] finalBinaryValue = binaryValue;
int[] uncheckedIndexes =
neighbors.sort(
nbrOrd -> {
float score =
switch (this.vectorEncoding) {
case FLOAT32 -> this.similarityFunction.compare(
finalVectorValue, (float[]) vectorsCopy.vectorValue(nbrOrd));
case BYTE -> this.similarityFunction.compare(
finalBinaryValue, (byte[]) vectorsCopy.vectorValue(nbrOrd));
};
return score;
});
RandomVectorScorer scorer = scorerSupplier.scorer(nodeOrd);
int[] uncheckedIndexes = neighbors.sort(scorer);
if (uncheckedIndexes == null) {
// all nodes are checked, we will directly return the most distant one
return neighbors.size() - 1;
@ -438,37 +355,12 @@ public final class HnswGraphBuilder<T> {
private boolean isWorstNonDiverse(
int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor)
throws IOException {
int candidateNode = neighbors.node[candidateIndex];
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidateIndex,
(byte[]) vectors.vectorValue(candidateNode),
neighbors,
uncheckedIndexes,
uncheckedCursor);
case FLOAT32 -> isWorstNonDiverse(
candidateIndex,
(float[]) vectors.vectorValue(candidateNode),
neighbors,
uncheckedIndexes,
uncheckedCursor);
};
}
private boolean isWorstNonDiverse(
int candidateIndex,
float[] candidateVector,
NeighborArray neighbors,
int[] uncheckedIndexes,
int uncheckedCursor)
throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
RandomVectorScorer scorer = scorerSupplier.scorer(neighbors.node[candidateIndex]);
if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
// the candidate itself is unchecked
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
float neighborSimilarity = scorer.score(neighbors.node[i]);
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
@ -479,47 +371,7 @@ public final class HnswGraphBuilder<T> {
// inserted) unchecked nodes
assert candidateIndex > uncheckedIndexes[uncheckedCursor];
for (int i = uncheckedCursor; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector,
(float[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
}
}
}
return false;
}
private boolean isWorstNonDiverse(
int candidateIndex,
byte[] candidateVector,
NeighborArray neighbors,
int[] uncheckedIndexes,
int uncheckedCursor)
throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
// the candidate itself is unchecked
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
}
}
} else {
// else we just need to make sure candidate does not violate diversity with the (newly
// inserted) unchecked nodes
assert candidateIndex > uncheckedIndexes[uncheckedCursor];
for (int i = uncheckedCursor; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector,
(byte[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
float neighborSimilarity = scorer.score(neighbors.node[uncheckedIndexes[i]]);
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;

View File

@ -20,8 +20,6 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.util.BitSet;
@ -32,13 +30,8 @@ import org.apache.lucene.util.SparseFixedBitSet;
/**
* Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the
* search algorithm, see {@link HnswGraph}.
*
* @param <T> the type of query vector
*/
public class HnswGraphSearcher<T> {
private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
public class HnswGraphSearcher {
/**
* Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive
* to allocate, so they're cleared and reused across calls.
@ -50,17 +43,10 @@ public class HnswGraphSearcher<T> {
/**
* Creates a new graph searcher.
*
* @param similarityFunction the similarity function to compare vectors
* @param candidates max heap that will track the candidate nodes to explore
* @param visited bit set that will track nodes that have already been visited
*/
public HnswGraphSearcher(
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
NeighborQueue candidates,
BitSet visited) {
this.vectorEncoding = vectorEncoding;
this.similarityFunction = similarityFunction;
public HnswGraphSearcher(NeighborQueue candidates, BitSet visited) {
this.candidates = candidates;
this.visited = visited;
}
@ -68,10 +54,27 @@ public class HnswGraphSearcher<T> {
/**
* Searches HNSW graph for the nearest neighbors of a query vector.
*
* @param query search query vector
* @param scorer the scorer to compare the query with the nodes
* @param knnCollector a collector of top knn results to be returned
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
*/
public static void search(
RandomVectorScorer scorer, KnnCollector knnCollector, HnswGraph graph, Bits acceptOrds)
throws IOException {
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(graph.size()));
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
}
/**
* Search {@link OnHeapHnswGraph}, this method is thread safe.
*
* @param scorer the scorer to compare the query with the nodes
* @param topK the number of nodes to be returned
* @param vectors the vector values
* @param similarityFunction the similarity function to compare vectors
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
@ -80,198 +83,36 @@ public class HnswGraphSearcher<T> {
* @return a set of collected vectors holding the nearest neighbors found
*/
public static KnnCollector search(
float[] query,
int topK,
RandomAccessVectorValues<float[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
RandomVectorScorer scorer, int topK, OnHeapHnswGraph graph, Bits acceptOrds, int visitedLimit)
throws IOException {
KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit);
search(query, knnCollector, vectors, vectorEncoding, similarityFunction, graph, acceptOrds);
OnHeapHnswGraphSearcher graphSearcher =
new OnHeapHnswGraphSearcher(
new NeighborQueue(topK, true), new SparseFixedBitSet(graph.size()));
search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
return knnCollector;
}
/**
* Searches HNSW graph for the nearest neighbors of a query vector.
*
* @param query search query vector
* @param knnCollector a collector of top knn results to be returned
* @param vectors the vector values
* @param similarityFunction the similarity function to compare vectors
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
*/
public static void search(
float[] query,
private static void search(
RandomVectorScorer scorer,
KnnCollector knnCollector,
RandomAccessVectorValues<float[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds)
throws IOException {
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
HnswGraphSearcher<float[]> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(knnCollector.k(), true),
new SparseFixedBitSet(vectors.size()));
search(query, knnCollector, vectors, graph, graphSearcher, acceptOrds);
}
/**
* Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
* {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding,
* VectorSimilarityFunction, HnswGraph, Bits, int)}
*/
public static KnnCollector search(
float[] query,
int topK,
RandomAccessVectorValues<float[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
OnHeapHnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit);
OnHeapHnswGraphSearcher<float[]> graphSearcher =
new OnHeapHnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
search(query, knnCollector, vectors, graph, graphSearcher, acceptOrds);
return knnCollector;
}
/**
* Searches HNSW graph for the nearest neighbors of a query vector.
*
* @param query search query vector
* @param topK the number of nodes to be returned
* @param vectors the vector values
* @param similarityFunction the similarity function to compare vectors
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return a set of collected vectors holding the nearest neighbors found
*/
public static KnnCollector search(
byte[] query,
int topK,
RandomAccessVectorValues<byte[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
KnnCollector collector = new TopKnnCollector(topK, visitedLimit);
search(query, collector, vectors, vectorEncoding, similarityFunction, graph, acceptOrds);
return collector;
}
/**
* Searches HNSW graph for the nearest neighbors of a query vector.
*
* @param query search query vector
* @param knnCollector a collector of top knn results to be returned
* @param vectors the vector values
* @param similarityFunction the similarity function to compare vectors
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
*/
public static void search(
byte[] query,
KnnCollector knnCollector,
RandomAccessVectorValues<byte[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds)
throws IOException {
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
HnswGraphSearcher<byte[]> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(knnCollector.k(), true),
new SparseFixedBitSet(vectors.size()));
search(query, knnCollector, vectors, graph, graphSearcher, acceptOrds);
}
/**
* Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
* {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction,
* HnswGraph, Bits, int)}
*/
public static KnnCollector search(
byte[] query,
int topK,
RandomAccessVectorValues<byte[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
OnHeapHnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
OnHeapHnswGraphSearcher<byte[]> graphSearcher =
new OnHeapHnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
KnnCollector collector = new TopKnnCollector(topK, visitedLimit);
search(query, collector, vectors, graph, graphSearcher, acceptOrds);
return collector;
}
private static <T> void search(
T query,
KnnCollector knnCollector,
RandomAccessVectorValues<T> vectors,
HnswGraph graph,
HnswGraphSearcher<T> graphSearcher,
HnswGraphSearcher graphSearcher,
Bits acceptOrds)
throws IOException {
int initialEp = graph.entryNode();
if (initialEp == -1) {
return;
}
int[] epAndVisited =
graphSearcher.findBestEntryPoint(query, vectors, graph, knnCollector.visitLimit());
int[] epAndVisited = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector.visitLimit());
int numVisited = epAndVisited[1];
int ep = epAndVisited[0];
if (ep == -1) {
knnCollector.incVisitedCount(numVisited);
return;
}
KnnCollector results = new OrdinalTranslatedKnnCollector(knnCollector, vectors::ordToDoc);
results.incVisitedCount(numVisited);
graphSearcher.searchLevel(results, query, 0, new int[] {ep}, vectors, graph, acceptOrds);
knnCollector.incVisitedCount(numVisited);
graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds);
}
/**
@ -280,48 +121,40 @@ public class HnswGraphSearcher<T> {
* <p>If the search stops early because it reaches the visited nodes limit, then the results will
* be marked incomplete through {@link NeighborQueue#incomplete()}.
*
* @param query search query vector
* @param scorer the scorer to compare the query with the nodes
* @param topK the number of nearest to query results to return
* @param level level to search
* @param eps the entry points for search at this level expressed as level 0th ordinals
* @param vectors vector values
* @param graph the graph values
* @return a set of collected vectors holding the nearest neighbors found
*/
public HnswGraphBuilder.GraphBuilderKnnCollector searchLevel(
// Note: this is only public because Lucene91HnswGraphBuilder needs it
T query,
int topK,
int level,
final int[] eps,
RandomAccessVectorValues<T> vectors,
HnswGraph graph)
RandomVectorScorer scorer, int topK, int level, final int[] eps, HnswGraph graph)
throws IOException {
HnswGraphBuilder.GraphBuilderKnnCollector results =
new HnswGraphBuilder.GraphBuilderKnnCollector(topK);
searchLevel(results, query, level, eps, vectors, graph, null);
searchLevel(results, scorer, level, eps, graph, null);
return results;
}
/**
* Function to find the best entry point from which to search the zeroth graph layer.
*
* @param query vector query with which to search
* @param vectors random access vector values
* @param scorer the scorer to compare the query with the nodes
* @param graph the HNSWGraph
* @param visitLimit How many vectors are allowed to be visited
* @return An integer array whose first element is the best entry point, and second is the number
* of candidates visited. Entry point of `-1` indicates visitation limit exceed
* @throws IOException When accessing the vector fails
*/
private int[] findBestEntryPoint(
T query, RandomAccessVectorValues<T> vectors, HnswGraph graph, long visitLimit)
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
throws IOException {
int size = graph.size();
int visitedCount = 1;
prepareScratchState(vectors.size());
prepareScratchState(graph.size());
int currentEp = graph.entryNode();
float currentScore = compare(query, vectors, currentEp);
float currentScore = scorer.score(currentEp);
boolean foundBetter;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
foundBetter = true;
@ -339,7 +172,7 @@ public class HnswGraphSearcher<T> {
if (visitedCount >= visitLimit) {
return new int[] {-1, visitedCount};
}
float friendSimilarity = compare(query, vectors, friendOrd);
float friendSimilarity = scorer.score(friendOrd);
visitedCount++;
if (friendSimilarity > currentScore
|| (friendSimilarity == currentScore && friendOrd < currentEp)) {
@ -361,23 +194,22 @@ public class HnswGraphSearcher<T> {
*/
void searchLevel(
KnnCollector results,
T query,
RandomVectorScorer scorer,
int level,
final int[] eps,
RandomAccessVectorValues<T> vectors,
HnswGraph graph,
Bits acceptOrds)
throws IOException {
int size = graph.size();
prepareScratchState(vectors.size());
prepareScratchState(graph.size());
for (int ep : eps) {
if (visited.getAndSet(ep) == false) {
if (results.earlyTerminated()) {
break;
}
float score = compare(query, vectors, ep);
float score = scorer.score(ep);
results.incVisitedCount(1);
candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) {
@ -408,7 +240,7 @@ public class HnswGraphSearcher<T> {
if (results.earlyTerminated()) {
break;
}
float friendSimilarity = compare(query, vectors, friendOrd);
float friendSimilarity = scorer.score(friendOrd);
results.incVisitedCount(1);
if (friendSimilarity >= minAcceptedSimilarity) {
candidates.add(friendOrd, friendSimilarity);
@ -422,14 +254,6 @@ public class HnswGraphSearcher<T> {
}
}
private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException {
if (vectorEncoding == VectorEncoding.BYTE) {
return similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(ord));
} else {
return similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(ord));
}
}
private void prepareScratchState(int capacity) {
candidates.clear();
if (visited.length() < capacity) {
@ -468,17 +292,13 @@ public class HnswGraphSearcher<T> {
* <p>Note the class itself is NOT thread safe, but since each search will create a new Searcher,
* the search methods using this class are thread safe.
*/
private static class OnHeapHnswGraphSearcher<C> extends HnswGraphSearcher<C> {
private static class OnHeapHnswGraphSearcher extends HnswGraphSearcher {
private NeighborArray cur;
private int upto;
private OnHeapHnswGraphSearcher(
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
NeighborQueue candidates,
BitSet visited) {
super(vectorEncoding, similarityFunction, candidates, visited);
private OnHeapHnswGraphSearcher(NeighborQueue candidates, BitSet visited) {
super(candidates, visited);
}
@Override

View File

@ -86,7 +86,7 @@ public class NeighborArray {
* @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is
* already fully sorted
*/
public int[] sort(ScoringFunction scoringFunction) throws IOException {
public int[] sort(RandomVectorScorer scorer) throws IOException {
if (size == sortedNodeSize) {
// all nodes checked and sorted
return null;
@ -95,8 +95,7 @@ public class NeighborArray {
int[] uncheckedIndexes = new int[size - sortedNodeSize];
int count = 0;
while (sortedNodeSize != size) {
uncheckedIndexes[count] =
insertSortedInternal(scoringFunction); // sortedNodeSize is increased inside
uncheckedIndexes[count] = insertSortedInternal(scorer); // sortedNodeSize is increased inside
for (int i = 0; i < count; i++) {
if (uncheckedIndexes[i] >= uncheckedIndexes[count]) {
// the previous inserted nodes has been shifted
@ -110,13 +109,13 @@ public class NeighborArray {
}
/** insert the first unsorted node into its sorted position */
private int insertSortedInternal(ScoringFunction scoringFunction) throws IOException {
private int insertSortedInternal(RandomVectorScorer scorer) throws IOException {
assert sortedNodeSize < size : "Call this method only when there's unsorted node";
int tmpNode = node[sortedNodeSize];
float tmpScore = score[sortedNodeSize];
if (Float.isNaN(tmpScore)) {
tmpScore = scoringFunction.computeScore(tmpNode);
tmpScore = scorer.score(tmpNode);
}
int insertionPoint =
@ -204,20 +203,4 @@ public class NeighborArray {
}
return start;
}
/**
* ScoringFunction is a lambda function created in HnswGraphBuilder to allow for lazy computation
* of distance score.
*/
interface ScoringFunction {
/**
* Computes the distance score between the given node ID and the root node of this
* NeighborArray.
*
* @param nodeId The ID of the node for which to compute the distance score.
* @return The distance score as a float value.
* @throws IOException If an I/O error occurs during computation.
*/
float computeScore(int nodeId) throws IOException;
}
}

View File

@ -24,12 +24,12 @@ import org.apache.lucene.search.TotalHits;
/**
* Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId
*/
final class OrdinalTranslatedKnnCollector implements KnnCollector {
public final class OrdinalTranslatedKnnCollector implements KnnCollector {
private final KnnCollector in;
private final IntToIntFunction vectorOrdinalToDocId;
OrdinalTranslatedKnnCollector(KnnCollector in, IntToIntFunction vectorOrdinalToDocId) {
public OrdinalTranslatedKnnCollector(KnnCollector in, IntToIntFunction vectorOrdinalToDocId) {
this.in = in;
this.vectorOrdinalToDocId = vectorOrdinalToDocId;
}

View File

@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
/** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */
public interface RandomVectorScorer {
/**
* Returns the score between the query and the provided node.
*
* @param node a random node in the graph
* @return the computed score
*/
float score(int node) throws IOException;
/**
* Creates a default scorer for float vectors.
*
* <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
* using it after calling this function. If you plan to use it again outside the returned {@link
* RandomVectorScorer}, think about passing a copied version ({@link
* RandomAccessVectorValues#copy}).
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to score vectors
* @param query the actual query
*/
static RandomVectorScorer createFloats(
final RandomAccessVectorValues<float[]> vectors,
final VectorSimilarityFunction similarityFunction,
final float[] query) {
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
return node -> similarityFunction.compare(query, vectors.vectorValue(node));
}
/**
* Creates a default scorer for byte vectors.
*
* <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
* using it after calling this function. If you plan to use it again outside the returned {@link
* RandomVectorScorer}, think about passing a copied version ({@link
* RandomAccessVectorValues#copy}).
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to use to score vectors
* @param query the actual query
*/
static RandomVectorScorer createBytes(
final RandomAccessVectorValues<byte[]> vectors,
final VectorSimilarityFunction similarityFunction,
final byte[] query) {
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
return node -> similarityFunction.compare(query, vectors.vectorValue(node));
}
}

View File

@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
/** A supplier that creates {@link RandomVectorScorer} from an ordinal. */
public interface RandomVectorScorerSupplier {
/**
* This creates a {@link RandomVectorScorer} for scoring random nodes in batches against the given
* ordinal.
*
* @param ord the ordinal of the node to compare
* @return a new {@link RandomVectorScorer}
*/
RandomVectorScorer scorer(int ord) throws IOException;
/**
* Creates a {@link RandomVectorScorerSupplier} to compare float vectors.
*
* <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
* using it after calling this function. If you plan to use it again outside the returned {@link
* RandomVectorScorer}, think about passing a copied version ({@link
* RandomAccessVectorValues#copy}).
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to score vectors
*/
static RandomVectorScorerSupplier createFloats(
final RandomAccessVectorValues<float[]> vectors,
final VectorSimilarityFunction similarityFunction)
throws IOException {
// We copy the provided random accessor just once during the supplier's initialization
// and then reuse it consistently across all scorers for conducting vector comparisons.
final RandomAccessVectorValues<float[]> vectorsCopy = vectors.copy();
return queryOrd ->
(RandomVectorScorer)
cand ->
similarityFunction.compare(
vectors.vectorValue(queryOrd), vectorsCopy.vectorValue(cand));
}
/**
* Creates a {@link RandomVectorScorerSupplier} to compare byte vectors.
*
* <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
* using it after calling this function. If you plan to use it again outside the returned {@link
* RandomVectorScorer}, think about passing a copied version ({@link
* RandomAccessVectorValues#copy}).
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to score vectors
*/
static RandomVectorScorerSupplier createBytes(
final RandomAccessVectorValues<byte[]> vectors,
final VectorSimilarityFunction similarityFunction)
throws IOException {
// We copy the provided random accessor just once during the supplier's initialization
// and then reuse it consistently across all scorers for conducting vector comparisons.
final RandomAccessVectorValues<byte[]> vectorsCopy = vectors.copy();
return queryOrd ->
(RandomVectorScorer)
cand ->
similarityFunction.compare(
vectors.vectorValue(queryOrd), vectorsCopy.vectorValue(cand));
}
}

View File

@ -109,6 +109,29 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
abstract T getTargetVector();
@SuppressWarnings("unchecked")
protected RandomVectorScorerSupplier buildScorerSupplier(RandomAccessVectorValues<T> vectors)
throws IOException {
return switch (getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
(RandomAccessVectorValues<byte[]>) vectors, similarityFunction);
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
(RandomAccessVectorValues<float[]>) vectors, similarityFunction);
};
}
@SuppressWarnings("unchecked")
protected RandomVectorScorer buildScorer(RandomAccessVectorValues<T> vectors, T query)
throws IOException {
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
return switch (getVectorEncoding()) {
case BYTE -> RandomVectorScorer.createBytes(
(RandomAccessVectorValues<byte[]>) vectorsCopy, similarityFunction, (byte[]) query);
case FLOAT32 -> RandomVectorScorer.createFloats(
(RandomAccessVectorValues<float[]>) vectorsCopy, similarityFunction, (float[]) query);
};
}
// test writing out and reading in a graph gives the expected graph
public void testReadWrite() throws IOException {
int dim = random().nextInt(100) + 1;
@ -118,10 +141,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
long seed = random().nextLong();
AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
AbstractMockVectorValues<T> v2 = vectors.copy(), v3 = vectors.copy();
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors.copy());
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors.size());
// Recreate the graph while indexing with the same random seed and write it out
HnswGraphBuilder.randSeed = seed;
@ -349,33 +371,14 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int nDoc = 100;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
// run some searches
KnnCollector nn =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(byte[]) getTargetVector(),
10,
(RandomAccessVectorValues<byte[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
null,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) getTargetVector(),
10,
(RandomAccessVectorValues<float[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
null,
Integer.MAX_VALUE);
};
HnswGraphSearcher.search(
buildScorer(vectors, getTargetVector()), 10, hnsw, null, Integer.MAX_VALUE);
TopDocs topDocs = nn.topDocs();
assertEquals("Number of found results is not equal to [10].", 10, topDocs.scoreDocs.length);
int sum = 0;
@ -401,33 +404,14 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int nDoc = 100;
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
// the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, nDoc);
KnnCollector nn =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(byte[]) getTargetVector(),
10,
(RandomAccessVectorValues<byte[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) getTargetVector(),
10,
(RandomAccessVectorValues<float[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};
HnswGraphSearcher.search(
buildScorer(vectors, getTargetVector()), 10, hnsw, acceptOrds, Integer.MAX_VALUE);
TopDocs nodes = nn.topDocs();
assertEquals("Number of found results is not equal to [10].", 10, nodes.scoreDocs.length);
int sum = 0;
@ -445,10 +429,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int nDoc = 100;
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
// Only mark a few vectors as accepted
BitSet acceptOrds = new FixedBitSet(nDoc);
for (int i = 0; i < nDoc; i += random().nextInt(15, 20)) {
@ -458,27 +441,12 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
// Check the search finds all accepted vectors
int numAccepted = acceptOrds.cardinality();
KnnCollector nn =
switch (getVectorEncoding()) {
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) getTargetVector(),
numAccepted,
(RandomAccessVectorValues<float[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case BYTE -> HnswGraphSearcher.search(
(byte[]) getTargetVector(),
numAccepted,
(RandomAccessVectorValues<byte[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};
HnswGraphSearcher.search(
buildScorer(vectors, getTargetVector()),
numAccepted,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
TopDocs nodes = nn.topDocs();
assertEquals(numAccepted, nodes.scoreDocs.length);
for (ScoreDoc node : nodes.scoreDocs) {
@ -565,32 +533,26 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
long seed = random().nextLong();
AbstractMockVectorValues<T> initializerVectors = vectorValues(initializerSize, dim);
HnswGraphBuilder<T> initializerBuilder =
HnswGraphBuilder.create(
initializerVectors, getVectorEncoding(), similarityFunction, 10, 30, seed);
RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors);
HnswGraphBuilder initializerBuilder =
HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed);
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy());
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size());
AbstractMockVectorValues<T> finalVectorValues =
vectorValues(totalSize, dim, initializerVectors, docIdOffset);
Map<Integer, Integer> initializerOrdMap =
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
HnswGraphBuilder<T> finalBuilder =
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
HnswGraphBuilder finalBuilder =
HnswGraphBuilder.create(
finalVectorValues,
getVectorEncoding(),
similarityFunction,
10,
30,
seed,
initializerGraph,
initializerOrdMap);
finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap);
// When offset is 0, the graphs should be identical before vectors are added
assertGraphEqual(initializerGraph, finalBuilder.getGraph());
OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy());
OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.size());
assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
}
@ -602,31 +564,26 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
long seed = random().nextLong();
AbstractMockVectorValues<T> initializerVectors = vectorValues(initializerSize, dim);
HnswGraphBuilder<T> initializerBuilder =
HnswGraphBuilder.create(
initializerVectors.copy(), getVectorEncoding(), similarityFunction, 10, 30, seed);
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy());
RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors);
HnswGraphBuilder initializerBuilder =
HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed);
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size());
AbstractMockVectorValues<T> finalVectorValues =
vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset);
Map<Integer, Integer> initializerOrdMap =
createOffsetOrdinalMap(initializerSize, finalVectorValues.copy(), docIdOffset);
HnswGraphBuilder<T> finalBuilder =
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
HnswGraphBuilder finalBuilder =
HnswGraphBuilder.create(
finalVectorValues,
getVectorEncoding(),
similarityFunction,
10,
30,
seed,
initializerGraph,
initializerOrdMap);
finalscorerSupplier, 10, 30, seed, initializerGraph, initializerOrdMap);
assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap);
// Confirm that the graph is appropriately constructed by checking that the nodes in the old
// graph are present in the levels of the new graph
OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy());
OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.size());
assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
}
@ -718,65 +675,32 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int nDoc = 500;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
int topK = 50;
int visitedLimit = topK + random().nextInt(5);
KnnCollector nn =
switch (getVectorEncoding()) {
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) getTargetVector(),
topK,
(RandomAccessVectorValues<float[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
createRandomAcceptOrds(0, nDoc),
visitedLimit);
case BYTE -> HnswGraphSearcher.search(
(byte[]) getTargetVector(),
topK,
(RandomAccessVectorValues<byte[]>) vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
createRandomAcceptOrds(0, nDoc),
visitedLimit);
};
HnswGraphSearcher.search(
buildScorer(vectors, getTargetVector()),
topK,
hnsw,
createRandomAcceptOrds(0, nDoc),
visitedLimit);
assertTrue(nn.earlyTerminated());
// The visited count shouldn't exceed the limit
assertTrue(nn.visitedCount() <= visitedLimit);
}
public void testHnswGraphBuilderInvalid() {
expectThrows(
NullPointerException.class, () -> HnswGraphBuilder.create(null, null, null, 0, 0, 0));
public void testHnswGraphBuilderInvalid() throws IOException {
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectorValues(1, 1));
// M must be > 0
expectThrows(
IllegalArgumentException.class,
() ->
HnswGraphBuilder.create(
vectorValues(1, 1),
getVectorEncoding(),
VectorSimilarityFunction.EUCLIDEAN,
0,
10,
0));
IllegalArgumentException.class, () -> HnswGraphBuilder.create(scorerSupplier, 0, 10, 0));
// beamWidth must be > 0
expectThrows(
IllegalArgumentException.class,
() ->
HnswGraphBuilder.create(
vectorValues(1, 1),
getVectorEncoding(),
VectorSimilarityFunction.EUCLIDEAN,
10,
0,
0));
IllegalArgumentException.class, () -> HnswGraphBuilder.create(scorerSupplier, 10, 0, 0));
}
public void testRamUsageEstimate() throws IOException {
@ -784,14 +708,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int dim = randomIntBetween(100, 1024);
int M = randomIntBetween(4, 96);
VectorSimilarityFunction similarityFunction =
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
RandomAccessVectorValues<T> vectors = vectorValues(size, dim);
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, M, M * 2, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder =
HnswGraphBuilder.create(scorerSupplier, M, M * 2, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
long estimated = RamUsageEstimator.sizeOfObject(hnsw);
long actual = ramUsed(hnsw);
@ -813,21 +736,19 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
};
AbstractMockVectorValues<T> vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt());
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 2, 10, random().nextInt());
// node 0 is added by the builder constructor
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
builder.addGraphNode(0, vectorsCopy);
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
builder.addGraphNode(0);
builder.addGraphNode(1);
builder.addGraphNode(2);
// now every node has tried to attach every other node as a neighbor, but
// some were excluded based on diversity check.
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0);
assertLevel0Neighbors(builder.hnsw, 2, 0);
builder.addGraphNode(3, vectorsCopy);
builder.addGraphNode(3);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// we added 3 here
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
@ -835,7 +756,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 3, 1);
// supplant an existing neighbor
builder.addGraphNode(4, vectorsCopy);
builder.addGraphNode(4);
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4);
@ -844,7 +765,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
assertLevel0Neighbors(builder.hnsw, 4, 1, 3);
builder.addGraphNode(5, vectorsCopy);
builder.addGraphNode(5);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5);
assertLevel0Neighbors(builder.hnsw, 2, 0);
@ -869,20 +790,18 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
};
AbstractMockVectorValues<T> vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
builder.addGraphNode(0, vectorsCopy);
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt());
builder.addGraphNode(0);
builder.addGraphNode(1);
builder.addGraphNode(2);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// 2 is closer to 0 than 1, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 1, 0);
// 1 is closer to 0 than 2, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 2, 0);
builder.addGraphNode(3, vectorsCopy);
builder.addGraphNode(3);
// this is one case we are testing; 2 has been displaced by 3
assertLevel0Neighbors(builder.hnsw, 0, 1, 3);
assertLevel0Neighbors(builder.hnsw, 1, 0);
@ -901,20 +820,18 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
};
AbstractMockVectorValues<T> vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
builder.addGraphNode(0, vectorsCopy);
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt());
builder.addGraphNode(0);
builder.addGraphNode(1);
builder.addGraphNode(2);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// 2 is closer to 0 than 1, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 1, 0);
// 1 is closer to 0 than 2, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 2, 0);
builder.addGraphNode(3, vectorsCopy);
builder.addGraphNode(3);
// this is one case we are testing; 1 has been displaced by 3
assertLevel0Neighbors(builder.hnsw, 0, 2, 3);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
@ -939,10 +856,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int dim = atLeast(10);
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
int topK = 5;
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
int totalMatches = 0;
@ -950,27 +866,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
KnnCollector actual;
T query = randomVector(dim);
actual =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(byte[]) query,
100,
(RandomAccessVectorValues<byte[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) query,
100,
(RandomAccessVectorValues<float[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};
HnswGraphSearcher.search(
buildScorer(vectors, query), 100, hnsw, acceptOrds, Integer.MAX_VALUE);
TopDocs topDocs = actual.topDocs();
NeighborQueue expected = new NeighborQueue(topK, false);
for (int j = 0; j < size; j++) {
@ -1007,10 +904,9 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
int size = atLeast(100);
int dim = atLeast(10);
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
List<T> queries = new ArrayList<>();
@ -1020,27 +916,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
T query = randomVector(dim);
queries.add(query);
expect =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(byte[]) query,
100,
(RandomAccessVectorValues<byte[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) query,
100,
(RandomAccessVectorValues<float[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};
HnswGraphSearcher.search(
buildScorer(vectors, query), 100, hnsw, acceptOrds, Integer.MAX_VALUE);
expects.add(expect);
}
@ -1054,26 +931,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
KnnCollector actual;
try {
actual =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(byte[]) query,
100,
(RandomAccessVectorValues<byte[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) query,
100,
(RandomAccessVectorValues<float[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};
HnswGraphSearcher.search(
buildScorer(vectors, query), 100, hnsw, acceptOrds, Integer.MAX_VALUE);
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}

View File

@ -130,10 +130,9 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
int nDoc = 1000;
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
RandomAccessVectorValues<float[]> vectors = circularVectorValues(nDoc);
HnswGraphBuilder<float[]> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.size());
// Skip over half of the documents that are closest to the query vector
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
@ -142,14 +141,7 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
}
KnnCollector nn =
HnswGraphSearcher.search(
getTargetVector(),
10,
vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
buildScorer(vectors, getTargetVector()), 10, hnsw, acceptOrds, Integer.MAX_VALUE);
TopDocs nodes = nn.topDocs();
assertEquals("Number of found results is not equal to [10].", 10, nodes.scoreDocs.length);