mirror of https://github.com/apache/lucene.git
Prevent extra similarity computation for single-level graphs (#12866)
### Description [`#findBestEntryPoint`](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L151)
) is used to determine the entry point for the last level of HNSW search It finds the single best-scoring node from [all upper levels](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L159)
) - but performs an [unnecessary computation](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L157)
) (along with [recording one visited node](4bc7850465/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java (L154)
)) when the graph just has 1 level (so the entry node is just the overall graph's entry node) Also added a test to demonstrate this (fails without the changes in PR) -- where we visit `graph.size() + 1` nodes when the `topK` is high (should be a maximum of `graph.size()`) --------- Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
This commit is contained in:
parent
0e96b9cd8c
commit
65d30ca1af
|
@ -140,6 +140,33 @@ Other
|
||||||
|
|
||||||
* GITHUB#9049: Fixing bug in UnescapedCharSequence#toStringEscaped() (Jakub Slowinski)
|
* GITHUB#9049: Fixing bug in UnescapedCharSequence#toStringEscaped() (Jakub Slowinski)
|
||||||
|
|
||||||
|
======================== Lucene 9.10.0 =======================
|
||||||
|
|
||||||
|
API Changes
|
||||||
|
---------------------
|
||||||
|
(No changes)
|
||||||
|
|
||||||
|
New Features
|
||||||
|
---------------------
|
||||||
|
(No changes)
|
||||||
|
|
||||||
|
Improvements
|
||||||
|
---------------------
|
||||||
|
(No changes)
|
||||||
|
|
||||||
|
Optimizations
|
||||||
|
---------------------
|
||||||
|
(No changes)
|
||||||
|
|
||||||
|
Bug Fixes
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
* GITHUB#12866: Prevent extra similarity computation for single-level HNSW graphs. (Kaival Parikh)
|
||||||
|
|
||||||
|
Other
|
||||||
|
---------------------
|
||||||
|
(No changes)
|
||||||
|
|
||||||
======================== Lucene 9.9.0 =======================
|
======================== Lucene 9.9.0 =======================
|
||||||
|
|
||||||
API Changes
|
API Changes
|
||||||
|
|
|
@ -112,10 +112,17 @@ public final class Version {
|
||||||
/**
|
/**
|
||||||
* Match settings and bugs in Lucene's 9.9.0 release.
|
* Match settings and bugs in Lucene's 9.9.0 release.
|
||||||
*
|
*
|
||||||
* @deprecated Use latest
|
* @deprecated (9.10.0) Use latest
|
||||||
*/
|
*/
|
||||||
@Deprecated public static final Version LUCENE_9_9_0 = new Version(9, 9, 0);
|
@Deprecated public static final Version LUCENE_9_9_0 = new Version(9, 9, 0);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Match settings and bugs in Lucene's 9.10.0 release.
|
||||||
|
*
|
||||||
|
* @deprecated Use latest
|
||||||
|
*/
|
||||||
|
@Deprecated public static final Version LUCENE_9_10_0 = new Version(9, 10, 0);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Match settings and bugs in Lucene's 10.0.0 release.
|
* Match settings and bugs in Lucene's 10.0.0 release.
|
||||||
*
|
*
|
||||||
|
|
|
@ -100,20 +100,11 @@ public class HnswGraphSearcher {
|
||||||
HnswGraphSearcher graphSearcher,
|
HnswGraphSearcher graphSearcher,
|
||||||
Bits acceptOrds)
|
Bits acceptOrds)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
int initialEp = graph.entryNode();
|
int ep = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector);
|
||||||
if (initialEp == -1) {
|
if (ep != -1) {
|
||||||
return;
|
|
||||||
}
|
|
||||||
int[] epAndVisited = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector.visitLimit());
|
|
||||||
int numVisited = epAndVisited[1];
|
|
||||||
int ep = epAndVisited[0];
|
|
||||||
if (ep == -1) {
|
|
||||||
knnCollector.incVisitedCount(numVisited);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
knnCollector.incVisitedCount(numVisited);
|
|
||||||
graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds);
|
graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Searches for the nearest neighbors of a query vector in a given level.
|
* Searches for the nearest neighbors of a query vector in a given level.
|
||||||
|
@ -143,18 +134,21 @@ public class HnswGraphSearcher {
|
||||||
*
|
*
|
||||||
* @param scorer the scorer to compare the query with the nodes
|
* @param scorer the scorer to compare the query with the nodes
|
||||||
* @param graph the HNSWGraph
|
* @param graph the HNSWGraph
|
||||||
* @param visitLimit How many vectors are allowed to be visited
|
* @param collector the knn result collector
|
||||||
* @return An integer array whose first element is the best entry point, and second is the number
|
* @return the best entry point, `-1` indicates graph entry node not set, or visitation limit
|
||||||
* of candidates visited. Entry point of `-1` indicates visitation limit exceed
|
* exceeded
|
||||||
* @throws IOException When accessing the vector fails
|
* @throws IOException When accessing the vector fails
|
||||||
*/
|
*/
|
||||||
private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
|
private int findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, KnnCollector collector)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
int size = getGraphSize(graph);
|
|
||||||
int visitedCount = 1;
|
|
||||||
prepareScratchState(size);
|
|
||||||
int currentEp = graph.entryNode();
|
int currentEp = graph.entryNode();
|
||||||
|
if (currentEp == -1 || graph.numLevels() == 1) {
|
||||||
|
return currentEp;
|
||||||
|
}
|
||||||
|
int size = getGraphSize(graph);
|
||||||
|
prepareScratchState(size);
|
||||||
float currentScore = scorer.score(currentEp);
|
float currentScore = scorer.score(currentEp);
|
||||||
|
collector.incVisitedCount(1);
|
||||||
boolean foundBetter;
|
boolean foundBetter;
|
||||||
for (int level = graph.numLevels() - 1; level >= 1; level--) {
|
for (int level = graph.numLevels() - 1; level >= 1; level--) {
|
||||||
foundBetter = true;
|
foundBetter = true;
|
||||||
|
@ -169,11 +163,11 @@ public class HnswGraphSearcher {
|
||||||
if (visited.getAndSet(friendOrd)) {
|
if (visited.getAndSet(friendOrd)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (visitedCount >= visitLimit) {
|
if (collector.earlyTerminated()) {
|
||||||
return new int[] {-1, visitedCount};
|
return -1;
|
||||||
}
|
}
|
||||||
float friendSimilarity = scorer.score(friendOrd);
|
float friendSimilarity = scorer.score(friendOrd);
|
||||||
visitedCount++;
|
collector.incVisitedCount(1);
|
||||||
if (friendSimilarity > currentScore) {
|
if (friendSimilarity > currentScore) {
|
||||||
currentScore = friendSimilarity;
|
currentScore = friendSimilarity;
|
||||||
currentEp = friendOrd;
|
currentEp = friendOrd;
|
||||||
|
@ -182,7 +176,7 @@ public class HnswGraphSearcher {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return new int[] {currentEp, visitedCount};
|
return collector.earlyTerminated() ? -1 : currentEp;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -70,6 +70,7 @@ import org.apache.lucene.search.Sort;
|
||||||
import org.apache.lucene.search.SortField;
|
import org.apache.lucene.search.SortField;
|
||||||
import org.apache.lucene.search.TaskExecutor;
|
import org.apache.lucene.search.TaskExecutor;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.lucene.search.TopKnnCollector;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
|
@ -1026,6 +1027,37 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testAllNodesVisitedInSingleLevel() throws IOException {
|
||||||
|
int size = atLeast(100);
|
||||||
|
int dim = atLeast(50);
|
||||||
|
|
||||||
|
// Search for a large number of results
|
||||||
|
int topK = size - 1;
|
||||||
|
|
||||||
|
AbstractMockVectorValues<T> docVectors = vectorValues(size, dim);
|
||||||
|
HnswGraph graph =
|
||||||
|
HnswGraphBuilder.create(buildScorerSupplier(docVectors), 10, 30, random().nextLong())
|
||||||
|
.build(size);
|
||||||
|
|
||||||
|
HnswGraph singleLevelGraph =
|
||||||
|
new DelegateHnswGraph(graph) {
|
||||||
|
@Override
|
||||||
|
public int numLevels() {
|
||||||
|
// Only retain the last level
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
AbstractMockVectorValues<T> queryVectors = vectorValues(1, dim);
|
||||||
|
RandomVectorScorer queryScorer = buildScorer(docVectors, queryVectors.vectorValue(0));
|
||||||
|
|
||||||
|
KnnCollector collector = new TopKnnCollector(topK, Integer.MAX_VALUE);
|
||||||
|
HnswGraphSearcher.search(queryScorer, collector, singleLevelGraph, null);
|
||||||
|
|
||||||
|
// Check that we visit all nodes
|
||||||
|
assertEquals(graph.size(), collector.visitedCount());
|
||||||
|
}
|
||||||
|
|
||||||
private int computeOverlap(int[] a, int[] b) {
|
private int computeOverlap(int[] a, int[] b) {
|
||||||
Arrays.sort(a);
|
Arrays.sort(a);
|
||||||
Arrays.sort(b);
|
Arrays.sort(b);
|
||||||
|
@ -1297,4 +1329,42 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||||
|
|
||||||
return sb.toString();
|
return sb.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class DelegateHnswGraph extends HnswGraph {
|
||||||
|
final HnswGraph delegate;
|
||||||
|
|
||||||
|
DelegateHnswGraph(HnswGraph delegate) {
|
||||||
|
this.delegate = delegate;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void seek(int level, int target) throws IOException {
|
||||||
|
delegate.seek(level, target);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return delegate.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextNeighbor() throws IOException {
|
||||||
|
return delegate.nextNeighbor();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numLevels() throws IOException {
|
||||||
|
return delegate.numLevels();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int entryNode() throws IOException {
|
||||||
|
return delegate.entryNode();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NodesIterator getNodesOnLevel(int level) throws IOException {
|
||||||
|
return delegate.getNodesOnLevel(level);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue