Concurrent rewrite for KnnVectorQuery (#12160)

- Reduce overhead of non-concurrent search by preserving original execution
- Improve readability by factoring into separate functions

---------

Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
This commit is contained in:
Kaival Parikh 2023-03-04 14:42:11 +05:30 committed by GitHub
parent 569533bd76
commit e0d92eef98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 71 additions and 12 deletions

View File

@ -70,6 +70,8 @@ Optimizations
* GITHUB#11857, GITHUB#11859, GITHUB#11893, GITHUB#11909: Hunspell: improved suggestion performance (Peter Gromov)
* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)
Bug Fixes
---------------------

View File

@ -21,7 +21,11 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.FutureTask;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
@ -29,6 +33,7 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.ThreadInterruptedException;
/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
@ -62,9 +67,8 @@ abstract class AbstractKnnVectorQuery extends Query {
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
Weight filterWeight = null;
final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
@ -73,17 +77,16 @@ abstract class AbstractKnnVectorQuery extends Query {
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}
for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
perLeafResults[ctx.ord] = results;
}
Executor executor = indexSearcher.getExecutor();
TopDocs[] perLeafResults =
(executor == null)
? sequentialSearch(reader.leaves(), filterWeight)
: parallelSearch(reader.leaves(), filterWeight, executor);
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
@ -92,7 +95,54 @@ abstract class AbstractKnnVectorQuery extends Query {
return createRewrittenQuery(reader, topK);
}
private TopDocs[] sequentialSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight) {
try {
TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
for (LeafReaderContext ctx : leafReaderContexts) {
perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
}
return perLeafResults;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private TopDocs[] parallelSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
List<FutureTask<TopDocs>> tasks =
leafReaderContexts.stream()
.map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight)))
.toList();
SliceExecutor sliceExecutor = new SliceExecutor(executor);
sliceExecutor.invokeAll(tasks);
return tasks.stream()
.map(
task -> {
try {
return task.get();
} catch (ExecutionException e) {
throw new RuntimeException(e.getCause());
} catch (InterruptedException e) {
throw new ThreadInterruptedException(e);
}
})
.toArray(TopDocs[]::new);
}
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
}
private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();

View File

@ -210,7 +210,10 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
expectThrows(
RuntimeException.class,
IllegalArgumentException.class,
() -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
}
}
@ -495,6 +498,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
@ -509,6 +513,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
@ -536,6 +541,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
@ -708,6 +714,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(