LUCENE-10611: Fix Heap Error in HnswGraphSearcher (#958)

The HNSW graph search does not consider that visitedLimit may be reached in the
upper levels of graph search itself

This occurs when the pre-filter is too restrictive (and its count sets the
visitedLimit). So instead of switching over to exactSearch, it tries to pop
from an empty heap and throws an error.

We can check if results are incomplete after searching in upper levels, and
break out accordingly. This way it won't throw heap errors, and gracefully
switch to exactSearch instead
This commit is contained in:
Kaival Parikh 2022-06-17 00:31:54 +05:30 committed by GitHub
parent 78b7b17f93
commit 6df6cb093c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 5 deletions

View File

@ -110,6 +110,8 @@ Bug Fixes
* LUCENE-10600: SortedSetDocValues#docValueCount should be an int, not long (Lu Xugang) * LUCENE-10600: SortedSetDocValues#docValueCount should be an int, not long (Lu Xugang)
* LUCENE-10611: Fix Heap Error in HnswGraphSearcher (Kaival Parikh)
Other Other
--------------------- ---------------------

View File

@ -87,10 +87,15 @@ public final class HnswGraphSearcher {
int numVisited = 0; int numVisited = 0;
for (int level = graph.numLevels() - 1; level >= 1; level--) { for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
eps[0] = results.pop();
numVisited += results.visitedCount(); numVisited += results.visitedCount();
visitedLimit -= results.visitedCount(); visitedLimit -= results.visitedCount();
if (results.incomplete()) {
results.setVisitedCount(numVisited);
return results;
}
eps[0] = results.pop();
} }
results = results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);

View File

@ -521,7 +521,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
/** Tests with random vectors and a random filter. Uses RandomIndexWriter. */ /** Tests with random vectors and a random filter. Uses RandomIndexWriter. */
public void testRandomWithFilter() throws IOException { public void testRandomWithFilter() throws IOException {
int numDocs = 200; int numDocs = 1000;
int dimension = atLeast(5); int dimension = atLeast(5);
int numIters = atLeast(10); int numIters = atLeast(10);
try (Directory d = newDirectory()) { try (Directory d = newDirectory()) {
@ -543,7 +543,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
try (DirectoryReader reader = DirectoryReader.open(d)) { try (DirectoryReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader); IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < numIters; i++) { for (int i = 0; i < numIters; i++) {
int lower = random().nextInt(50); int lower = random().nextInt(500);
// Test a filter with cost less than k and check we use exact search // Test a filter with cost less than k and check we use exact search
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 8); Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 8);
@ -574,7 +574,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
numDocs)); numDocs));
// Test an unrestrictive filter and check we use approximate search // Test an unrestrictive filter and check we use approximate search
Query filter3 = IntPoint.newRangeQuery("tag", lower, 200); Query filter3 = IntPoint.newRangeQuery("tag", lower, numDocs);
results = results =
searcher.search( searcher.search(
new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter3), new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter3),
@ -588,8 +588,17 @@ public class TestKnnVectorQuery extends LuceneTestCase {
assertEquals(1, fieldDoc.fields.length); assertEquals(1, fieldDoc.fields.length);
int tag = (int) fieldDoc.fields[0]; int tag = (int) fieldDoc.fields[0];
assertTrue(lower <= tag && tag <= 200); assertTrue(lower <= tag && tag <= numDocs);
} }
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
expectThrows(
UnsupportedOperationException.class,
() ->
searcher.search(
new ThrowingKnnVectorQuery("field", randomVector(dimension), 1, filter4),
numDocs));
} }
} }
} }