diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 799b4e8e4d2..4e86d5aeff7 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -70,6 +70,8 @@ Optimizations * GITHUB#11857, GITHUB#11859, GITHUB#11893, GITHUB#11909: Hunspell: improved suggestion performance (Peter Gromov) +* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 39c2d34ffd2..04309561cbd 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -21,7 +21,11 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; import java.util.Arrays; import java.util.Comparator; +import java.util.List; import java.util.Objects; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.FutureTask; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; @@ -29,6 +33,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.ThreadInterruptedException; /** * Uses {@link KnnVectorsReader#search} to perform nearest neighbour search. @@ -62,9 +67,8 @@ abstract class AbstractKnnVectorQuery extends Query { @Override public Query rewrite(IndexSearcher indexSearcher) throws IOException { IndexReader reader = indexSearcher.getIndexReader(); - TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()]; - Weight filterWeight = null; + final Weight filterWeight; if (filter != null) { BooleanQuery booleanQuery = new BooleanQuery.Builder() @@ -73,17 +77,16 @@ abstract class AbstractKnnVectorQuery extends Query { .build(); Query rewritten = indexSearcher.rewrite(booleanQuery); filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); + } else { + filterWeight = null; } - for (LeafReaderContext ctx : reader.leaves()) { - TopDocs results = searchLeaf(ctx, filterWeight); - if (ctx.docBase > 0) { - for (ScoreDoc scoreDoc : results.scoreDocs) { - scoreDoc.doc += ctx.docBase; - } - } - perLeafResults[ctx.ord] = results; - } + Executor executor = indexSearcher.getExecutor(); + TopDocs[] perLeafResults = + (executor == null) + ? sequentialSearch(reader.leaves(), filterWeight) + : parallelSearch(reader.leaves(), filterWeight, executor); + // Merge sort the results TopDocs topK = TopDocs.merge(k, perLeafResults); if (topK.scoreDocs.length == 0) { @@ -92,7 +95,54 @@ abstract class AbstractKnnVectorQuery extends Query { return createRewrittenQuery(reader, topK); } + private TopDocs[] sequentialSearch( + List leafReaderContexts, Weight filterWeight) { + try { + TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()]; + for (LeafReaderContext ctx : leafReaderContexts) { + perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight); + } + return perLeafResults; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private TopDocs[] parallelSearch( + List leafReaderContexts, Weight filterWeight, Executor executor) { + List> tasks = + leafReaderContexts.stream() + .map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight))) + .toList(); + + SliceExecutor sliceExecutor = new SliceExecutor(executor); + sliceExecutor.invokeAll(tasks); + + return tasks.stream() + .map( + task -> { + try { + return task.get(); + } catch (ExecutionException e) { + throw new RuntimeException(e.getCause()); + } catch (InterruptedException e) { + throw new ThreadInterruptedException(e); + } + }) + .toArray(TopDocs[]::new); + } + private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException { + TopDocs results = getLeafResults(ctx, filterWeight); + if (ctx.docBase > 0) { + for (ScoreDoc scoreDoc : results.scoreDocs) { + scoreDoc.doc += ctx.docBase; + } + } + return results; + } + + private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException { Bits liveDocs = ctx.reader().getLiveDocs(); int maxDoc = ctx.reader().maxDoc(); diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index dbb13d9f058..f9356ae0bc1 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -210,7 +210,10 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { IndexSearcher searcher = newSearcher(reader); AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10); IllegalArgumentException e = - expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10)); + expectThrows( + RuntimeException.class, + IllegalArgumentException.class, + () -> searcher.search(kvq, 10)); assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage()); } } @@ -495,6 +498,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { assertEquals(9, results.totalHits.value); assertEquals(results.totalHits.value, results.scoreDocs.length); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search( @@ -509,6 +513,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { assertEquals(5, results.totalHits.value); assertEquals(results.totalHits.value, results.scoreDocs.length); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search( @@ -536,6 +541,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { // Test a filter that exhausts visitedLimit in upper levels, and switches to exact search Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search( @@ -708,6 +714,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs)); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search(