diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java index 97fdd1c3999..45dbc0cd79e 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java @@ -144,7 +144,15 @@ public final class Lucene90HnswGraphBuilder { // We pass 'null' for acceptOrds because there are no deletions while building the graph NeighborQueue candidates = Lucene90OnHeapHnswGraph.search( - value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random); + value, + beamWidth, + beamWidth, + vectorValues, + similarityFunction, + hnsw, + null, + Integer.MAX_VALUE, + random); int node = hnsw.addNode(); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 2d2e2d57a1a..6d061534860 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -252,6 +252,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { fieldEntry.similarityFunction, getGraphValues(fieldEntry), getAcceptOrds(acceptDocs, fieldEntry), + visitedLimit, random); int i = 0; ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; @@ -261,11 +262,11 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { results.pop(); scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score); } - // always return >= the case where we can assert == is only when there are fewer than topK - // vectors in the index - return new TopDocs( - new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), - scoreDocs); + TotalHits.Relation relation = + results.incomplete() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); } private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java index 340bcf24199..9de59301abb 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java @@ -80,6 +80,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph { VectorSimilarityFunction similarityFunction, HnswGraph graphValues, Bits acceptOrds, + int visitedLimit, SplittableRandom random) throws IOException { int size = graphValues.size(); @@ -89,6 +90,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph { // MAX heap, from which to pull the candidate nodes NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed); + int numVisited = 0; // set of ordinals that have been visited by search on this layer, used to avoid backtracking SparseFixedBitSet visited = new SparseFixedBitSet(size); // get initial candidates at random @@ -96,12 +98,17 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph { for (int i = 0; i < boundedNumSeed; i++) { int entryPoint = random.nextInt(size); if (visited.getAndSet(entryPoint) == false) { + if (numVisited >= visitedLimit) { + results.markIncomplete(); + break; + } // explore the topK starting points of some random numSeed probes float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint)); candidates.add(entryPoint, score); if (acceptOrds == null || acceptOrds.get(entryPoint)) { results.add(entryPoint, score); } + numVisited++; } } @@ -110,7 +117,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph { // to exceed this bound BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed); bound.set(results.topScore()); - while (candidates.size() > 0) { + while (candidates.size() > 0 && results.incomplete() == false) { // get the best candidate (closest or best scoring) float topCandidateScore = candidates.topScore(); if (results.size() >= topK) { @@ -127,6 +134,11 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph { continue; } + if (numVisited >= visitedLimit) { + results.markIncomplete(); + break; + } + float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd)); if (results.size() < numSeed || bound.check(score) == false) { candidates.add(friendOrd, score); @@ -135,12 +147,13 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph { bound.set(results.topScore()); } } + numVisited++; } } while (results.size() > topK) { results.pop(); } - results.setVisitedCount(visited.approximateCardinality()); + results.setVisitedCount(numVisited); return results; } diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 617a67f49c1..081503a51f8 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -154,20 +154,31 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { FieldInfo info = readState.fieldInfos.fieldInfo(field); VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction(); HitQueue topK = new HitQueue(k, false); + + int numVisited = 0; + TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; + int doc; while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { if (acceptDocs != null && acceptDocs.get(doc) == false) { continue; } + + if (numVisited >= visitedLimit) { + relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + break; + } + float[] vector = values.vectorValue(); float score = vectorSimilarity.convertToScore(vectorSimilarity.compare(vector, target)); topK.insertWithOverflow(new ScoreDoc(doc, score)); + numVisited++; } ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()]; for (int i = topScoreDocs.length - 1; i >= 0; i--) { topScoreDocs[i] = topK.pop(); } - return new TopDocs(new TotalHits(values.size(), TotalHits.Relation.EQUAL_TO), topScoreDocs); + return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 69a2b74e6f5..e02fbcd9bda 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -140,12 +140,16 @@ public final class HnswGraphSearcher { int numVisited = 0; for (int ep : eps) { if (visited.getAndSet(ep) == false) { + if (numVisited >= visitedLimit) { + results.markIncomplete(); + break; + } float score = similarityFunction.compare(query, vectors.vectorValue(ep)); + numVisited++; candidates.add(ep, score); if (acceptOrds == null || acceptOrds.get(ep)) { results.add(ep, score); } - numVisited++; } } @@ -155,18 +159,13 @@ public final class HnswGraphSearcher { if (results.size() >= topK) { bound.set(results.topScore()); } - while (candidates.size() > 0) { + while (candidates.size() > 0 && results.incomplete() == false) { // get the best candidate (closest or best scoring) float topCandidateScore = candidates.topScore(); if (bound.check(topCandidateScore)) { break; } - if (numVisited >= visitedLimit) { - results.markIncomplete(); - break; - } - int topCandidateNode = candidates.pop(); graph.seek(level, topCandidateNode); int friendOrd; @@ -176,6 +175,10 @@ public final class HnswGraphSearcher { continue; } + if (numVisited >= visitedLimit) { + results.markIncomplete(); + break; + } float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd)); numVisited++; if (bound.check(score) == false) { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java index c0d327aa6ea..7c3aaa7f32b 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java @@ -44,6 +44,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; @@ -491,7 +492,11 @@ public class TestKnnVectorQuery extends LuceneTestCase { int dimension = atLeast(5); int numIters = atLeast(10); try (Directory d = newDirectory()) { - RandomIndexWriter w = new RandomIndexWriter(random(), d); + // Always use the default kNN format to have predictable behavior around when it hits + // visitedLimit. This is fine since the test targets KnnVectorQuery logic, not the kNN format + // implementation. + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); for (int i = 0; i < numDocs; i++) { Document doc = new Document(); doc.add(new KnnVectorField("field", randomVector(dimension))); @@ -690,8 +695,7 @@ public class TestKnnVectorQuery extends LuceneTestCase { } @Override - protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) - throws IOException { + protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) { throw new UnsupportedOperationException("exact search is not supported"); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java index 58b636f1ca6..a151d80b9c1 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java @@ -309,8 +309,8 @@ public class TestHnswGraph extends LuceneTestCase { createRandomAcceptOrds(0, vectors.size), visitedLimit); assertTrue(nn.incomplete()); - // The visited count shouldn't be much over the limit - assertTrue(nn.visitedCount() < visitedLimit + 3); + // The visited count shouldn't exceed the limit + assertTrue(nn.visitedCount() <= visitedLimit); } public void testBoundsCheckerMax() { diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 052679b3e78..95da6bb0554 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.Set; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; @@ -45,6 +46,7 @@ import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.tests.util.TestUtil; @@ -821,6 +823,68 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe } } + /** + * Tests whether {@link KnnVectorsReader#search} implementations obey the limit on the number of + * visited vectors. This test is a best-effort attempt to capture the right behavior, and isn't + * meant to define a strict requirement on behavior. + */ + public void testSearchWithVisitedLimit() throws Exception { + IndexWriterConfig iwc = newIndexWriterConfig(); + String fieldName = "field"; + try (Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, iwc)) { + int numDoc = atLeast(300); + int dimension = atLeast(10); + for (int i = 0; i < numDoc; i++) { + float[] value; + if (random().nextInt(7) != 3) { + // usually index a vector value for a doc + value = randomVector(dimension); + } else { + value = null; + } + add(iw, fieldName, i, value, VectorSimilarityFunction.EUCLIDEAN); + } + iw.forceMerge(1); + + // randomly delete some documents + for (int i = 0; i < 30; i++) { + int idToDelete = random().nextInt(numDoc); + iw.deleteDocuments(new Term("id", Integer.toString(idToDelete))); + } + + try (IndexReader reader = DirectoryReader.open(iw)) { + for (LeafReaderContext ctx : reader.leaves()) { + Bits liveDocs = ctx.reader().getLiveDocs(); + VectorValues vectorValues = ctx.reader().getVectorValues(fieldName); + if (vectorValues == null) { + continue; + } + + // check the limit is hit when it's very small + int k = 5 + random().nextInt(45); + int visitedLimit = k + random().nextInt(5); + TopDocs results = + ctx.reader() + .searchNearestVectors( + fieldName, randomVector(dimension), k, liveDocs, visitedLimit); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation); + assertEquals(visitedLimit, results.totalHits.value); + + // check the limit is not hit when it clearly exceeds the number of vectors + k = vectorValues.size(); + visitedLimit = k + 30; + results = + ctx.reader() + .searchNearestVectors( + fieldName, randomVector(dimension), k, liveDocs, visitedLimit); + assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation); + assertTrue(results.totalHits.value <= visitedLimit); + } + } + } + } + /** * Index random vectors, sometimes skipping documents, sometimes updating a document, sometimes * merging, sometimes sorting the index, using an HNSW similarity function so as to also produce a