mirror of https://github.com/apache/lucene.git
LUCENE-10040: Handle deletions in nearest vector search (#239)
This PR extends VectorReader#search to take a parameter specifying the live docs. LeafReader#searchNearestVectors then always returns the k nearest undeleted docs. To implement this, the HNSW algorithm will only add a candidate to the result set if it is a live doc. The graph search still visits and traverses deleted docs as it gathers candidates.
This commit is contained in:
parent
19e5c00a4f
commit
6993fb9a99
|
@ -7,9 +7,9 @@ http://s.apache.org/luceneversions
|
|||
|
||||
New Features
|
||||
|
||||
* LUCENE-9322 LUCENE-9855: Vector-valued fields, Lucene90 Codec (Mike Sokolov, Julie Tibshirani, Tomoko Uchida)
|
||||
* LUCENE-9322, LUCENE-9855: Vector-valued fields, Lucene90 Codec (Mike Sokolov, Julie Tibshirani, Tomoko Uchida)
|
||||
|
||||
* LUCENE-9004: Approximate nearest vector search via NSW graphs (Mike Sokolov, Tomoko Uchida et al.)
|
||||
* LUCENE-9004, LUCENE-10040: Approximate nearest vector search via NSW graphs (Mike Sokolov, Tomoko Uchida et al.)
|
||||
|
||||
* LUCENE-9659: SpanPayloadCheckQuery now supports inequalities. (Kevin Watters, Gus Heck)
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.apache.lucene.store.BufferedChecksumIndexInput;
|
|||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.BytesRefBuilder;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
|
@ -138,7 +139,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k) throws IOException {
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.lucene.index.SegmentWriteState;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopDocsCollector;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.NamedSPILoader;
|
||||
|
||||
/**
|
||||
|
@ -99,7 +100,7 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k) {
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
|
||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.io.IOException;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/** Reads vectors from an index. */
|
||||
public abstract class KnnVectorsReader implements Closeable, Accountable {
|
||||
|
@ -51,9 +52,12 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
|||
* @param field the vector field to search
|
||||
* @param target the vector-valued query
|
||||
* @param k the number of docs to return
|
||||
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
|
||||
* if they are all allowed to match.
|
||||
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
|
||||
*/
|
||||
public abstract TopDocs search(String field, float[] target, int k) throws IOException;
|
||||
public abstract TopDocs search(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException;
|
||||
|
||||
/**
|
||||
* Returns an instance optimized for merging. This instance may only be consumed in the thread
|
||||
|
|
|
@ -43,6 +43,7 @@ import org.apache.lucene.search.TotalHits;
|
|||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
@ -232,7 +233,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k) throws IOException {
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null || fieldEntry.dimension == 0) {
|
||||
return null;
|
||||
|
@ -250,6 +251,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
vectorValues,
|
||||
fieldEntry.similarityFunction,
|
||||
getGraphValues(fieldEntry),
|
||||
getAcceptOrds(acceptDocs, fieldEntry),
|
||||
random);
|
||||
int i = 0;
|
||||
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
|
||||
|
@ -276,6 +278,23 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
return new OffHeapVectorValues(fieldEntry, bytesSlice);
|
||||
}
|
||||
|
||||
private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
|
||||
if (acceptDocs == null) {
|
||||
return null;
|
||||
}
|
||||
return new Bits() {
|
||||
@Override
|
||||
public boolean get(int index) {
|
||||
return acceptDocs.get(fieldEntry.ordToDoc[index]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int length() {
|
||||
return fieldEntry.ordToDoc.length;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public KnnGraphValues getGraphValues(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.apache.lucene.index.VectorValues;
|
|||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
|
||||
/**
|
||||
|
@ -240,12 +241,12 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k) throws IOException {
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
|
||||
KnnVectorsReader knnVectorsReader = fields.get(field);
|
||||
if (knnVectorsReader == null) {
|
||||
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
} else {
|
||||
return knnVectorsReader.search(field, target, k);
|
||||
return knnVectorsReader.search(field, target, k, acceptDocs);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.lucene.codecs.PointsReader;
|
|||
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/** LeafReader implemented by codec APIs. */
|
||||
public abstract class CodecReader extends LeafReader {
|
||||
|
@ -211,7 +212,7 @@ public abstract class CodecReader extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final TopDocs searchNearestVectors(String field, float[] target, int k)
|
||||
public final TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
ensureOpen();
|
||||
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||
|
@ -220,7 +221,7 @@ public abstract class CodecReader extends LeafReader {
|
|||
return null;
|
||||
}
|
||||
|
||||
return getVectorReader().search(field, target, k);
|
||||
return getVectorReader().search(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -53,7 +53,8 @@ abstract class DocValuesLeafReader extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException {
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -345,8 +345,9 @@ public abstract class FilterLeafReader extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException {
|
||||
return in.searchNearestVectors(field, target, k);
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -222,10 +222,12 @@ public abstract class LeafReader extends IndexReader {
|
|||
* @param field the vector field to search
|
||||
* @param target the vector-valued query
|
||||
* @param k the number of docs to return
|
||||
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
|
||||
* if they are all allowed to match.
|
||||
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract TopDocs searchNearestVectors(String field, float[] target, int k)
|
||||
public abstract TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException;
|
||||
|
||||
/**
|
||||
|
|
|
@ -209,8 +209,9 @@ class MergeReaderWrapper extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException {
|
||||
return in.searchNearestVectors(field, target, k);
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
return in.searchNearestVectors(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -398,10 +398,11 @@ public class ParallelLeafReader extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String fieldName, float[] target, int k) throws IOException {
|
||||
public TopDocs searchNearestVectors(String fieldName, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
ensureOpen();
|
||||
LeafReader reader = fieldToReader.get(fieldName);
|
||||
return reader == null ? null : reader.searchNearestVectors(fieldName, target, k);
|
||||
return reader == null ? null : reader.searchNearestVectors(fieldName, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -167,8 +167,9 @@ public final class SlowCodecReaderWrapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k) throws IOException {
|
||||
return reader.searchNearestVectors(field, target, k);
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
|
||||
throws IOException {
|
||||
return reader.searchNearestVectors(field, target, k, acceptDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -315,7 +315,7 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k) {
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
|||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
|
||||
public class KnnVectorQuery extends Query {
|
||||
|
@ -70,7 +71,8 @@ public class KnnVectorQuery extends Query {
|
|||
}
|
||||
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
|
||||
TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf, liveDocs);
|
||||
if (results == null) {
|
||||
return NO_RESULTS;
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import java.util.Random;
|
|||
import org.apache.lucene.index.KnnGraphValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.SparseFixedBitSet;
|
||||
|
||||
/**
|
||||
|
@ -83,6 +84,8 @@ public final class HnswGraph extends KnnGraphValues {
|
|||
* @param vectors vector values
|
||||
* @param graphValues the graph values. May represent the entire graph, or a level in a
|
||||
* hierarchical graph.
|
||||
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
|
||||
* {@code null} if they are all allowed to match.
|
||||
* @param random a source of randomness, used for generating entry points to the graph
|
||||
* @return a priority queue holding the closest neighbors found
|
||||
*/
|
||||
|
@ -93,12 +96,15 @@ public final class HnswGraph extends KnnGraphValues {
|
|||
RandomAccessVectorValues vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
KnnGraphValues graphValues,
|
||||
Bits acceptOrds,
|
||||
Random random)
|
||||
throws IOException {
|
||||
int size = graphValues.size();
|
||||
|
||||
// MIN heap, holding the top results
|
||||
NeighborQueue results = new NeighborQueue(numSeed, similarityFunction.reversed);
|
||||
// MAX heap, from which to pull the candidate nodes
|
||||
NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
|
||||
|
||||
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
|
||||
SparseFixedBitSet visited = new SparseFixedBitSet(size);
|
||||
|
@ -109,13 +115,14 @@ public final class HnswGraph extends KnnGraphValues {
|
|||
if (visited.get(entryPoint) == false) {
|
||||
visited.set(entryPoint);
|
||||
// explore the topK starting points of some random numSeed probes
|
||||
results.add(entryPoint, similarityFunction.compare(query, vectors.vectorValue(entryPoint)));
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint));
|
||||
candidates.add(entryPoint, score);
|
||||
if (acceptOrds == null || acceptOrds.get(entryPoint)) {
|
||||
results.add(entryPoint, score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MAX heap, from which to pull the candidate nodes
|
||||
NeighborQueue candidates = results.copy(!similarityFunction.reversed);
|
||||
|
||||
// Set the bound to the worst current result and below reject any newly-generated candidates
|
||||
// failing
|
||||
// to exceed this bound
|
||||
|
@ -138,10 +145,14 @@ public final class HnswGraph extends KnnGraphValues {
|
|||
continue;
|
||||
}
|
||||
visited.set(friendOrd);
|
||||
|
||||
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
||||
if (results.insertWithOverflow(friendOrd, score)) {
|
||||
if (results.size() < numSeed || bound.check(score) == false) {
|
||||
candidates.add(friendOrd, score);
|
||||
bound.set(results.topScore());
|
||||
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
|
||||
results.insertWithOverflow(friendOrd, score);
|
||||
bound.set(results.topScore());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -134,9 +134,10 @@ public final class HnswGraphBuilder {
|
|||
|
||||
/** Inserts a doc with vector value to the graph */
|
||||
void addGraphNode(float[] value) throws IOException {
|
||||
// We pass 'null' for acceptOrds because there are no deletions while building the graph
|
||||
NeighborQueue candidates =
|
||||
HnswGraph.search(
|
||||
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, random);
|
||||
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
|
||||
|
||||
int node = hnsw.addNode();
|
||||
|
||||
|
|
|
@ -42,13 +42,6 @@ public class NeighborQueue {
|
|||
}
|
||||
}
|
||||
|
||||
NeighborQueue copy(boolean reversed) {
|
||||
int size = size();
|
||||
NeighborQueue copy = new NeighborQueue(size, reversed);
|
||||
copy.heap.pushAll(heap);
|
||||
return copy;
|
||||
}
|
||||
|
||||
/** @return the number of elements in the heap */
|
||||
public int size() {
|
||||
return heap.size();
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.apache.lucene.index.FieldInfo;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.NoMergePolicy;
|
||||
import org.apache.lucene.index.RandomCodec;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
|
@ -101,19 +102,13 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
|||
|
||||
// Double-check the vectors were written
|
||||
try (IndexReader ireader = DirectoryReader.open(directory)) {
|
||||
LeafReader reader = ireader.leaves().get(0).reader();
|
||||
TopDocs hits1 =
|
||||
ireader
|
||||
.leaves()
|
||||
.get(0)
|
||||
.reader()
|
||||
.searchNearestVectors("field1", new float[] {1, 2, 3}, 10);
|
||||
reader.searchNearestVectors("field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs());
|
||||
assertEquals(1, hits1.scoreDocs.length);
|
||||
|
||||
TopDocs hits2 =
|
||||
ireader
|
||||
.leaves()
|
||||
.get(0)
|
||||
.reader()
|
||||
.searchNearestVectors("field2", new float[] {1, 2, 3}, 10);
|
||||
reader.searchNearestVectors("field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs());
|
||||
assertEquals(1, hits2.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,6 +42,7 @@ import org.apache.lucene.search.ScoreDoc;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
@ -291,7 +292,8 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
private static TopDocs doKnnSearch(IndexReader reader, float[] vector, int k) throws IOException {
|
||||
TopDocs[] results = new TopDocs[reader.leaves().size()];
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k);
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc doc : results[ctx.ord].scoreDocs) {
|
||||
doc.doc += ctx.docBase;
|
||||
|
|
|
@ -112,7 +112,7 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k) {
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,10 +16,13 @@
|
|||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.TestVectorUtil.randomVector;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
|
@ -303,6 +306,77 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testDeletes() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
final int numDocs = atLeast(100);
|
||||
final int dim = 30;
|
||||
int docIndex = 0;
|
||||
for (int i = 0; i < numDocs; ++i) {
|
||||
Document d = new Document();
|
||||
if (frequently()) {
|
||||
d.add(new StringField("index", String.valueOf(docIndex), Field.Store.YES));
|
||||
d.add(new KnnVectorField("vector", randomVector(dim)));
|
||||
docIndex++;
|
||||
} else {
|
||||
d.add(new StringField("other", "value" + (i % 5), Field.Store.NO));
|
||||
}
|
||||
w.addDocument(d);
|
||||
}
|
||||
w.commit();
|
||||
|
||||
// Delete some documents at random, both those with and without vectors
|
||||
Set<Term> toDelete = new HashSet<>();
|
||||
for (int i = 0; i < 20; i++) {
|
||||
int index = random().nextInt(docIndex);
|
||||
toDelete.add(new Term("index", String.valueOf(index)));
|
||||
}
|
||||
w.deleteDocuments(toDelete.toArray(new Term[0]));
|
||||
w.deleteDocuments(new Term("other", "value" + random().nextInt(5)));
|
||||
w.commit();
|
||||
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
Set<String> allIds = new HashSet<>();
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs);
|
||||
TopDocs topDocs = searcher.search(query, numDocs);
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
Document doc = reader.document(scoreDoc.doc, Set.of("index"));
|
||||
String index = doc.get("index");
|
||||
assertFalse(
|
||||
"search returned a deleted document: " + index,
|
||||
toDelete.contains(new Term("index", index)));
|
||||
allIds.add(index);
|
||||
}
|
||||
assertEquals("search missed some documents", docIndex - toDelete.size(), allIds.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testAllDeletes() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
final int numDocs = atLeast(100);
|
||||
final int dim = 30;
|
||||
for (int i = 0; i < numDocs; ++i) {
|
||||
Document d = new Document();
|
||||
d.add(new KnnVectorField("vector", randomVector(dim)));
|
||||
w.addDocument(d);
|
||||
}
|
||||
w.commit();
|
||||
|
||||
w.deleteDocuments(new MatchAllDocsQuery());
|
||||
w.commit();
|
||||
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs);
|
||||
TopDocs topDocs = searcher.search(query, numDocs);
|
||||
assertEquals(0, topDocs.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
||||
|
|
|
@ -58,6 +58,7 @@ import org.apache.lucene.search.ScoreDoc;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.FSDirectory;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IntroSorter;
|
||||
import org.apache.lucene.util.PrintStreamInfoStream;
|
||||
|
@ -424,7 +425,8 @@ public class KnnGraphTester {
|
|||
IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException {
|
||||
TopDocs[] results = new TopDocs[reader.leaves().size()];
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
results[ctx.ord] = ctx.reader().searchNearestVectors(field, vector, k + fanout);
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
results[ctx.ord] = ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs);
|
||||
int docBase = ctx.docBase;
|
||||
for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
|
||||
scoreDoc.doc += docBase;
|
||||
|
|
|
@ -45,12 +45,14 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
|||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
/** Tests HNSW KNN graphs */
|
||||
public class TestHnsw extends LuceneTestCase {
|
||||
public class TestHnswGraph extends LuceneTestCase {
|
||||
|
||||
// test writing out and reading in a graph gives the expected graph
|
||||
public void testReadWrite() throws IOException {
|
||||
|
@ -138,6 +140,7 @@ public class TestHnsw extends LuceneTestCase {
|
|||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
hnsw,
|
||||
null,
|
||||
random());
|
||||
int sum = 0;
|
||||
for (int node : nn.nodes()) {
|
||||
|
@ -156,6 +159,35 @@ public class TestHnsw extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testSearchWithAcceptOrds() throws IOException {
|
||||
int nDoc = 100;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
|
||||
Bits acceptOrds = createRandomAcceptOrds(vectors.size);
|
||||
NeighborQueue nn =
|
||||
HnswGraph.search(
|
||||
new float[] {1, 0},
|
||||
10,
|
||||
5,
|
||||
vectors.randomAccess(),
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
random());
|
||||
int sum = 0;
|
||||
for (int node : nn.nodes()) {
|
||||
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
||||
sum += node;
|
||||
}
|
||||
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) =
|
||||
// 45
|
||||
assertTrue("sum(result docs)=" + sum, sum < 75);
|
||||
}
|
||||
|
||||
public void testBoundsCheckerMax() {
|
||||
BoundsChecker max = BoundsChecker.create(false);
|
||||
float f = random().nextFloat() - 0.5f;
|
||||
|
@ -279,16 +311,21 @@ public class TestHnsw extends LuceneTestCase {
|
|||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(size);
|
||||
|
||||
int totalMatches = 0;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
float[] query = randomVector(random(), dim);
|
||||
NeighborQueue actual =
|
||||
HnswGraph.search(query, topK, 100, vectors, similarityFunction, hnsw, random());
|
||||
HnswGraph.search(
|
||||
query, topK, 100, vectors, similarityFunction, hnsw, acceptOrds, random());
|
||||
NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
|
||||
for (int j = 0; j < size; j++) {
|
||||
float[] v = vectors.vectorValue(j);
|
||||
if (v != null) {
|
||||
expected.insertWithOverflow(j, similarityFunction.compare(query, vectors.vectorValue(j)));
|
||||
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
||||
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
|
||||
if (expected.size() > topK) {
|
||||
expected.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
assertEquals(topK, actual.size());
|
||||
|
@ -455,6 +492,17 @@ public class TestHnsw extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
/** Generate a random bitset where each entry has a 2/3 probability of being set. */
|
||||
private static Bits createRandomAcceptOrds(int length) {
|
||||
FixedBitSet bits = new FixedBitSet(length);
|
||||
for (int i = 0; i < bits.length(); i++) {
|
||||
if (random().nextFloat() < 0.667f) {
|
||||
bits.set(i);
|
||||
}
|
||||
}
|
||||
return bits;
|
||||
}
|
||||
|
||||
private static float[] randomVector(Random random, int dim) {
|
||||
float[] vec = new float[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
|
@ -162,7 +162,7 @@ public class TermVectorLeafReader extends LeafReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k) {
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -1373,7 +1373,7 @@ public class MemoryIndex {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k) {
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.lucene.index.SegmentReadState;
|
|||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.TestUtil;
|
||||
|
||||
/** Wraps the default KnnVectorsFormat and provides additional assertions. */
|
||||
|
@ -98,8 +99,8 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k) throws IOException {
|
||||
TopDocs hits = delegate.search(field, target, k);
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
|
||||
TopDocs hits = delegate.search(field, target, k, acceptDocs);
|
||||
assert hits != null;
|
||||
assert hits.scoreDocs.length <= k;
|
||||
return hits;
|
||||
|
|
|
@ -216,7 +216,7 @@ public class QueryUtils {
|
|||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k) {
|
||||
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue