Prevent extra similarity computation for single-level graphs (#12866)

### Description

[`#findBestEntryPoint`](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L151)) is used to determine the entry point for the last level of HNSW search

It finds the single best-scoring node from [all upper levels](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L159)) - but performs an [unnecessary computation](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L157)) (along with [recording one visited node](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L154))) when the graph just has 1 level (so the entry node is just the overall graph's entry node)

Also added a test to demonstrate this (fails without the changes in PR) -- where we visit `graph.size() + 1` nodes when the `topK` is high (should be a maximum of `graph.size()`)

---------

Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
This commit is contained in:
Kaival Parikh 2023-12-01 23:32:00 +05:30 committed by GitHub
parent 0e96b9cd8c
commit 65d30ca1af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 24 deletions

View File

@ -140,6 +140,33 @@ Other
* GITHUB#9049: Fixing bug in UnescapedCharSequence#toStringEscaped() (Jakub Slowinski) * GITHUB#9049: Fixing bug in UnescapedCharSequence#toStringEscaped() (Jakub Slowinski)
======================== Lucene 9.10.0 =======================
API Changes
---------------------
(No changes)
New Features
---------------------
(No changes)
Improvements
---------------------
(No changes)
Optimizations
---------------------
(No changes)
Bug Fixes
---------------------
* GITHUB#12866: Prevent extra similarity computation for single-level HNSW graphs. (Kaival Parikh)
Other
---------------------
(No changes)
======================== Lucene 9.9.0 ======================= ======================== Lucene 9.9.0 =======================
API Changes API Changes

View File

@ -112,10 +112,17 @@ public final class Version {
/** /**
* Match settings and bugs in Lucene's 9.9.0 release. * Match settings and bugs in Lucene's 9.9.0 release.
* *
* @deprecated Use latest * @deprecated (9.10.0) Use latest
*/ */
@Deprecated public static final Version LUCENE_9_9_0 = new Version(9, 9, 0); @Deprecated public static final Version LUCENE_9_9_0 = new Version(9, 9, 0);
/**
* Match settings and bugs in Lucene's 9.10.0 release.
*
* @deprecated Use latest
*/
@Deprecated public static final Version LUCENE_9_10_0 = new Version(9, 10, 0);
/** /**
* Match settings and bugs in Lucene's 10.0.0 release. * Match settings and bugs in Lucene's 10.0.0 release.
* *

View File

@ -100,20 +100,11 @@ public class HnswGraphSearcher {
HnswGraphSearcher graphSearcher, HnswGraphSearcher graphSearcher,
Bits acceptOrds) Bits acceptOrds)
throws IOException { throws IOException {
int initialEp = graph.entryNode(); int ep = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector);
if (initialEp == -1) { if (ep != -1) {
return;
}
int[] epAndVisited = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector.visitLimit());
int numVisited = epAndVisited[1];
int ep = epAndVisited[0];
if (ep == -1) {
knnCollector.incVisitedCount(numVisited);
return;
}
knnCollector.incVisitedCount(numVisited);
graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds); graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds);
} }
}
/** /**
* Searches for the nearest neighbors of a query vector in a given level. * Searches for the nearest neighbors of a query vector in a given level.
@ -143,18 +134,21 @@ public class HnswGraphSearcher {
* *
* @param scorer the scorer to compare the query with the nodes * @param scorer the scorer to compare the query with the nodes
* @param graph the HNSWGraph * @param graph the HNSWGraph
* @param visitLimit How many vectors are allowed to be visited * @param collector the knn result collector
* @return An integer array whose first element is the best entry point, and second is the number * @return the best entry point, `-1` indicates graph entry node not set, or visitation limit
* of candidates visited. Entry point of `-1` indicates visitation limit exceed * exceeded
* @throws IOException When accessing the vector fails * @throws IOException When accessing the vector fails
*/ */
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit) private int findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, KnnCollector collector)
throws IOException { throws IOException {
int size = getGraphSize(graph);
int visitedCount = 1;
prepareScratchState(size);
int currentEp = graph.entryNode(); int currentEp = graph.entryNode();
if (currentEp == -1 || graph.numLevels() == 1) {
return currentEp;
}
int size = getGraphSize(graph);
prepareScratchState(size);
float currentScore = scorer.score(currentEp); float currentScore = scorer.score(currentEp);
collector.incVisitedCount(1);
boolean foundBetter; boolean foundBetter;
for (int level = graph.numLevels() - 1; level >= 1; level--) { for (int level = graph.numLevels() - 1; level >= 1; level--) {
foundBetter = true; foundBetter = true;
@ -169,11 +163,11 @@ public class HnswGraphSearcher {
if (visited.getAndSet(friendOrd)) { if (visited.getAndSet(friendOrd)) {
continue; continue;
} }
if (visitedCount >= visitLimit) { if (collector.earlyTerminated()) {
return new int[] {-1, visitedCount}; return -1;
} }
float friendSimilarity = scorer.score(friendOrd); float friendSimilarity = scorer.score(friendOrd);
visitedCount++; collector.incVisitedCount(1);
if (friendSimilarity > currentScore) { if (friendSimilarity > currentScore) {
currentScore = friendSimilarity; currentScore = friendSimilarity;
currentEp = friendOrd; currentEp = friendOrd;
@ -182,7 +176,7 @@ public class HnswGraphSearcher {
} }
} }
} }
return new int[] {currentEp, visitedCount}; return collector.earlyTerminated() ? -1 : currentEp;
} }
/** /**

View File

@ -70,6 +70,7 @@ import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
@ -1026,6 +1027,37 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
} }
} }
public void testAllNodesVisitedInSingleLevel() throws IOException {
int size = atLeast(100);
int dim = atLeast(50);
// Search for a large number of results
int topK = size - 1;
AbstractMockVectorValues<T> docVectors = vectorValues(size, dim);
HnswGraph graph =
HnswGraphBuilder.create(buildScorerSupplier(docVectors), 10, 30, random().nextLong())
.build(size);
HnswGraph singleLevelGraph =
new DelegateHnswGraph(graph) {
@Override
public int numLevels() {
// Only retain the last level
return 1;
}
};
AbstractMockVectorValues<T> queryVectors = vectorValues(1, dim);
RandomVectorScorer queryScorer = buildScorer(docVectors, queryVectors.vectorValue(0));
KnnCollector collector = new TopKnnCollector(topK, Integer.MAX_VALUE);
HnswGraphSearcher.search(queryScorer, collector, singleLevelGraph, null);
// Check that we visit all nodes
assertEquals(graph.size(), collector.visitedCount());
}
private int computeOverlap(int[] a, int[] b) { private int computeOverlap(int[] a, int[] b) {
Arrays.sort(a); Arrays.sort(a);
Arrays.sort(b); Arrays.sort(b);
@ -1297,4 +1329,42 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return sb.toString(); return sb.toString();
} }
private static class DelegateHnswGraph extends HnswGraph {
final HnswGraph delegate;
DelegateHnswGraph(HnswGraph delegate) {
this.delegate = delegate;
}
@Override
public void seek(int level, int target) throws IOException {
delegate.seek(level, target);
}
@Override
public int size() {
return delegate.size();
}
@Override
public int nextNeighbor() throws IOException {
return delegate.nextNeighbor();
}
@Override
public int numLevels() throws IOException {
return delegate.numLevels();
}
@Override
public int entryNode() throws IOException {
return delegate.entryNode();
}
@Override
public NodesIterator getNodesOnLevel(int level) throws IOException {
return delegate.getNodesOnLevel(level);
}
}
} }