LUCENE-10040: Handle deletions in nearest vector search (#239)

This PR extends VectorReader#search to take a parameter specifying the live
docs. LeafReader#searchNearestVectors then always returns the k nearest
undeleted docs.

To implement this, the HNSW algorithm will only add a candidate to the result
set if it is a live doc. The graph search still visits and traverses deleted
docs as it gathers candidates.
This commit is contained in:
Julie Tibshirani 2021-08-16 17:44:17 +03:00 committed by GitHub
parent 19e5c00a4f
commit 6993fb9a99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 222 additions and 59 deletions

View File

@ -7,9 +7,9 @@ http://s.apache.org/luceneversions
New Features
* LUCENE-9322 LUCENE-9855: Vector-valued fields, Lucene90 Codec (Mike Sokolov, Julie Tibshirani, Tomoko Uchida)
* LUCENE-9322, LUCENE-9855: Vector-valued fields, Lucene90 Codec (Mike Sokolov, Julie Tibshirani, Tomoko Uchida)
* LUCENE-9004: Approximate nearest vector search via NSW graphs (Mike Sokolov, Tomoko Uchida et al.)
* LUCENE-9004, LUCENE-10040: Approximate nearest vector search via NSW graphs (Mike Sokolov, Tomoko Uchida et al.)
* LUCENE-9659: SpanPayloadCheckQuery now supports inequalities. (Kevin Watters, Gus Heck)

View File

@ -37,6 +37,7 @@ import org.apache.lucene.store.BufferedChecksumIndexInput;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IOUtils;
@ -138,7 +139,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, float[] target, int k) throws IOException {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -23,6 +23,7 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.NamedSPILoader;
/**
@ -99,7 +100,7 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
}
@Override
public TopDocs search(String field, float[] target, int k) {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
return TopDocsCollector.EMPTY_TOPDOCS;
}

View File

@ -22,6 +22,7 @@ import java.io.IOException;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
/** Reads vectors from an index. */
public abstract class KnnVectorsReader implements Closeable, Accountable {
@ -51,9 +52,12 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
* if they are all allowed to match.
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
*/
public abstract TopDocs search(String field, float[] target, int k) throws IOException;
public abstract TopDocs search(String field, float[] target, int k, Bits acceptDocs)
throws IOException;
/**
* Returns an instance optimized for merging. This instance may only be consumed in the thread

View File

@ -43,6 +43,7 @@ import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
@ -232,7 +233,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
@Override
public TopDocs search(String field, float[] target, int k) throws IOException {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.dimension == 0) {
return null;
@ -250,6 +251,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
vectorValues,
fieldEntry.similarityFunction,
getGraphValues(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry),
random);
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
@ -276,6 +278,23 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return new OffHeapVectorValues(fieldEntry, bytesSlice);
}
private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
if (acceptDocs == null) {
return null;
}
return new Bits() {
@Override
public boolean get(int index) {
return acceptDocs.get(fieldEntry.ordToDoc[index]);
}
@Override
public int length() {
return fieldEntry.ordToDoc.length;
}
};
}
public KnnGraphValues getGraphValues(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {

View File

@ -33,6 +33,7 @@ import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
/**
@ -240,12 +241,12 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public TopDocs search(String field, float[] target, int k) throws IOException {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field);
if (knnVectorsReader == null) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
} else {
return knnVectorsReader.search(field, target, k);
return knnVectorsReader.search(field, target, k, acceptDocs);
}
}

View File

@ -26,6 +26,7 @@ import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
/** LeafReader implemented by codec APIs. */
public abstract class CodecReader extends LeafReader {
@ -211,7 +212,7 @@ public abstract class CodecReader extends LeafReader {
}
@Override
public final TopDocs searchNearestVectors(String field, float[] target, int k)
public final TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
@ -220,7 +221,7 @@ public abstract class CodecReader extends LeafReader {
return null;
}
return getVectorReader().search(field, target, k);
return getVectorReader().search(field, target, k, acceptDocs);
}
@Override

View File

@ -53,7 +53,8 @@ abstract class DocValuesLeafReader extends LeafReader {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException {
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -345,8 +345,9 @@ public abstract class FilterLeafReader extends LeafReader {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException {
return in.searchNearestVectors(field, target, k);
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
return in.searchNearestVectors(field, target, k, acceptDocs);
}
@Override

View File

@ -222,10 +222,12 @@ public abstract class LeafReader extends IndexReader {
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
* if they are all allowed to match.
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
* @lucene.experimental
*/
public abstract TopDocs searchNearestVectors(String field, float[] target, int k)
public abstract TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException;
/**

View File

@ -209,8 +209,9 @@ class MergeReaderWrapper extends LeafReader {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException {
return in.searchNearestVectors(field, target, k);
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
return in.searchNearestVectors(field, target, k, acceptDocs);
}
@Override

View File

@ -398,10 +398,11 @@ public class ParallelLeafReader extends LeafReader {
}
@Override
public TopDocs searchNearestVectors(String fieldName, float[] target, int k) throws IOException {
public TopDocs searchNearestVectors(String fieldName, float[] target, int k, Bits acceptDocs)
throws IOException {
ensureOpen();
LeafReader reader = fieldToReader.get(fieldName);
return reader == null ? null : reader.searchNearestVectors(fieldName, target, k);
return reader == null ? null : reader.searchNearestVectors(fieldName, target, k, acceptDocs);
}
@Override

View File

@ -167,8 +167,9 @@ public final class SlowCodecReaderWrapper {
}
@Override
public TopDocs search(String field, float[] target, int k) throws IOException {
return reader.searchNearestVectors(field, target, k);
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
return reader.searchNearestVectors(field, target, k, acceptDocs);
}
@Override

View File

@ -315,7 +315,7 @@ public final class SortingCodecReader extends FilterCodecReader {
}
@Override
public TopDocs search(String field, float[] target, int k) {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
throw new UnsupportedOperationException();
}

View File

@ -26,6 +26,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.Bits;
/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
public class KnnVectorQuery extends Query {
@ -70,7 +71,8 @@ public class KnnVectorQuery extends Query {
}
private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
Bits liveDocs = ctx.reader().getLiveDocs();
TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf, liveDocs);
if (results == null) {
return NO_RESULTS;
}

View File

@ -26,6 +26,7 @@ import java.util.Random;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SparseFixedBitSet;
/**
@ -83,6 +84,8 @@ public final class HnswGraph extends KnnGraphValues {
* @param vectors vector values
* @param graphValues the graph values. May represent the entire graph, or a level in a
* hierarchical graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
* @param random a source of randomness, used for generating entry points to the graph
* @return a priority queue holding the closest neighbors found
*/
@ -93,12 +96,15 @@ public final class HnswGraph extends KnnGraphValues {
RandomAccessVectorValues vectors,
VectorSimilarityFunction similarityFunction,
KnnGraphValues graphValues,
Bits acceptOrds,
Random random)
throws IOException {
int size = graphValues.size();
// MIN heap, holding the top results
NeighborQueue results = new NeighborQueue(numSeed, similarityFunction.reversed);
// MAX heap, from which to pull the candidate nodes
NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
SparseFixedBitSet visited = new SparseFixedBitSet(size);
@ -109,12 +115,13 @@ public final class HnswGraph extends KnnGraphValues {
if (visited.get(entryPoint) == false) {
visited.set(entryPoint);
// explore the topK starting points of some random numSeed probes
results.add(entryPoint, similarityFunction.compare(query, vectors.vectorValue(entryPoint)));
float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint));
candidates.add(entryPoint, score);
if (acceptOrds == null || acceptOrds.get(entryPoint)) {
results.add(entryPoint, score);
}
}
}
// MAX heap, from which to pull the candidate nodes
NeighborQueue candidates = results.copy(!similarityFunction.reversed);
// Set the bound to the worst current result and below reject any newly-generated candidates
// failing
@ -138,13 +145,17 @@ public final class HnswGraph extends KnnGraphValues {
continue;
}
visited.set(friendOrd);
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
if (results.insertWithOverflow(friendOrd, score)) {
if (results.size() < numSeed || bound.check(score) == false) {
candidates.add(friendOrd, score);
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
results.insertWithOverflow(friendOrd, score);
bound.set(results.topScore());
}
}
}
}
while (results.size() > topK) {
results.pop();
}

View File

@ -134,9 +134,10 @@ public final class HnswGraphBuilder {
/** Inserts a doc with vector value to the graph */
void addGraphNode(float[] value) throws IOException {
// We pass 'null' for acceptOrds because there are no deletions while building the graph
NeighborQueue candidates =
HnswGraph.search(
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, random);
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
int node = hnsw.addNode();

View File

@ -42,13 +42,6 @@ public class NeighborQueue {
}
}
NeighborQueue copy(boolean reversed) {
int size = size();
NeighborQueue copy = new NeighborQueue(size, reversed);
copy.heap.pushAll(heap);
return copy;
}
/** @return the number of elements in the heap */
public int size() {
return heap.size();

View File

@ -38,6 +38,7 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.RandomCodec;
import org.apache.lucene.index.SegmentReadState;
@ -101,19 +102,13 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
// Double-check the vectors were written
try (IndexReader ireader = DirectoryReader.open(directory)) {
LeafReader reader = ireader.leaves().get(0).reader();
TopDocs hits1 =
ireader
.leaves()
.get(0)
.reader()
.searchNearestVectors("field1", new float[] {1, 2, 3}, 10);
reader.searchNearestVectors("field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs());
assertEquals(1, hits1.scoreDocs.length);
TopDocs hits2 =
ireader
.leaves()
.get(0)
.reader()
.searchNearestVectors("field2", new float[] {1, 2, 3}, 10);
reader.searchNearestVectors("field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs());
assertEquals(1, hits2.scoreDocs.length);
}
}

View File

@ -42,6 +42,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
@ -291,7 +292,8 @@ public class TestKnnGraph extends LuceneTestCase {
private static TopDocs doKnnSearch(IndexReader reader, float[] vector, int k) throws IOException {
TopDocs[] results = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx : reader.leaves()) {
results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k);
Bits liveDocs = ctx.reader().getLiveDocs();
results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs);
if (ctx.docBase > 0) {
for (ScoreDoc doc : results[ctx.ord].scoreDocs) {
doc.doc += ctx.docBase;

View File

@ -112,7 +112,7 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k) {
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
return null;
}

View File

@ -16,10 +16,13 @@
*/
package org.apache.lucene.search;
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.TestVectorUtil.randomVector;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField;
@ -303,6 +306,77 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
public void testDeletes() throws IOException {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
final int numDocs = atLeast(100);
final int dim = 30;
int docIndex = 0;
for (int i = 0; i < numDocs; ++i) {
Document d = new Document();
if (frequently()) {
d.add(new StringField("index", String.valueOf(docIndex), Field.Store.YES));
d.add(new KnnVectorField("vector", randomVector(dim)));
docIndex++;
} else {
d.add(new StringField("other", "value" + (i % 5), Field.Store.NO));
}
w.addDocument(d);
}
w.commit();
// Delete some documents at random, both those with and without vectors
Set<Term> toDelete = new HashSet<>();
for (int i = 0; i < 20; i++) {
int index = random().nextInt(docIndex);
toDelete.add(new Term("index", String.valueOf(index)));
}
w.deleteDocuments(toDelete.toArray(new Term[0]));
w.deleteDocuments(new Term("other", "value" + random().nextInt(5)));
w.commit();
try (IndexReader reader = DirectoryReader.open(dir)) {
Set<String> allIds = new HashSet<>();
IndexSearcher searcher = new IndexSearcher(reader);
KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs);
TopDocs topDocs = searcher.search(query, numDocs);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
Document doc = reader.document(scoreDoc.doc, Set.of("index"));
String index = doc.get("index");
assertFalse(
"search returned a deleted document: " + index,
toDelete.contains(new Term("index", index)));
allIds.add(index);
}
assertEquals("search missed some documents", docIndex - toDelete.size(), allIds.size());
}
}
}
public void testAllDeletes() throws IOException {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
final int numDocs = atLeast(100);
final int dim = 30;
for (int i = 0; i < numDocs; ++i) {
Document d = new Document();
d.add(new KnnVectorField("vector", randomVector(dim)));
w.addDocument(d);
}
w.commit();
w.deleteDocuments(new MatchAllDocsQuery());
w.commit();
try (IndexReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs);
TopDocs topDocs = searcher.search(query, numDocs);
assertEquals(0, topDocs.scoreDocs.length);
}
}
}
private Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);

View File

@ -58,6 +58,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.PrintStreamInfoStream;
@ -424,7 +425,8 @@ public class KnnGraphTester {
IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException {
TopDocs[] results = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx : reader.leaves()) {
results[ctx.ord] = ctx.reader().searchNearestVectors(field, vector, k + fanout);
Bits liveDocs = ctx.reader().getLiveDocs();
results[ctx.ord] = ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs);
int docBase = ctx.docBase;
for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
scoreDoc.doc += docBase;

View File

@ -45,12 +45,14 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
/** Tests HNSW KNN graphs */
public class TestHnsw extends LuceneTestCase {
public class TestHnswGraph extends LuceneTestCase {
// test writing out and reading in a graph gives the expected graph
public void testReadWrite() throws IOException {
@ -138,6 +140,7 @@ public class TestHnsw extends LuceneTestCase {
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
null,
random());
int sum = 0;
for (int node : nn.nodes()) {
@ -156,6 +159,35 @@ public class TestHnsw extends LuceneTestCase {
}
}
public void testSearchWithAcceptOrds() throws IOException {
int nDoc = 100;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
HnswGraph hnsw = builder.build(vectors);
Bits acceptOrds = createRandomAcceptOrds(vectors.size);
NeighborQueue nn =
HnswGraph.search(
new float[] {1, 0},
10,
5,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
acceptOrds,
random());
int sum = 0;
for (int node : nn.nodes()) {
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
sum += node;
}
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) =
// 45
assertTrue("sum(result docs)=" + sum, sum < 75);
}
public void testBoundsCheckerMax() {
BoundsChecker max = BoundsChecker.create(false);
float f = random().nextFloat() - 0.5f;
@ -279,16 +311,21 @@ public class TestHnsw extends LuceneTestCase {
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
HnswGraph hnsw = builder.build(vectors);
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(size);
int totalMatches = 0;
for (int i = 0; i < 100; i++) {
float[] query = randomVector(random(), dim);
NeighborQueue actual =
HnswGraph.search(query, topK, 100, vectors, similarityFunction, hnsw, random());
HnswGraph.search(
query, topK, 100, vectors, similarityFunction, hnsw, acceptOrds, random());
NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
for (int j = 0; j < size; j++) {
float[] v = vectors.vectorValue(j);
if (v != null) {
expected.insertWithOverflow(j, similarityFunction.compare(query, vectors.vectorValue(j)));
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
if (expected.size() > topK) {
expected.pop();
}
}
}
assertEquals(topK, actual.size());
@ -455,6 +492,17 @@ public class TestHnsw extends LuceneTestCase {
}
}
/** Generate a random bitset where each entry has a 2/3 probability of being set. */
private static Bits createRandomAcceptOrds(int length) {
FixedBitSet bits = new FixedBitSet(length);
for (int i = 0; i < bits.length(); i++) {
if (random().nextFloat() < 0.667f) {
bits.set(i);
}
}
return bits;
}
private static float[] randomVector(Random random, int dim) {
float[] vec = new float[dim];
for (int i = 0; i < dim; i++) {

View File

@ -162,7 +162,7 @@ public class TermVectorLeafReader extends LeafReader {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k) {
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
return null;
}

View File

@ -1373,7 +1373,7 @@ public class MemoryIndex {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k) {
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
return null;
}

View File

@ -26,6 +26,7 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.TestUtil;
/** Wraps the default KnnVectorsFormat and provides additional assertions. */
@ -98,8 +99,8 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public TopDocs search(String field, float[] target, int k) throws IOException {
TopDocs hits = delegate.search(field, target, k);
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
TopDocs hits = delegate.search(field, target, k, acceptDocs);
assert hits != null;
assert hits.scoreDocs.length <= k;
return hits;

View File

@ -216,7 +216,7 @@ public class QueryUtils {
}
@Override
public TopDocs searchNearestVectors(String field, float[] target, int k) {
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
return null;
}