From 6df6cb093cca7f93075bad131fbc4ad6a8ce5fef Mon Sep 17 00:00:00 2001 From: Kaival Parikh <46070017+kaivalnp@users.noreply.github.com> Date: Fri, 17 Jun 2022 00:31:54 +0530 Subject: [PATCH] LUCENE-10611: Fix Heap Error in HnswGraphSearcher (#958) The HNSW graph search does not consider that visitedLimit may be reached in the upper levels of graph search itself This occurs when the pre-filter is too restrictive (and its count sets the visitedLimit). So instead of switching over to exactSearch, it tries to pop from an empty heap and throws an error. We can check if results are incomplete after searching in upper levels, and break out accordingly. This way it won't throw heap errors, and gracefully switch to exactSearch instead --- lucene/CHANGES.txt | 2 ++ .../lucene/util/hnsw/HnswGraphSearcher.java | 7 ++++++- .../lucene/search/TestKnnVectorQuery.java | 17 +++++++++++++---- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 0e0dc71d86b..53c1ac5133e 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -110,6 +110,8 @@ Bug Fixes * LUCENE-10600: SortedSetDocValues#docValueCount should be an int, not long (Lu Xugang) +* LUCENE-10611: Fix Heap Error in HnswGraphSearcher (Kaival Parikh) + Other --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index b1a2436166f..ba88995bd3b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -87,10 +87,15 @@ public final class HnswGraphSearcher { int numVisited = 0; for (int level = graph.numLevels() - 1; level >= 1; level--) { results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); - eps[0] = results.pop(); numVisited += results.visitedCount(); visitedLimit -= results.visitedCount(); + + if (results.incomplete()) { + results.setVisitedCount(numVisited); + return results; + } + eps[0] = results.pop(); } results = graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java index 857f140aa65..ba9e6b5b5a7 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java @@ -521,7 +521,7 @@ public class TestKnnVectorQuery extends LuceneTestCase { /** Tests with random vectors and a random filter. Uses RandomIndexWriter. */ public void testRandomWithFilter() throws IOException { - int numDocs = 200; + int numDocs = 1000; int dimension = atLeast(5); int numIters = atLeast(10); try (Directory d = newDirectory()) { @@ -543,7 +543,7 @@ public class TestKnnVectorQuery extends LuceneTestCase { try (DirectoryReader reader = DirectoryReader.open(d)) { IndexSearcher searcher = newSearcher(reader); for (int i = 0; i < numIters; i++) { - int lower = random().nextInt(50); + int lower = random().nextInt(500); // Test a filter with cost less than k and check we use exact search Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 8); @@ -574,7 +574,7 @@ public class TestKnnVectorQuery extends LuceneTestCase { numDocs)); // Test an unrestrictive filter and check we use approximate search - Query filter3 = IntPoint.newRangeQuery("tag", lower, 200); + Query filter3 = IntPoint.newRangeQuery("tag", lower, numDocs); results = searcher.search( new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter3), @@ -588,8 +588,17 @@ public class TestKnnVectorQuery extends LuceneTestCase { assertEquals(1, fieldDoc.fields.length); int tag = (int) fieldDoc.fields[0]; - assertTrue(lower <= tag && tag <= 200); + assertTrue(lower <= tag && tag <= numDocs); } + + // Test a filter that exhausts visitedLimit in upper levels, and switches to exact search + Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2); + expectThrows( + UnsupportedOperationException.class, + () -> + searcher.search( + new ThrowingKnnVectorQuery("field", randomVector(dimension), 1, filter4), + numDocs)); } } }