diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index eb3a748384e..4d24255ab07 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -140,6 +140,33 @@ Other * GITHUB#9049: Fixing bug in UnescapedCharSequence#toStringEscaped() (Jakub Slowinski) +======================== Lucene 9.10.0 ======================= + +API Changes +--------------------- +(No changes) + +New Features +--------------------- +(No changes) + +Improvements +--------------------- +(No changes) + +Optimizations +--------------------- +(No changes) + +Bug Fixes +--------------------- + +* GITHUB#12866: Prevent extra similarity computation for single-level HNSW graphs. (Kaival Parikh) + +Other +--------------------- +(No changes) + ======================== Lucene 9.9.0 ======================= API Changes diff --git a/lucene/core/src/java/org/apache/lucene/util/Version.java b/lucene/core/src/java/org/apache/lucene/util/Version.java index f7c0ad42013..8c124888eb7 100644 --- a/lucene/core/src/java/org/apache/lucene/util/Version.java +++ b/lucene/core/src/java/org/apache/lucene/util/Version.java @@ -112,10 +112,17 @@ public final class Version { /** * Match settings and bugs in Lucene's 9.9.0 release. * - * @deprecated Use latest + * @deprecated (9.10.0) Use latest */ @Deprecated public static final Version LUCENE_9_9_0 = new Version(9, 9, 0); + /** + * Match settings and bugs in Lucene's 9.10.0 release. + * + * @deprecated Use latest + */ + @Deprecated public static final Version LUCENE_9_10_0 = new Version(9, 10, 0); + /** * Match settings and bugs in Lucene's 10.0.0 release. * diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 0135fc5a411..2aa5389ff58 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -100,19 +100,10 @@ public class HnswGraphSearcher { HnswGraphSearcher graphSearcher, Bits acceptOrds) throws IOException { - int initialEp = graph.entryNode(); - if (initialEp == -1) { - return; + int ep = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector); + if (ep != -1) { + graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds); } - int[] epAndVisited = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector.visitLimit()); - int numVisited = epAndVisited[1]; - int ep = epAndVisited[0]; - if (ep == -1) { - knnCollector.incVisitedCount(numVisited); - return; - } - knnCollector.incVisitedCount(numVisited); - graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds); } /** @@ -143,18 +134,21 @@ public class HnswGraphSearcher { * * @param scorer the scorer to compare the query with the nodes * @param graph the HNSWGraph - * @param visitLimit How many vectors are allowed to be visited - * @return An integer array whose first element is the best entry point, and second is the number - * of candidates visited. Entry point of `-1` indicates visitation limit exceed + * @param collector the knn result collector + * @return the best entry point, `-1` indicates graph entry node not set, or visitation limit + * exceeded * @throws IOException When accessing the vector fails */ - private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit) + private int findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, KnnCollector collector) throws IOException { - int size = getGraphSize(graph); - int visitedCount = 1; - prepareScratchState(size); int currentEp = graph.entryNode(); + if (currentEp == -1 || graph.numLevels() == 1) { + return currentEp; + } + int size = getGraphSize(graph); + prepareScratchState(size); float currentScore = scorer.score(currentEp); + collector.incVisitedCount(1); boolean foundBetter; for (int level = graph.numLevels() - 1; level >= 1; level--) { foundBetter = true; @@ -169,11 +163,11 @@ public class HnswGraphSearcher { if (visited.getAndSet(friendOrd)) { continue; } - if (visitedCount >= visitLimit) { - return new int[] {-1, visitedCount}; + if (collector.earlyTerminated()) { + return -1; } float friendSimilarity = scorer.score(friendOrd); - visitedCount++; + collector.incVisitedCount(1); if (friendSimilarity > currentScore) { currentScore = friendSimilarity; currentEp = friendOrd; @@ -182,7 +176,7 @@ public class HnswGraphSearcher { } } } - return new int[] {currentEp, visitedCount}; + return collector.earlyTerminated() ? -1 : currentEp; } /** diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index b943d3a2d50..0cc100712cf 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -70,6 +70,7 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -1026,6 +1027,37 @@ abstract class HnswGraphTestCase extends LuceneTestCase { } } + public void testAllNodesVisitedInSingleLevel() throws IOException { + int size = atLeast(100); + int dim = atLeast(50); + + // Search for a large number of results + int topK = size - 1; + + AbstractMockVectorValues docVectors = vectorValues(size, dim); + HnswGraph graph = + HnswGraphBuilder.create(buildScorerSupplier(docVectors), 10, 30, random().nextLong()) + .build(size); + + HnswGraph singleLevelGraph = + new DelegateHnswGraph(graph) { + @Override + public int numLevels() { + // Only retain the last level + return 1; + } + }; + + AbstractMockVectorValues queryVectors = vectorValues(1, dim); + RandomVectorScorer queryScorer = buildScorer(docVectors, queryVectors.vectorValue(0)); + + KnnCollector collector = new TopKnnCollector(topK, Integer.MAX_VALUE); + HnswGraphSearcher.search(queryScorer, collector, singleLevelGraph, null); + + // Check that we visit all nodes + assertEquals(graph.size(), collector.visitedCount()); + } + private int computeOverlap(int[] a, int[] b) { Arrays.sort(a); Arrays.sort(b); @@ -1297,4 +1329,42 @@ abstract class HnswGraphTestCase extends LuceneTestCase { return sb.toString(); } + + private static class DelegateHnswGraph extends HnswGraph { + final HnswGraph delegate; + + DelegateHnswGraph(HnswGraph delegate) { + this.delegate = delegate; + } + + @Override + public void seek(int level, int target) throws IOException { + delegate.seek(level, target); + } + + @Override + public int size() { + return delegate.size(); + } + + @Override + public int nextNeighbor() throws IOException { + return delegate.nextNeighbor(); + } + + @Override + public int numLevels() throws IOException { + return delegate.numLevels(); + } + + @Override + public int entryNode() throws IOException { + return delegate.entryNode(); + } + + @Override + public NodesIterator getNodesOnLevel(int level) throws IOException { + return delegate.getNodesOnLevel(level); + } + } }