Create a task executor when executor is not provided (#12606)

As we introduce more places where we add concurrency (there are
currently three) there is a common pattern around checking whether there
is an executor provided, and then going sequential on the caller thread
or parallel relying on the executor.

That can be improved by internally creating a TaskExecutor that relies
on an executor that executes tasks on the caller thread, which ensures
that the task executor is never null, hence the common conditional is no
longer needed, as the concurrent path that uses the task executor would
be the default and only choice for operations that can be parallelized.
This commit is contained in:
Luca Cavanna 2023-10-03 09:13:45 +02:00 committed by GitHub
parent 1dd05c89b0
commit 2106bf5172
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 99 additions and 98 deletions

View File

@ -153,6 +153,9 @@ Improvements
* GITHUB#12603: Simplify the TaskExecutor API by exposing a single invokeAll method that takes a
collection of callables, executes them and returns their results (Luca Cavanna)
* GITHUB#12606: Create a TaskExecutor when an executor is not provided to the IndexSearcher, in
order to simplify consumer's code (Luca Cavanna)
Optimizations
---------------------
* GITHUB#12183: Make TermStates#build concurrent. (Shubham Chaudhary)

View File

@ -98,37 +98,26 @@ public final class TermStates {
final TermStates perReaderTermState = new TermStates(needsStats ? null : term, context);
if (needsStats) {
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
if (taskExecutor != null) {
// build the term states concurrently
List<Callable<TermStateInfo>> tasks = new ArrayList<>(context.leaves().size());
for (LeafReaderContext ctx : context.leaves()) {
tasks.add(
() -> {
TermsEnum termsEnum = loadTermsEnum(ctx, term);
return termsEnum == null
? null
: new TermStateInfo(
termsEnum.termState(),
ctx.ord,
termsEnum.docFreq(),
termsEnum.totalTermFreq());
});
}
List<TermStateInfo> resultInfos = taskExecutor.invokeAll(tasks);
for (TermStateInfo info : resultInfos) {
if (info != null) {
perReaderTermState.register(
info.getState(), info.getOrdinal(), info.getDocFreq(), info.getTotalTermFreq());
}
}
} else {
// build the term states sequentially
for (final LeafReaderContext ctx : context.leaves()) {
TermsEnum termsEnum = loadTermsEnum(ctx, term);
if (termsEnum != null) {
perReaderTermState.register(
termsEnum.termState(), ctx.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
}
// build the term states concurrently
List<Callable<TermStateInfo>> tasks = new ArrayList<>(context.leaves().size());
for (LeafReaderContext ctx : context.leaves()) {
tasks.add(
() -> {
TermsEnum termsEnum = loadTermsEnum(ctx, term);
return termsEnum == null
? null
: new TermStateInfo(
termsEnum.termState(),
ctx.ord,
termsEnum.docFreq(),
termsEnum.totalTermFreq());
});
}
List<TermStateInfo> resultInfos = taskExecutor.invokeAll(tasks);
for (TermStateInfo info : resultInfos) {
if (info != null) {
perReaderTermState.register(
info.getState(), info.getOrdinal(), info.getDocFreq(), info.getTotalTermFreq());
}
}
}

View File

@ -80,10 +80,12 @@ abstract class AbstractKnnVectorQuery extends Query {
}
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
TopDocs[] perLeafResults =
(taskExecutor == null)
? sequentialSearch(reader.leaves(), filterWeight)
: parallelSearch(reader.leaves(), filterWeight, taskExecutor);
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(() -> searchLeaf(context, filterWeight));
}
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
@ -93,25 +95,6 @@ abstract class AbstractKnnVectorQuery extends Query {
return createRewrittenQuery(reader, topK);
}
private TopDocs[] sequentialSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight) throws IOException {
TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
for (LeafReaderContext ctx : leafReaderContexts) {
perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
}
return perLeafResults;
}
private TopDocs[] parallelSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, TaskExecutor taskExecutor)
throws IOException {
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(() -> searchLeaf(context, filterWeight));
}
return taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
}
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight);
if (ctx.docBase > 0) {

View File

@ -119,10 +119,7 @@ public class IndexSearcher {
* method from constructor, which is a bad practice. This is {@code null} if no executor is
* provided
*/
private final CachingLeafSlicesSupplier leafSlicesSupplier;
// These are only used for multi-threaded search
private final Executor executor;
private final Supplier<LeafSlice[]> leafSlicesSupplier;
// Used internally for load balancing threads executing for the query
private final TaskExecutor taskExecutor;
@ -227,12 +224,18 @@ public class IndexSearcher {
assert context.isTopLevel
: "IndexSearcher's ReaderContext must be topLevel for reader" + context.reader();
reader = context.reader();
this.executor = executor;
this.taskExecutor = executor == null ? null : new TaskExecutor(executor);
this.taskExecutor =
executor == null ? new TaskExecutor(Runnable::run) : new TaskExecutor(executor);
this.readerContext = context;
leafContexts = context.leaves();
leafSlicesSupplier =
(executor == null) ? null : new CachingLeafSlicesSupplier(this::slices, leafContexts);
Function<List<LeafReaderContext>, LeafSlice[]> slicesProvider =
executor == null
? leaves ->
leaves.size() == 0
? new LeafSlice[0]
: new LeafSlice[] {new LeafSlice(new ArrayList<>(leaves))}
: this::slices;
leafSlicesSupplier = new CachingLeafSlicesSupplier(slicesProvider, leafContexts);
}
/**
@ -421,13 +424,12 @@ public class IndexSearcher {
}
/**
* Returns the leaf slices used for concurrent searching, or null if no {@code Executor} was
* passed to the constructor.
* Returns the leaf slices used for concurrent searching
*
* @lucene.experimental
*/
public LeafSlice[] getSlices() {
return (executor == null) ? null : leafSlicesSupplier.get();
return leafSlicesSupplier.get();
}
/**
@ -457,12 +459,12 @@ public class IndexSearcher {
new CollectorManager<TopScoreDocCollector, TopDocs>() {
private final HitsThresholdChecker hitsThresholdChecker =
(leafSlices == null || leafSlices.length <= 1)
leafSlices.length <= 1
? HitsThresholdChecker.create(Math.max(TOTAL_HITS_THRESHOLD, numHits))
: HitsThresholdChecker.createShared(Math.max(TOTAL_HITS_THRESHOLD, numHits));
private final MaxScoreAccumulator minScoreAcc =
(leafSlices == null || leafSlices.length <= 1) ? null : new MaxScoreAccumulator();
leafSlices.length <= 1 ? null : new MaxScoreAccumulator();
@Override
public TopScoreDocCollector newCollector() throws IOException {
@ -602,12 +604,12 @@ public class IndexSearcher {
new CollectorManager<>() {
private final HitsThresholdChecker hitsThresholdChecker =
(leafSlices == null || leafSlices.length <= 1)
leafSlices.length <= 1
? HitsThresholdChecker.create(Math.max(TOTAL_HITS_THRESHOLD, numHits))
: HitsThresholdChecker.createShared(Math.max(TOTAL_HITS_THRESHOLD, numHits));
private final MaxScoreAccumulator minScoreAcc =
(leafSlices == null || leafSlices.length <= 1) ? null : new MaxScoreAccumulator();
leafSlices.length <= 1 ? null : new MaxScoreAccumulator();
@Override
public TopFieldCollector newCollector() throws IOException {
@ -653,8 +655,10 @@ public class IndexSearcher {
private <C extends Collector, T> T search(
Weight weight, CollectorManager<C, T> collectorManager, C firstCollector) throws IOException {
final LeafSlice[] leafSlices = getSlices();
if (leafSlices == null || leafSlices.length == 0) {
search(leafContexts, weight, firstCollector);
if (leafSlices.length == 0) {
// there are no segments, nothing to offload to the executor, but we do need to call reduce to
// create some kind of empty result
assert leafContexts.size() == 0;
return collectorManager.reduce(Collections.singletonList(firstCollector));
} else {
final List<C> collectors = new ArrayList<>(leafSlices.length);
@ -893,13 +897,7 @@ public class IndexSearcher {
@Override
public String toString() {
return "IndexSearcher("
+ reader
+ "; executor="
+ executor
+ "; sliceExecutionControlPlane "
+ taskExecutor
+ ")";
return "IndexSearcher(" + reader + "; taskExecutor=" + taskExecutor + ")";
}
/**

View File

@ -110,4 +110,9 @@ public final class TaskExecutor {
}
}
}
@Override
public String toString() {
return "TaskExecutor(" + "executor=" + executor + ')';
}
}

View File

@ -42,7 +42,6 @@ import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.NamedThreadFactory;
import org.junit.Test;
public class TestIndexSearcher extends LuceneTestCase {
Directory dir;
@ -115,7 +114,6 @@ public class TestIndexSearcher extends LuceneTestCase {
TestUtil.shutdownExecutorService(service);
}
@Test
public void testSearchAfterPassedMaxDoc() throws Exception {
// LUCENE-5128: ensure we get a meaningful message if searchAfter exceeds maxDoc
Directory dir = newDirectory();
@ -221,30 +219,50 @@ public class TestIndexSearcher extends LuceneTestCase {
assertEquals(dummyPolicy, searcher.getQueryCachingPolicy());
}
public void testGetSlices() throws Exception {
assertNull(new IndexSearcher(new MultiReader()).getSlices());
public void testGetSlicesNoLeavesNoExecutor() throws IOException {
IndexSearcher.LeafSlice[] slices = new IndexSearcher(new MultiReader()).getSlices();
assertEquals(0, slices.length);
}
public void testGetSlicesNoLeavesWithExecutor() throws IOException {
IndexSearcher.LeafSlice[] slices =
new IndexSearcher(new MultiReader(), Runnable::run).getSlices();
assertEquals(0, slices.length);
}
public void testGetSlices() throws Exception {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
w.addDocument(new Document());
for (int i = 0; i < 10; i++) {
w.addDocument(new Document());
// manually flush, so we get to create multiple segments almost all the times, as well as
// multiple slices
w.flush();
}
IndexReader r = w.getReader();
w.close();
ExecutorService service =
new ThreadPoolExecutor(
4,
4,
0L,
TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(),
new NamedThreadFactory("TestIndexSearcher"));
IndexSearcher s = new IndexSearcher(r, service);
IndexSearcher.LeafSlice[] slices = s.getSlices();
assertNotNull(slices);
assertEquals(1, slices.length);
assertEquals(1, slices[0].leaves.length);
assertTrue(slices[0].leaves[0] == r.leaves().get(0));
service.shutdown();
{
// without executor
IndexSearcher.LeafSlice[] slices = new IndexSearcher(r).getSlices();
assertEquals(1, slices.length);
assertEquals(r.leaves().size(), slices[0].leaves.length);
}
{
// force creation of multiple slices, and provide an executor
IndexSearcher searcher =
new IndexSearcher(r, Runnable::run) {
@Override
protected LeafSlice[] slices(List<LeafReaderContext> leaves) {
return slices(leaves, 1, 1);
}
};
IndexSearcher.LeafSlice[] slices = searcher.getSlices();
for (IndexSearcher.LeafSlice slice : slices) {
assertEquals(1, slice.leaves.length);
}
assertEquals(r.leaves().size(), slices.length);
}
IOUtils.close(r, dir);
}
@ -270,4 +288,9 @@ public class TestIndexSearcher extends LuceneTestCase {
searcher.search(new MatchAllDocsQuery(), 10);
assertEquals(leaves.size(), numExecutions.get());
}
public void testNullExecutorNonNullTaskExecutor() {
IndexSearcher indexSearcher = new IndexSearcher(reader);
assertNotNull(indexSearcher.getTaskExecutor());
}
}