Share concurrent execution code into TaskExecutor (#12398)

Lucene has a non-public SliceExecutor abstraction that handles the execution of tasks when search
is executed concurrently across leaf slices. Knn query vector rewrite has similar code that runs
tasks concurrently and waits for them to be completed and handles
eventual exceptions.

This commit shares code among these two scenarios, to reduce code
duplicate as well as to ensure that furhter improvements can be shared among them.
This commit is contained in:
Luca Cavanna 2023-06-28 13:52:01 +02:00 committed by GitHub
parent 4029cc37a7
commit f44cc45cf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 55 additions and 63 deletions

View File

@ -19,13 +19,13 @@ package org.apache.lucene.search;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
@ -33,7 +33,6 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.ThreadInterruptedException;
/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
@ -81,11 +80,11 @@ abstract class AbstractKnnVectorQuery extends Query {
filterWeight = null;
}
Executor executor = indexSearcher.getExecutor();
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
TopDocs[] perLeafResults =
(executor == null)
(taskExecutor == null)
? sequentialSearch(reader.leaves(), filterWeight)
: parallelSearch(reader.leaves(), filterWeight, executor);
: parallelSearch(reader.leaves(), filterWeight, taskExecutor);
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
@ -109,27 +108,12 @@ abstract class AbstractKnnVectorQuery extends Query {
}
private TopDocs[] parallelSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
List<FutureTask<TopDocs>> tasks =
leafReaderContexts.stream()
.map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight)))
.toList();
SliceExecutor sliceExecutor = new SliceExecutor(executor);
sliceExecutor.invokeAll(tasks);
return tasks.stream()
.map(
task -> {
try {
return task.get();
} catch (ExecutionException e) {
throw new RuntimeException(e.getCause());
} catch (InterruptedException e) {
throw new ThreadInterruptedException(e);
}
})
.toArray(TopDocs[]::new);
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, TaskExecutor taskExecutor) {
List<RunnableFuture<TopDocs>> tasks = new ArrayList<>();
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(new FutureTask<>(() -> searchLeaf(context, filterWeight)));
}
return taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
}
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {

View File

@ -24,10 +24,9 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.function.Supplier;
import org.apache.lucene.index.DirectoryReader;
@ -44,7 +43,6 @@ import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.store.NIOFSDirectory;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.ThreadInterruptedException;
import org.apache.lucene.util.automaton.ByteRunAutomaton;
/**
@ -123,7 +121,7 @@ public class IndexSearcher {
private final Executor executor;
// Used internally for load balancing threads executing for the query
private final SliceExecutor sliceExecutor;
private final TaskExecutor taskExecutor;
// the default Similarity
private static final Similarity defaultSimilarity = new BM25Similarity();
@ -226,14 +224,14 @@ public class IndexSearcher {
}
// Package private for testing
IndexSearcher(IndexReaderContext context, Executor executor, SliceExecutor sliceExecutor) {
IndexSearcher(IndexReaderContext context, Executor executor, TaskExecutor taskExecutor) {
assert context.isTopLevel
: "IndexSearcher's ReaderContext must be topLevel for reader" + context.reader();
assert (sliceExecutor == null) == (executor == null);
assert (taskExecutor == null) == (executor == null);
reader = context.reader();
this.executor = executor;
this.sliceExecutor = sliceExecutor;
this.taskExecutor = taskExecutor;
this.readerContext = context;
leafContexts = context.leaves();
this.leafSlices = executor == null ? null : slices(leafContexts);
@ -669,7 +667,7 @@ public class IndexSearcher {
"CollectorManager does not always produce collectors with the same score mode");
}
}
final List<FutureTask<C>> listTasks = new ArrayList<>();
final List<RunnableFuture<C>> listTasks = new ArrayList<>();
for (int i = 0; i < leafSlices.length; ++i) {
final LeafReaderContext[] leaves = leafSlices[i].leaves;
final C collector = collectors.get(i);
@ -682,19 +680,8 @@ public class IndexSearcher {
listTasks.add(task);
}
sliceExecutor.invokeAll(listTasks);
final List<C> collectedCollectors = new ArrayList<>();
for (Future<C> future : listTasks) {
try {
collectedCollectors.add(future.get());
} catch (InterruptedException e) {
throw new ThreadInterruptedException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}
return collectorManager.reduce(collectedCollectors);
List<C> results = taskExecutor.invokeAll(listTasks);
return collectorManager.reduce(results);
}
}
@ -910,7 +897,7 @@ public class IndexSearcher {
+ "; executor="
+ executor
+ "; sliceExecutionControlPlane "
+ sliceExecutor
+ taskExecutor
+ ")";
}
@ -962,6 +949,10 @@ public class IndexSearcher {
return executor;
}
TaskExecutor getTaskExecutor() {
return taskExecutor;
}
/**
* Thrown when an attempt is made to add more than {@link #getMaxClauseCount()} clauses. This
* typically happens if a PrefixQuery, FuzzyQuery, WildcardQuery, or TermRangeQuery is expanded to
@ -999,7 +990,7 @@ public class IndexSearcher {
}
/** Return the SliceExecutionControlPlane instance to be used for this IndexSearcher instance */
private static SliceExecutor getSliceExecutionControlPlane(Executor executor) {
private static TaskExecutor getSliceExecutionControlPlane(Executor executor) {
if (executor == null) {
return null;
}
@ -1008,6 +999,6 @@ public class IndexSearcher {
return new QueueSizeBasedExecutor((ThreadPoolExecutor) executor);
}
return new SliceExecutor(executor);
return new TaskExecutor(executor);
}
}

View File

@ -20,11 +20,11 @@ package org.apache.lucene.search;
import java.util.concurrent.ThreadPoolExecutor;
/**
* Derivative of SliceExecutor that controls the number of active threads that are used for a single
* Derivative of TaskExecutor that controls the number of active threads that are used for a single
* query. At any point, no more than (maximum pool size of the executor * LIMITING_FACTOR) tasks
* should be active. If the limit is exceeded, further segments are searched on the caller thread
*/
class QueueSizeBasedExecutor extends SliceExecutor {
class QueueSizeBasedExecutor extends TaskExecutor {
private static final double LIMITING_FACTOR = 1.5;
private final ThreadPoolExecutor threadPoolExecutor;

View File

@ -17,23 +17,29 @@
package org.apache.lucene.search;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.RunnableFuture;
import org.apache.lucene.util.ThreadInterruptedException;
/**
* Executor which is responsible for execution of slices based on the current status of the system
* and current system load
* Executor wrapper responsible for the execution of concurrent tasks. Used to parallelize search
* across segments as well as query rewrite in some cases.
*/
class SliceExecutor {
class TaskExecutor {
private final Executor executor;
SliceExecutor(Executor executor) {
TaskExecutor(Executor executor) {
this.executor = Objects.requireNonNull(executor, "Executor is null");
}
final void invokeAll(Collection<? extends Runnable> tasks) {
final <T> List<T> invokeAll(Collection<RunnableFuture<T>> tasks) {
int i = 0;
for (Runnable task : tasks) {
if (shouldExecuteOnCallerThread(i, tasks.size())) {
@ -49,6 +55,17 @@ class SliceExecutor {
}
++i;
}
final List<T> results = new ArrayList<>();
for (Future<T> future : tasks) {
try {
results.add(future.get());
} catch (InterruptedException e) {
throw new ThreadInterruptedException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e.getCause());
}
}
return results;
}
boolean shouldExecuteOnCallerThread(int index, int numTasks) {

View File

@ -412,12 +412,12 @@ public class TestIndexSearcher extends LuceneTestCase {
private void runSliceExecutorTest(ThreadPoolExecutor service, boolean useRandomSliceExecutor)
throws Exception {
SliceExecutor sliceExecutor =
TaskExecutor taskExecutor =
useRandomSliceExecutor == true
? new RandomBlockingSliceExecutor(service)
? new RandomBlockingTaskExecutor(service)
: new QueueSizeBasedExecutor(service);
IndexSearcher searcher = new IndexSearcher(reader.getContext(), service, sliceExecutor);
IndexSearcher searcher = new IndexSearcher(reader.getContext(), service, taskExecutor);
Query[] queries = new Query[] {new MatchAllDocsQuery(), new TermQuery(new Term("field", "1"))};
Sort[] sorts = new Sort[] {null, new Sort(new SortField("field2", SortField.Type.STRING))};
@ -453,9 +453,9 @@ public class TestIndexSearcher extends LuceneTestCase {
}
}
private static class RandomBlockingSliceExecutor extends SliceExecutor {
private static class RandomBlockingTaskExecutor extends TaskExecutor {
RandomBlockingSliceExecutor(Executor executor) {
RandomBlockingTaskExecutor(Executor executor) {
super(executor);
}