mirror of https://github.com/apache/lucene.git
Concurrent rewrite for KnnVectorQuery (#12160)
- Reduce overhead of non-concurrent search by preserving original execution - Improve readability by factoring into separate functions --------- Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
This commit is contained in:
parent
569533bd76
commit
e0d92eef98
|
@ -70,6 +70,8 @@ Optimizations
|
|||
|
||||
* GITHUB#11857, GITHUB#11859, GITHUB#11893, GITHUB#11909: Hunspell: improved suggestion performance (Peter Gromov)
|
||||
|
||||
* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -21,7 +21,11 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.FutureTask;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
|
@ -29,6 +33,7 @@ import org.apache.lucene.index.LeafReaderContext;
|
|||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.ThreadInterruptedException;
|
||||
|
||||
/**
|
||||
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
|
||||
|
@ -62,9 +67,8 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
@Override
|
||||
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
|
||||
IndexReader reader = indexSearcher.getIndexReader();
|
||||
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
|
||||
|
||||
Weight filterWeight = null;
|
||||
final Weight filterWeight;
|
||||
if (filter != null) {
|
||||
BooleanQuery booleanQuery =
|
||||
new BooleanQuery.Builder()
|
||||
|
@ -73,17 +77,16 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
.build();
|
||||
Query rewritten = indexSearcher.rewrite(booleanQuery);
|
||||
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
|
||||
} else {
|
||||
filterWeight = null;
|
||||
}
|
||||
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
TopDocs results = searchLeaf(ctx, filterWeight);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
}
|
||||
}
|
||||
perLeafResults[ctx.ord] = results;
|
||||
}
|
||||
Executor executor = indexSearcher.getExecutor();
|
||||
TopDocs[] perLeafResults =
|
||||
(executor == null)
|
||||
? sequentialSearch(reader.leaves(), filterWeight)
|
||||
: parallelSearch(reader.leaves(), filterWeight, executor);
|
||||
|
||||
// Merge sort the results
|
||||
TopDocs topK = TopDocs.merge(k, perLeafResults);
|
||||
if (topK.scoreDocs.length == 0) {
|
||||
|
@ -92,7 +95,54 @@ abstract class AbstractKnnVectorQuery extends Query {
|
|||
return createRewrittenQuery(reader, topK);
|
||||
}
|
||||
|
||||
private TopDocs[] sequentialSearch(
|
||||
List<LeafReaderContext> leafReaderContexts, Weight filterWeight) {
|
||||
try {
|
||||
TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
|
||||
for (LeafReaderContext ctx : leafReaderContexts) {
|
||||
perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
|
||||
}
|
||||
return perLeafResults;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private TopDocs[] parallelSearch(
|
||||
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
|
||||
List<FutureTask<TopDocs>> tasks =
|
||||
leafReaderContexts.stream()
|
||||
.map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight)))
|
||||
.toList();
|
||||
|
||||
SliceExecutor sliceExecutor = new SliceExecutor(executor);
|
||||
sliceExecutor.invokeAll(tasks);
|
||||
|
||||
return tasks.stream()
|
||||
.map(
|
||||
task -> {
|
||||
try {
|
||||
return task.get();
|
||||
} catch (ExecutionException e) {
|
||||
throw new RuntimeException(e.getCause());
|
||||
} catch (InterruptedException e) {
|
||||
throw new ThreadInterruptedException(e);
|
||||
}
|
||||
})
|
||||
.toArray(TopDocs[]::new);
|
||||
}
|
||||
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
|
||||
TopDocs results = getLeafResults(ctx, filterWeight);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
int maxDoc = ctx.reader().maxDoc();
|
||||
|
||||
|
|
|
@ -210,7 +210,10 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
IndexSearcher searcher = newSearcher(reader);
|
||||
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
|
||||
IllegalArgumentException e =
|
||||
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
|
||||
expectThrows(
|
||||
RuntimeException.class,
|
||||
IllegalArgumentException.class,
|
||||
() -> searcher.search(kvq, 10));
|
||||
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
@ -495,6 +498,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
assertEquals(9, results.totalHits.value);
|
||||
assertEquals(results.totalHits.value, results.scoreDocs.length);
|
||||
expectThrows(
|
||||
RuntimeException.class,
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
|
@ -509,6 +513,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
assertEquals(5, results.totalHits.value);
|
||||
assertEquals(results.totalHits.value, results.scoreDocs.length);
|
||||
expectThrows(
|
||||
RuntimeException.class,
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
|
@ -536,6 +541,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
|
||||
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
|
||||
expectThrows(
|
||||
RuntimeException.class,
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
|
@ -708,6 +714,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
|
||||
Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
|
||||
expectThrows(
|
||||
RuntimeException.class,
|
||||
UnsupportedOperationException.class,
|
||||
() ->
|
||||
searcher.search(
|
||||
|
|
Loading…
Reference in New Issue