Prevent concurrent tasks from parallelizing further (#12569)

Concurrent search is currently applied once per search call, either when
search is called, or when concurrent query rewrite happens. They
generally don't happen within one another. There are situations in which
we are going to introduce parallelism in places where there could be
multiple inner levels of parallelism requested as each task could try to
parallelize further. In these cases, with certain executor
implementations, like ThreadPoolExecutor, we may deadlock as we are
waiting for all tasks to complete but they are waiting for threads to
free up to complete their execution.

This commit introduces a simple safeguard that makes sure that we only
parallelize via the executor at the top-level invokeAll call. When each
task tries to parallelize further, we just execute them directly instead
of submitting them to the executor.

Co-authored-by: Adrien Grand <jpountz@gmail.com>
This commit is contained in:
Luca Cavanna 2023-09-20 12:00:13 +02:00 committed by GitHub
parent 51ade888f3
commit 937ebd4296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 191 additions and 24 deletions

View File

@ -239,6 +239,9 @@ Changes in runtime behavior
* GITHUB#12515: Offload sequential search execution to the executor that's optionally provided to the IndexSearcher
(Luca Cavanna)
* GITHUB#12569: Prevent concurrent tasks from parallelizing execution further which could cause deadlock
(Luca Cavanna)
Bug Fixes
---------------------

View File

@ -24,8 +24,6 @@ import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
@ -106,9 +104,9 @@ abstract class AbstractKnnVectorQuery extends Query {
private TopDocs[] parallelSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, TaskExecutor taskExecutor)
throws IOException {
List<RunnableFuture<TopDocs>> tasks = new ArrayList<>();
List<TaskExecutor.Task<TopDocs>> tasks = new ArrayList<>();
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(new FutureTask<>(() -> searchLeaf(context, filterWeight)));
tasks.add(taskExecutor.createTask(() -> searchLeaf(context, filterWeight)));
}
return taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
}

View File

