mirror of https://github.com/apache/lucene.git
Prevent extra similarity computation for single-level graphs (#12866)
### Description [`#findBestEntryPoint`](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L151)
) is used to determine the entry point for the last level of HNSW search It finds the single best-scoring node from [all upper levels](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L159)
) - but performs an [unnecessary computation](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L157)
) (along with [recording one visited node](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L154)
)) when the graph just has 1 level (so the entry node is just the overall graph's entry node) Also added a test to demonstrate this (fails without the changes in PR) -- where we visit `graph.size() + 1` nodes when the `topK` is high (should be a maximum of `graph.size()`) --------- Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
This commit is contained in:
parent
0e96b9cd8c
commit
65d30ca1af
|
@ -140,6 +140,33 @@ Other
|
|||
|
||||
* GITHUB#9049: Fixing bug in UnescapedCharSequence#toStringEscaped() (Jakub Slowinski)
|
||||
|
||||
======================== Lucene 9.10.0 =======================
|
||||
|
||||
API Changes
|
||||
---------------------
|
||||
(No changes)
|
||||
|
||||
New Features
|
||||
---------------------
|
||||
(No changes)
|
||||
|
||||
Improvements
|
||||
---------------------
|
||||
(No changes)
|
||||
|
||||
Optimizations
|
||||
---------------------
|
||||
(No changes)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
|
||||
* GITHUB#12866: Prevent extra similarity computation for single-level HNSW graphs. (Kaival Parikh)
|
||||
|
||||
Other
|
||||
---------------------
|
||||
(No changes)
|
||||
|
||||
======================== Lucene 9.9.0 =======================
|
||||
|
||||
API Changes
|
||||
|
|
|
@ -112,10 +112,17 @@ public final class Version {
|
|||
/**
|
||||
* Match settings and bugs in Lucene's 9.9.0 release.
|
||||
*
|
||||
* @deprecated Use latest
|
||||
* @deprecated (9.10.0) Use latest
|
||||
*/
|
||||
@Deprecated public static final Version LUCENE_9_9_0 = new Version(9, 9, 0);
|
||||
|
||||
/**
|
||||
* Match settings and bugs in Lucene's 9.10.0 release.
|
||||
*
|
||||
* @deprecated Use latest
|
||||
*/
|
||||
@Deprecated public static final Version LUCENE_9_10_0 = new Version(9, 10, 0);
|
||||
|
||||
/**
|
||||
* Match settings and bugs in Lucene's 10.0.0 release.
|
||||
*
|
||||
|
|
|
@ -100,20 +100,11 @@ public class HnswGraphSearcher {
|
|||
HnswGraphSearcher graphSearcher,
|
||||
Bits acceptOrds)
|
||||
throws IOException {
|
||||
int initialEp = graph.entryNode();
|
||||
if (initialEp == -1) {
|
||||
return;
|
||||
}
|
||||
int[] epAndVisited = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector.visitLimit());
|
||||
int numVisited = epAndVisited[1];
|
||||
int ep = epAndVisited[0];
|
||||
if (ep == -1) {
|
||||
knnCollector.incVisitedCount(numVisited);
|
||||
return;
|
||||
}
|
||||
knnCollector.incVisitedCount(numVisited);
|
||||
int ep = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector);
|
||||
if (ep != -1) {
|
||||
graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Searches for the nearest neighbors of a query vector in a given level.
|
||||
|
@ -143,18 +134,21 @@ public class HnswGraphSearcher {
|
|||
*
|
||||
* @param scorer the scorer to compare the query with the nodes
|
||||
* @param graph the HNSWGraph
|
||||
* @param visitLimit How many vectors are allowed to be visited
|
||||
* @return An integer array whose first element is the best entry point, and second is the number
|
||||
* of candidates visited. Entry point of `-1` indicates visitation limit exceed
|
||||
* @param collector the knn result collector
|
||||
* @return the best entry point, `-1` indicates graph entry node not set, or visitation limit
|
||||
* exceeded
|
||||
* @throws IOException When accessing the vector fails
|
||||
*/
|
||||
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
|
||||
private int findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, KnnCollector collector)
|
||||
throws IOException {
|
||||
int size = getGraphSize(graph);
|
||||
int visitedCount = 1;
|
||||
prepareScratchState(size);
|
||||
int currentEp = graph.entryNode();
|
||||
if (currentEp == -1 || graph.numLevels() == 1) {
|
||||
return currentEp;
|
||||
}
|
||||
int size = getGraphSize(graph);
|
||||
prepareScratchState(size);
|
||||
float currentScore = scorer.score(currentEp);
|
||||
collector.incVisitedCount(1);
|
||||
boolean foundBetter;
|
||||
for (int level = graph.numLevels() - 1; level >= 1; level--) {
|
||||
foundBetter = true;
|
||||
|
@ -169,11 +163,11 @@ public class HnswGraphSearcher {
|
|||
if (visited.getAndSet(friendOrd)) {
|
||||
continue;
|
||||
}
|
||||
if (visitedCount >= visitLimit) {
|
||||
return new int[] {-1, visitedCount};
|
||||
if (collector.earlyTerminated()) {
|
||||
return -1;
|
||||
}
|
||||
float friendSimilarity = scorer.score(friendOrd);
|
||||
visitedCount++;
|
||||
collector.incVisitedCount(1);
|
||||
if (friendSimilarity > currentScore) {
|
||||
currentScore = friendSimilarity;
|
||||
currentEp = friendOrd;
|
||||
|
@ -182,7 +176,7 @@ public class HnswGraphSearcher {
|
|||
}
|
||||
}
|
||||
}
|
||||
return new int[] {currentEp, visitedCount};
|
||||
return collector.earlyTerminated() ? -1 : currentEp;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -70,6 +70,7 @@ import org.apache.lucene.search.Sort;
|
|||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopKnnCollector;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
|
@ -1026,6 +1027,37 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testAllNodesVisitedInSingleLevel() throws IOException {
|
||||
int size = atLeast(100);
|
||||
int dim = atLeast(50);
|
||||
|
||||
// Search for a large number of results
|
||||
int topK = size - 1;
|
||||
|
||||
AbstractMockVectorValues<T> docVectors = vectorValues(size, dim);
|
||||
HnswGraph graph =
|
||||
HnswGraphBuilder.create(buildScorerSupplier(docVectors), 10, 30, random().nextLong())
|
||||
.build(size);
|
||||
|
||||
HnswGraph singleLevelGraph =
|
||||
new DelegateHnswGraph(graph) {
|
||||
@Override
|
||||
public int numLevels() {
|
||||
// Only retain the last level
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
AbstractMockVectorValues<T> queryVectors = vectorValues(1, dim);
|
||||
RandomVectorScorer queryScorer = buildScorer(docVectors, queryVectors.vectorValue(0));
|
||||
|
||||
KnnCollector collector = new TopKnnCollector(topK, Integer.MAX_VALUE);
|
||||
HnswGraphSearcher.search(queryScorer, collector, singleLevelGraph, null);
|
||||
|
||||
// Check that we visit all nodes
|
||||
assertEquals(graph.size(), collector.visitedCount());
|
||||
}
|
||||
|
||||
private int computeOverlap(int[] a, int[] b) {
|
||||
Arrays.sort(a);
|
||||
Arrays.sort(b);
|
||||
|
@ -1297,4 +1329,42 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private static class DelegateHnswGraph extends HnswGraph {
|
||||
final HnswGraph delegate;
|
||||
|
||||
DelegateHnswGraph(HnswGraph delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void seek(int level, int target) throws IOException {
|
||||
delegate.seek(level, target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return delegate.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextNeighbor() throws IOException {
|
||||
return delegate.nextNeighbor();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numLevels() throws IOException {
|
||||
return delegate.numLevels();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int entryNode() throws IOException {
|
||||
return delegate.entryNode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) throws IOException {
|
||||
return delegate.getNodesOnLevel(level);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue