From 6993fb9a9985372b0f0984b8bdd7434aaa33ad26 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Mon, 16 Aug 2021 17:44:17 +0300 Subject: [PATCH] LUCENE-10040: Handle deletions in nearest vector search (#239) This PR extends VectorReader#search to take a parameter specifying the live docs. LeafReader#searchNearestVectors then always returns the k nearest undeleted docs. To implement this, the HNSW algorithm will only add a candidate to the result set if it is a live doc. The graph search still visits and traverses deleted docs as it gathers candidates. --- lucene/CHANGES.txt | 4 +- .../SimpleTextKnnVectorsReader.java | 3 +- .../lucene/codecs/KnnVectorsFormat.java | 3 +- .../lucene/codecs/KnnVectorsReader.java | 6 +- .../lucene90/Lucene90HnswVectorsReader.java | 21 +++++- .../perfield/PerFieldKnnVectorsFormat.java | 5 +- .../org/apache/lucene/index/CodecReader.java | 5 +- .../lucene/index/DocValuesLeafReader.java | 3 +- .../apache/lucene/index/FilterLeafReader.java | 5 +- .../org/apache/lucene/index/LeafReader.java | 4 +- .../lucene/index/MergeReaderWrapper.java | 5 +- .../lucene/index/ParallelLeafReader.java | 5 +- .../lucene/index/SlowCodecReaderWrapper.java | 5 +- .../lucene/index/SortingCodecReader.java | 2 +- .../apache/lucene/search/KnnVectorQuery.java | 4 +- .../apache/lucene/util/hnsw/HnswGraph.java | 23 ++++-- .../lucene/util/hnsw/HnswGraphBuilder.java | 3 +- .../lucene/util/hnsw/NeighborQueue.java | 7 -- .../TestPerFieldKnnVectorsFormat.java | 15 ++-- .../org/apache/lucene/index/TestKnnGraph.java | 4 +- .../index/TestSegmentToThreadMapping.java | 2 +- .../lucene/search/TestKnnVectorQuery.java | 74 +++++++++++++++++++ .../lucene/util/hnsw/KnnGraphTester.java | 4 +- .../{TestHnsw.java => TestHnswGraph.java} | 58 +++++++++++++-- .../highlight/TermVectorLeafReader.java | 2 +- .../lucene/index/memory/MemoryIndex.java | 2 +- .../asserting/AssertingKnnVectorsFormat.java | 5 +- .../org/apache/lucene/search/QueryUtils.java | 2 +- 28 files changed, 222 insertions(+), 59 deletions(-) rename lucene/core/src/test/org/apache/lucene/util/hnsw/{TestHnsw.java => TestHnswGraph.java} (89%) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 6460112c7d6..5f970ebc0fd 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -7,9 +7,9 @@ http://s.apache.org/luceneversions New Features -* LUCENE-9322 LUCENE-9855: Vector-valued fields, Lucene90 Codec (Mike Sokolov, Julie Tibshirani, Tomoko Uchida) +* LUCENE-9322, LUCENE-9855: Vector-valued fields, Lucene90 Codec (Mike Sokolov, Julie Tibshirani, Tomoko Uchida) -* LUCENE-9004: Approximate nearest vector search via NSW graphs (Mike Sokolov, Tomoko Uchida et al.) +* LUCENE-9004, LUCENE-10040: Approximate nearest vector search via NSW graphs (Mike Sokolov, Tomoko Uchida et al.) * LUCENE-9659: SpanPayloadCheckQuery now supports inequalities. (Kevin Watters, Gus Heck) diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index dcc85188255..7fdf266fe15 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -37,6 +37,7 @@ import org.apache.lucene.store.BufferedChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.IOUtils; @@ -138,7 +139,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { } @Override - public TopDocs search(String field, float[] target, int k) throws IOException { + public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java index 3d0f2640883..4b58f2dc6c0 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java @@ -23,6 +23,7 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.NamedSPILoader; /** @@ -99,7 +100,7 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI { } @Override - public TopDocs search(String field, float[] target, int k) { + public TopDocs search(String field, float[] target, int k, Bits acceptDocs) { return TopDocsCollector.EMPTY_TOPDOCS; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java index beca0060beb..b692ace84b3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java @@ -22,6 +22,7 @@ import java.io.IOException; import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.Bits; /** Reads vectors from an index. */ public abstract class KnnVectorsReader implements Closeable, Accountable { @@ -51,9 +52,12 @@ public abstract class KnnVectorsReader implements Closeable, Accountable { * @param field the vector field to search * @param target the vector-valued query * @param k the number of docs to return + * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} + * if they are all allowed to match. * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. */ - public abstract TopDocs search(String field, float[] target, int k) throws IOException; + public abstract TopDocs search(String field, float[] target, int k, Bits acceptDocs) + throws IOException; /** * Returns an instance optimized for merging. This instance may only be consumed in the thread diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java index 6a69ab9f4e2..70e386d0b61 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java @@ -43,6 +43,7 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; @@ -232,7 +233,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { } @Override - public TopDocs search(String field, float[] target, int k) throws IOException { + public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); if (fieldEntry == null || fieldEntry.dimension == 0) { return null; @@ -250,6 +251,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { vectorValues, fieldEntry.similarityFunction, getGraphValues(fieldEntry), + getAcceptOrds(acceptDocs, fieldEntry), random); int i = 0; ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; @@ -276,6 +278,23 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { return new OffHeapVectorValues(fieldEntry, bytesSlice); } + private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(fieldEntry.ordToDoc[index]); + } + + @Override + public int length() { + return fieldEntry.ordToDoc.length; + } + }; + } + public KnnGraphValues getGraphValues(String field) throws IOException { FieldInfo info = fieldInfos.fieldInfo(field); if (info == null) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java index 060f0328a1a..0e5cb00d8ce 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java @@ -33,6 +33,7 @@ import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; /** @@ -240,12 +241,12 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat { } @Override - public TopDocs search(String field, float[] target, int k) throws IOException { + public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException { KnnVectorsReader knnVectorsReader = fields.get(field); if (knnVectorsReader == null) { return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); } else { - return knnVectorsReader.search(field, target, k); + return knnVectorsReader.search(field, target, k, acceptDocs); } } diff --git a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java index b25087abc55..694205167cf 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java @@ -26,6 +26,7 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.util.Bits; /** LeafReader implemented by codec APIs. */ public abstract class CodecReader extends LeafReader { @@ -211,7 +212,7 @@ public abstract class CodecReader extends LeafReader { } @Override - public final TopDocs searchNearestVectors(String field, float[] target, int k) + public final TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); @@ -220,7 +221,7 @@ public abstract class CodecReader extends LeafReader { return null; } - return getVectorReader().search(field, target, k); + return getVectorReader().search(field, target, k, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java index 4f0ace4240e..f618c6cd7b7 100644 --- a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java @@ -53,7 +53,8 @@ abstract class DocValuesLeafReader extends LeafReader { } @Override - public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException { + public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) + throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java index c08a559c182..cba9b998e09 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java @@ -345,8 +345,9 @@ public abstract class FilterLeafReader extends LeafReader { } @Override - public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException { - return in.searchNearestVectors(field, target, k); + public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) + throws IOException { + return in.searchNearestVectors(field, target, k, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java index 95f41764df2..729db64b71e 100644 --- a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java @@ -222,10 +222,12 @@ public abstract class LeafReader extends IndexReader { * @param field the vector field to search * @param target the vector-valued query * @param k the number of docs to return + * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} + * if they are all allowed to match. * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. * @lucene.experimental */ - public abstract TopDocs searchNearestVectors(String field, float[] target, int k) + public abstract TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) throws IOException; /** diff --git a/lucene/core/src/java/org/apache/lucene/index/MergeReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/MergeReaderWrapper.java index 3d925de3779..ef4d462aeaf 100644 --- a/lucene/core/src/java/org/apache/lucene/index/MergeReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/MergeReaderWrapper.java @@ -209,8 +209,9 @@ class MergeReaderWrapper extends LeafReader { } @Override - public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException { - return in.searchNearestVectors(field, target, k); + public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) + throws IOException { + return in.searchNearestVectors(field, target, k, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java index 6ab727b3fb6..c8d10052008 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java @@ -398,10 +398,11 @@ public class ParallelLeafReader extends LeafReader { } @Override - public TopDocs searchNearestVectors(String fieldName, float[] target, int k) throws IOException { + public TopDocs searchNearestVectors(String fieldName, float[] target, int k, Bits acceptDocs) + throws IOException { ensureOpen(); LeafReader reader = fieldToReader.get(fieldName); - return reader == null ? null : reader.searchNearestVectors(fieldName, target, k); + return reader == null ? null : reader.searchNearestVectors(fieldName, target, k, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java index de629654fc4..3363dc04c05 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java @@ -167,8 +167,9 @@ public final class SlowCodecReaderWrapper { } @Override - public TopDocs search(String field, float[] target, int k) throws IOException { - return reader.searchNearestVectors(field, target, k); + public TopDocs search(String field, float[] target, int k, Bits acceptDocs) + throws IOException { + return reader.searchNearestVectors(field, target, k, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index 479df738093..f808c90ea1b 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -315,7 +315,7 @@ public final class SortingCodecReader extends FilterCodecReader { } @Override - public TopDocs search(String field, float[] target, int k) { + public TopDocs search(String field, float[] target, int k, Bits acceptDocs) { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java index 5dccb8042e4..6050920c851 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java @@ -26,6 +26,7 @@ import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.util.Bits; /** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */ public class KnnVectorQuery extends Query { @@ -70,7 +71,8 @@ public class KnnVectorQuery extends Query { } private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException { - TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf); + Bits liveDocs = ctx.reader().getLiveDocs(); + TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf, liveDocs); if (results == null) { return NO_RESULTS; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java index 49f2c95d1f8..d1f0420d3cf 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java @@ -26,6 +26,7 @@ import java.util.Random; import org.apache.lucene.index.KnnGraphValues; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.SparseFixedBitSet; /** @@ -83,6 +84,8 @@ public final class HnswGraph extends KnnGraphValues { * @param vectors vector values * @param graphValues the graph values. May represent the entire graph, or a level in a * hierarchical graph. + * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or + * {@code null} if they are all allowed to match. * @param random a source of randomness, used for generating entry points to the graph * @return a priority queue holding the closest neighbors found */ @@ -93,12 +96,15 @@ public final class HnswGraph extends KnnGraphValues { RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, KnnGraphValues graphValues, + Bits acceptOrds, Random random) throws IOException { int size = graphValues.size(); // MIN heap, holding the top results NeighborQueue results = new NeighborQueue(numSeed, similarityFunction.reversed); + // MAX heap, from which to pull the candidate nodes + NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed); // set of ordinals that have been visited by search on this layer, used to avoid backtracking SparseFixedBitSet visited = new SparseFixedBitSet(size); @@ -109,13 +115,14 @@ public final class HnswGraph extends KnnGraphValues { if (visited.get(entryPoint) == false) { visited.set(entryPoint); // explore the topK starting points of some random numSeed probes - results.add(entryPoint, similarityFunction.compare(query, vectors.vectorValue(entryPoint))); + float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint)); + candidates.add(entryPoint, score); + if (acceptOrds == null || acceptOrds.get(entryPoint)) { + results.add(entryPoint, score); + } } } - // MAX heap, from which to pull the candidate nodes - NeighborQueue candidates = results.copy(!similarityFunction.reversed); - // Set the bound to the worst current result and below reject any newly-generated candidates // failing // to exceed this bound @@ -138,10 +145,14 @@ public final class HnswGraph extends KnnGraphValues { continue; } visited.set(friendOrd); + float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd)); - if (results.insertWithOverflow(friendOrd, score)) { + if (results.size() < numSeed || bound.check(score) == false) { candidates.add(friendOrd, score); - bound.set(results.topScore()); + if (acceptOrds == null || acceptOrds.get(friendOrd)) { + results.insertWithOverflow(friendOrd, score); + bound.set(results.topScore()); + } } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index c7ff31a637a..d12a731ed98 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -134,9 +134,10 @@ public final class HnswGraphBuilder { /** Inserts a doc with vector value to the graph */ void addGraphNode(float[] value) throws IOException { + // We pass 'null' for acceptOrds because there are no deletions while building the graph NeighborQueue candidates = HnswGraph.search( - value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, random); + value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random); int node = hnsw.addNode(); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java index bab361bc707..4102dff3c27 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java @@ -42,13 +42,6 @@ public class NeighborQueue { } } - NeighborQueue copy(boolean reversed) { - int size = size(); - NeighborQueue copy = new NeighborQueue(size, reversed); - copy.heap.pushAll(heap); - return copy; - } - /** @return the number of elements in the heap */ public int size() { return heap.size(); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java index b170b0d6954..bc0eae5c6e3 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java @@ -38,6 +38,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.RandomCodec; import org.apache.lucene.index.SegmentReadState; @@ -101,19 +102,13 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase { // Double-check the vectors were written try (IndexReader ireader = DirectoryReader.open(directory)) { + LeafReader reader = ireader.leaves().get(0).reader(); TopDocs hits1 = - ireader - .leaves() - .get(0) - .reader() - .searchNearestVectors("field1", new float[] {1, 2, 3}, 10); + reader.searchNearestVectors("field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs()); assertEquals(1, hits1.scoreDocs.length); + TopDocs hits2 = - ireader - .leaves() - .get(0) - .reader() - .searchNearestVectors("field2", new float[] {1, 2, 3}, 10); + reader.searchNearestVectors("field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs()); assertEquals(1, hits2.scoreDocs.length); } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index ba43b7996f0..b035a2ff272 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -42,6 +42,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.VectorUtil; @@ -291,7 +292,8 @@ public class TestKnnGraph extends LuceneTestCase { private static TopDocs doKnnSearch(IndexReader reader, float[] vector, int k) throws IOException { TopDocs[] results = new TopDocs[reader.leaves().size()]; for (LeafReaderContext ctx : reader.leaves()) { - results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k); + Bits liveDocs = ctx.reader().getLiveDocs(); + results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs); if (ctx.docBase > 0) { for (ScoreDoc doc : results[ctx.ord].scoreDocs) { doc.doc += ctx.docBase; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java index 74888fd6d01..c76968f87da 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java @@ -112,7 +112,7 @@ public class TestSegmentToThreadMapping extends LuceneTestCase { } @Override - public TopDocs searchNearestVectors(String field, float[] target, int k) { + public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) { return null; } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java index 862f8f7c114..6443c860fd0 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java @@ -16,10 +16,13 @@ */ package org.apache.lucene.search; +import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.util.TestVectorUtil.randomVector; import java.io.IOException; +import java.util.HashSet; +import java.util.Set; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnVectorField; @@ -303,6 +306,77 @@ public class TestKnnVectorQuery extends LuceneTestCase { } } + public void testDeletes() throws IOException { + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + final int numDocs = atLeast(100); + final int dim = 30; + int docIndex = 0; + for (int i = 0; i < numDocs; ++i) { + Document d = new Document(); + if (frequently()) { + d.add(new StringField("index", String.valueOf(docIndex), Field.Store.YES)); + d.add(new KnnVectorField("vector", randomVector(dim))); + docIndex++; + } else { + d.add(new StringField("other", "value" + (i % 5), Field.Store.NO)); + } + w.addDocument(d); + } + w.commit(); + + // Delete some documents at random, both those with and without vectors + Set toDelete = new HashSet<>(); + for (int i = 0; i < 20; i++) { + int index = random().nextInt(docIndex); + toDelete.add(new Term("index", String.valueOf(index))); + } + w.deleteDocuments(toDelete.toArray(new Term[0])); + w.deleteDocuments(new Term("other", "value" + random().nextInt(5))); + w.commit(); + + try (IndexReader reader = DirectoryReader.open(dir)) { + Set allIds = new HashSet<>(); + IndexSearcher searcher = new IndexSearcher(reader); + KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs); + TopDocs topDocs = searcher.search(query, numDocs); + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + Document doc = reader.document(scoreDoc.doc, Set.of("index")); + String index = doc.get("index"); + assertFalse( + "search returned a deleted document: " + index, + toDelete.contains(new Term("index", index))); + allIds.add(index); + } + assertEquals("search missed some documents", docIndex - toDelete.size(), allIds.size()); + } + } + } + + public void testAllDeletes() throws IOException { + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + final int numDocs = atLeast(100); + final int dim = 30; + for (int i = 0; i < numDocs; ++i) { + Document d = new Document(); + d.add(new KnnVectorField("vector", randomVector(dim))); + w.addDocument(d); + } + w.commit(); + + w.deleteDocuments(new MatchAllDocsQuery()); + w.commit(); + + try (IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs); + TopDocs topDocs = searcher.search(query, numDocs); + assertEquals(0, topDocs.scoreDocs.length); + } + } + } + private Directory getIndexStore(String field, float[]... contents) throws IOException { Directory indexStore = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java index caf94feb870..fcdf0aa8dff 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java @@ -58,6 +58,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IntroSorter; import org.apache.lucene.util.PrintStreamInfoStream; @@ -424,7 +425,8 @@ public class KnnGraphTester { IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException { TopDocs[] results = new TopDocs[reader.leaves().size()]; for (LeafReaderContext ctx : reader.leaves()) { - results[ctx.ord] = ctx.reader().searchNearestVectors(field, vector, k + fanout); + Bits liveDocs = ctx.reader().getLiveDocs(); + results[ctx.ord] = ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs); int docBase = ctx.docBase; for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) { scoreDoc.doc += docBase; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java similarity index 89% rename from lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java rename to lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java index 676cae8c84c..bec0541a76f 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java @@ -45,12 +45,14 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorValues; import org.apache.lucene.store.Directory; import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.VectorUtil; /** Tests HNSW KNN graphs */ -public class TestHnsw extends LuceneTestCase { +public class TestHnswGraph extends LuceneTestCase { // test writing out and reading in a graph gives the expected graph public void testReadWrite() throws IOException { @@ -138,6 +140,7 @@ public class TestHnsw extends LuceneTestCase { vectors.randomAccess(), VectorSimilarityFunction.DOT_PRODUCT, hnsw, + null, random()); int sum = 0; for (int node : nn.nodes()) { @@ -156,6 +159,35 @@ public class TestHnsw extends LuceneTestCase { } } + public void testSearchWithAcceptOrds() throws IOException { + int nDoc = 100; + CircularVectorValues vectors = new CircularVectorValues(nDoc); + HnswGraphBuilder builder = + new HnswGraphBuilder( + vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt()); + HnswGraph hnsw = builder.build(vectors); + + Bits acceptOrds = createRandomAcceptOrds(vectors.size); + NeighborQueue nn = + HnswGraph.search( + new float[] {1, 0}, + 10, + 5, + vectors.randomAccess(), + VectorSimilarityFunction.DOT_PRODUCT, + hnsw, + acceptOrds, + random()); + int sum = 0; + for (int node : nn.nodes()) { + assertTrue("the results include a deleted document: " + node, acceptOrds.get(node)); + sum += node; + } + // We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) = + // 45 + assertTrue("sum(result docs)=" + sum, sum < 75); + } + public void testBoundsCheckerMax() { BoundsChecker max = BoundsChecker.create(false); float f = random().nextFloat() - 0.5f; @@ -279,16 +311,21 @@ public class TestHnsw extends LuceneTestCase { HnswGraphBuilder builder = new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong()); HnswGraph hnsw = builder.build(vectors); + Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(size); + int totalMatches = 0; for (int i = 0; i < 100; i++) { float[] query = randomVector(random(), dim); NeighborQueue actual = - HnswGraph.search(query, topK, 100, vectors, similarityFunction, hnsw, random()); + HnswGraph.search( + query, topK, 100, vectors, similarityFunction, hnsw, acceptOrds, random()); NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed); for (int j = 0; j < size; j++) { - float[] v = vectors.vectorValue(j); - if (v != null) { - expected.insertWithOverflow(j, similarityFunction.compare(query, vectors.vectorValue(j))); + if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) { + expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j))); + if (expected.size() > topK) { + expected.pop(); + } } } assertEquals(topK, actual.size()); @@ -455,6 +492,17 @@ public class TestHnsw extends LuceneTestCase { } } + /** Generate a random bitset where each entry has a 2/3 probability of being set. */ + private static Bits createRandomAcceptOrds(int length) { + FixedBitSet bits = new FixedBitSet(length); + for (int i = 0; i < bits.length(); i++) { + if (random().nextFloat() < 0.667f) { + bits.set(i); + } + } + return bits; + } + private static float[] randomVector(Random random, int dim) { float[] vec = new float[dim]; for (int i = 0; i < dim; i++) { diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java index a307d679fcd..8a3e992ea14 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java @@ -162,7 +162,7 @@ public class TermVectorLeafReader extends LeafReader { } @Override - public TopDocs searchNearestVectors(String field, float[] target, int k) { + public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) { return null; } diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index b22d4851a90..5b06e3f21d2 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -1373,7 +1373,7 @@ public class MemoryIndex { } @Override - public TopDocs searchNearestVectors(String field, float[] target, int k) { + public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) { return null; } diff --git a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingKnnVectorsFormat.java index 135a248f8e3..180c6dfcf93 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingKnnVectorsFormat.java @@ -26,6 +26,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.TestUtil; /** Wraps the default KnnVectorsFormat and provides additional assertions. */ @@ -98,8 +99,8 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat { } @Override - public TopDocs search(String field, float[] target, int k) throws IOException { - TopDocs hits = delegate.search(field, target, k); + public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException { + TopDocs hits = delegate.search(field, target, k, acceptDocs); assert hits != null; assert hits.scoreDocs.length <= k; return hits; diff --git a/lucene/test-framework/src/java/org/apache/lucene/search/QueryUtils.java b/lucene/test-framework/src/java/org/apache/lucene/search/QueryUtils.java index 06fa4bc1cc2..48ed7b261ae 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/search/QueryUtils.java +++ b/lucene/test-framework/src/java/org/apache/lucene/search/QueryUtils.java @@ -216,7 +216,7 @@ public class QueryUtils { } @Override - public TopDocs searchNearestVectors(String field, float[] target, int k) { + public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) { return null; }