@ -25,8 +25,6 @@ import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.lucene.index.DirectoryReader;
@ -669,17 +667,16 @@ public class IndexSearcher {
"CollectorManager does not always produce collectors with the same score mode");
}
}
final List<RunnableFuture<C>> listTasks = new ArrayList<>();
final List<TaskExecutor.Task<C>> listTasks = new ArrayList<>();
for (int i = 0; i < leafSlices.length; ++i) {
final LeafReaderContext[] leaves = leafSlices[i].leaves;
final C collector = collectors.get(i);
FutureTask<C> task =
new FutureTask<>(
TaskExecutor.Task<C> task =
taskExecutor.createTask(
() -> {
search(Arrays.asList(leaves), weight, collector);
return collector;
});
listTasks.add(task);
}
List<C> results = taskExecutor.invokeAll(listTasks);

View File

@ -22,18 +22,31 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.FutureTask;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.ThreadInterruptedException;
/**
* Executor wrapper responsible for the execution of concurrent tasks. Used to parallelize search
* across segments as well as query rewrite in some cases.
* across segments as well as query rewrite in some cases. Exposes a {@link #createTask(Callable)}
* method to create tasks given a {@link Callable}, as well as the {@link #invokeAll(Collection)}
* method to execute a set of tasks concurrently. Once all tasks are submitted to the executor, it
* blocks and wait for all tasks to be completed, and then returns a list with the obtained results.
* Ensures that the underlying executor is only used for top-level {@link #invokeAll(Collection)}
* calls, and not for potential {@link #invokeAll(Collection)} calls made from one of the tasks.
* This is to prevent deadlock with certain types of pool based executors (e.g. {@link
* java.util.concurrent.ThreadPoolExecutor}).
*/
class TaskExecutor {
// a static thread local is ok as long as we use a counter, which accounts for multiple
// searchers holding a different TaskExecutor all backed by the same executor
private static final ThreadLocal<Integer> numberOfRunningTasksInCurrentThread =
ThreadLocal.withInitial(() -> 0);
private final Executor executor;
TaskExecutor(Executor executor) {
@ -48,10 +61,17 @@ class TaskExecutor {
* @return a list containing the results from the tasks execution
* @param <T> the return type of the task execution
*/
final <T> List<T> invokeAll(Collection<RunnableFuture<T>> tasks) throws IOException {
for (Runnable task : tasks) {
executor.execute(task);
final <T> List<T> invokeAll(Collection<Task<T>> tasks) throws IOException {
if (numberOfRunningTasksInCurrentThread.get() > 0) {
for (Task<T> task : tasks) {
task.run();
}
} else {
for (Runnable task : tasks) {
executor.execute(task);
}
}
final List<T> results = new ArrayList<>();
for (Future<T> future : tasks) {
try {
@ -64,4 +84,26 @@ class TaskExecutor {
}
return results;
}
final <C> Task<C> createTask(Callable<C> callable) {
return new Task<>(callable);
}
static class Task<V> extends FutureTask<V> {
private Task(Callable<V> callable) {
super(callable);
}
@Override
public void run() {
try {
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter + 1);
super.run();
} finally {
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter - 1);
}
}
}
}

View File

@ -17,10 +17,16 @@
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.NamedThreadFactory;
import org.junit.AfterClass;
@ -44,8 +50,8 @@ public class TestTaskExecutor extends LuceneTestCase {
public void testUnwrapIOExceptionFromExecutionException() {
TaskExecutor taskExecutor = new TaskExecutor(executorService);
FutureTask<?> task =
new FutureTask<>(
TaskExecutor.Task<?> task =
taskExecutor.createTask(
() -> {
throw new IOException("io exception");
});
@ -57,8 +63,8 @@ public class TestTaskExecutor extends LuceneTestCase {
public void testUnwrapRuntimeExceptionFromExecutionException() {
TaskExecutor taskExecutor = new TaskExecutor(executorService);
FutureTask<?> task =
new FutureTask<>(
TaskExecutor.Task<?> task =
taskExecutor.createTask(
() -> {
throw new RuntimeException("runtime");
});
@ -71,8 +77,8 @@ public class TestTaskExecutor extends LuceneTestCase {
public void testUnwrapErrorFromExecutionException() {
TaskExecutor taskExecutor = new TaskExecutor(executorService);
FutureTask<?> task =
new FutureTask<>(
TaskExecutor.Task<?> task =
taskExecutor.createTask(
() -> {
throw new OutOfMemoryError("oom");
});
@ -85,8 +91,8 @@ public class TestTaskExecutor extends LuceneTestCase {
public void testUnwrappedExceptions() {
TaskExecutor taskExecutor = new TaskExecutor(executorService);
FutureTask<?> task =
new FutureTask<>(
TaskExecutor.Task<?> task =
taskExecutor.createTask(
() -> {
throw new Exception("exc");
});
@ -95,4 +101,125 @@ public class TestTaskExecutor extends LuceneTestCase {
RuntimeException.class, () -> taskExecutor.invokeAll(Collections.singletonList(task)));
assertEquals("exc", runtimeException.getCause().getMessage());
}
public void testInvokeAllFromTaskDoesNotDeadlockSameSearcher() throws IOException {
try (Directory dir = newDirectory();
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
for (int i = 0; i < 500; i++) {
iw.addDocument(new Document());
}
try (DirectoryReader reader = iw.getReader()) {
IndexSearcher searcher =
new IndexSearcher(reader, executorService) {
@Override
protected LeafSlice[] slices(List<LeafReaderContext> leaves) {
return slices(leaves, 1, 1);
}
};
searcher.search(
new MatchAllDocsQuery(),
new CollectorManager<Collector, Void>() {
@Override
public Collector newCollector() {
return new Collector() {
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) {
return new LeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {
TaskExecutor.Task<Void> task =
searcher
.getTaskExecutor()
.createTask(
() -> {
// make sure that we don't miss disabling concurrency one
// level deeper
TaskExecutor.Task<Object> anotherTask =
searcher.getTaskExecutor().createTask(() -> null);
searcher
.getTaskExecutor()
.invokeAll(Collections.singletonList(anotherTask));
return null;
});
searcher.getTaskExecutor().invokeAll(Collections.singletonList(task));
}
@Override
public void collect(int doc) {}
};
}
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE;
}
};
}
@Override
public Void reduce(Collection<Collector> collectors) {
return null;
}
});
}
}
}
public void testInvokeAllFromTaskDoesNotDeadlockMultipleSearchers() throws IOException {
try (Directory dir = newDirectory();
RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
for (int i = 0; i < 500; i++) {
iw.addDocument(new Document());
}
try (DirectoryReader reader = iw.getReader()) {
IndexSearcher searcher =
new IndexSearcher(reader, executorService) {
@Override
protected LeafSlice[] slices(List<LeafReaderContext> leaves) {
return slices(leaves, 1, 1);
}
};
searcher.search(
new MatchAllDocsQuery(),
new CollectorManager<Collector, Void>() {
@Override
public Collector newCollector() {
return new Collector() {
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) {
return new LeafCollector() {
@Override
public void setScorer(Scorable scorer) throws IOException {
// the thread local used to prevent deadlock is static, so while each
// searcher has its own
// TaskExecutor, the safeguard is shared among all the searchers that get
// the same executor
IndexSearcher indexSearcher = new IndexSearcher(reader, executorService);
TaskExecutor.Task<Void> task =
indexSearcher.getTaskExecutor().createTask(() -> null);
searcher.getTaskExecutor().invokeAll(Collections.singletonList(task));
}
@Override
public void collect(int doc) {}
};
}
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE;
}
};
}
@Override
public Void reduce(Collection<Collector> collectors) {
return null;
}
});
}
}
}
}