TaskExecutor to cancel all tasks on exception (#12689)

When operations are parallelized, like query rewrite, or search, or
createWeight, one of the tasks may throw an exception. In that case we
wait for all tasks to be completed before re-throwing the exception that
were caught. Tasks that were not started when the exception is captured
though can be safely skipped. Ideally we would also cancel ongoing tasks
but I left that for another time.
This commit is contained in:
Luca Cavanna 2023-10-24 17:38:25 +02:00 committed by GitHub
parent 71c4ea74ba
commit 1200ecce3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 209 additions and 80 deletions

View File

@ -20,6 +20,7 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
@ -27,6 +28,8 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.FutureTask; import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.ThreadInterruptedException; import org.apache.lucene.util.ThreadInterruptedException;
@ -64,21 +67,90 @@ public final class TaskExecutor {
* @param <T> the return type of the task execution * @param <T> the return type of the task execution
*/ */
public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException { public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException {
List<Task<T>> tasks = new ArrayList<>(callables.size()); TaskGroup<T> taskGroup = new TaskGroup<>(callables);
boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0; return taskGroup.invokeAll(executor);
for (Callable<T> callable : callables) {
Task<T> task = new Task<>(callable);
tasks.add(task);
if (runOnCallerThread) {
task.run();
} else {
executor.execute(task);
}
} }
@Override
public String toString() {
return "TaskExecutor(" + "executor=" + executor + ')';
}
/**
* Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and
* exposes the ability to invoke such tasks and wait for them all to complete their execution and
* provide their results. Ensures that each task does not get parallelized further: this is
* important to avoid a deadlock in situations where one executor thread waits on other executor
* threads to complete before it can progress. This happens in situations where for instance
* {@link Query#createWeight(IndexSearcher, ScoreMode, float)} is called as part of searching each
* slice, like {@link TopFieldCollector#populateScores(ScoreDoc[], IndexSearcher, Query)} does.
* Additionally, if one task throws an exception, all other tasks from the same group are
* cancelled, to avoid needless computation as their results would not be exposed anyways. Creates
* one {@link FutureTask} for each {@link Callable} provided
*
* @param <T> the return type of all the callables
*/
private static final class TaskGroup<T> {
private final Collection<RunnableFuture<T>> futures;
TaskGroup(Collection<Callable<T>> callables) {
List<RunnableFuture<T>> tasks = new ArrayList<>(callables.size());
for (Callable<T> callable : callables) {
tasks.add(createTask(callable));
}
this.futures = Collections.unmodifiableCollection(tasks);
}
RunnableFuture<T> createTask(Callable<T> callable) {
// -1: cancelled; 0: not yet started; 1: started
AtomicBoolean startedOrCancelled = new AtomicBoolean(false);
return new FutureTask<>(
() -> {
if (startedOrCancelled.compareAndSet(false, true)) {
try {
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter + 1);
return callable.call();
} catch (Throwable t) {
cancelAll();
throw t;
} finally {
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter - 1);
}
}
// task is cancelled hence it has no results to return. That's fine: they would be
// ignored anyway.
return null;
}) {
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
assert mayInterruptIfRunning == false
: "cancelling tasks that are running is not supported";
/*
Future#get (called in invokeAll) throws CancellationException when invoked against a running task that has been cancelled but
leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
*/
return startedOrCancelled.compareAndSet(false, true);
}
};
}
List<T> invokeAll(Executor executor) throws IOException {
boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0;
for (Runnable runnable : futures) {
if (runOnCallerThread) {
runnable.run();
} else {
executor.execute(runnable);
}
}
Throwable exc = null; Throwable exc = null;
final List<T> results = new ArrayList<>(); List<T> results = new ArrayList<>(futures.size());
for (Future<T> future : tasks) { for (Future<T> future : futures) {
try { try {
results.add(future.get()); results.add(future.get());
} catch (InterruptedException e) { } catch (InterruptedException e) {
@ -96,38 +168,26 @@ public final class TaskExecutor {
} }
} }
} }
assert assertAllFuturesCompleted() : "Some tasks are still running?";
if (exc != null) { if (exc != null) {
throw IOUtils.rethrowAlways(exc); throw IOUtils.rethrowAlways(exc);
} }
return results; return results;
} }
/** private boolean assertAllFuturesCompleted() {
* Extension of {@link FutureTask} that tracks the number of tasks that are running in each for (RunnableFuture<T> future : futures) {
* thread. if (future.isDone() == false) {
* return false;
* @param <V> the return type of the task }
*/ }
private static final class Task<V> extends FutureTask<V> { return true;
private Task(Callable<V> callable) {
super(callable);
} }
@Override private void cancelAll() {
public void run() { for (Future<T> future : futures) {
try { future.cancel(false);
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter + 1);
super.run();
} finally {
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter - 1);
} }
} }
} }
@Override
public String toString() {
return "TaskExecutor(" + "executor=" + executor + ')';
}
} }

View File

@ -22,8 +22,10 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
@ -32,6 +34,8 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.NamedThreadFactory; import org.apache.lucene.util.NamedThreadFactory;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.BeforeClass; import org.junit.BeforeClass;
@ -43,7 +47,8 @@ public class TestTaskExecutor extends LuceneTestCase {
public static void createExecutor() { public static void createExecutor() {
executorService = executorService =
Executors.newFixedThreadPool( Executors.newFixedThreadPool(
1, new NamedThreadFactory(TestTaskExecutor.class.getSimpleName())); random().nextBoolean() ? 1 : 2,
new NamedThreadFactory(TestTaskExecutor.class.getSimpleName()));
} }
@AfterClass @AfterClass
@ -228,11 +233,21 @@ public class TestTaskExecutor extends LuceneTestCase {
} }
public void testInvokeAllDoesNotLeaveTasksBehind() { public void testInvokeAllDoesNotLeaveTasksBehind() {
TaskExecutor taskExecutor = new TaskExecutor(executorService); AtomicInteger tasksStarted = new AtomicInteger(0);
TaskExecutor taskExecutor =
new TaskExecutor(
command -> {
executorService.execute(
() -> {
tasksStarted.incrementAndGet();
command.run();
});
});
AtomicInteger tasksExecuted = new AtomicInteger(0); AtomicInteger tasksExecuted = new AtomicInteger(0);
List<Callable<Void>> callables = new ArrayList<>(); List<Callable<Void>> callables = new ArrayList<>();
callables.add( callables.add(
() -> { () -> {
tasksExecuted.incrementAndGet();
throw new RuntimeException(); throw new RuntimeException();
}); });
int tasksWithNormalExit = 99; int tasksWithNormalExit = 99;
@ -244,7 +259,14 @@ public class TestTaskExecutor extends LuceneTestCase {
}); });
} }
expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables)); expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables));
assertEquals(tasksWithNormalExit, tasksExecuted.get()); int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize();
if (maximumPoolSize == 1) {
assertEquals(1, tasksExecuted.get());
} else {
MatcherAssert.assertThat(tasksExecuted.get(), Matchers.greaterThanOrEqualTo(1));
}
// the callables are technically all run, but the cancelled ones will be no-op
assertEquals(100, tasksStarted.get());
} }
/** /**
@ -253,28 +275,36 @@ public class TestTaskExecutor extends LuceneTestCase {
*/ */
public void testInvokeAllCatchesMultipleExceptions() { public void testInvokeAllCatchesMultipleExceptions() {
TaskExecutor taskExecutor = new TaskExecutor(executorService); TaskExecutor taskExecutor = new TaskExecutor(executorService);
AtomicInteger tasksExecuted = new AtomicInteger(0);
List<Callable<Void>> callables = new ArrayList<>(); List<Callable<Void>> callables = new ArrayList<>();
int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize();
// if we have multiple threads, make sure both are started before an exception is thrown,
// otherwise there may or may not be a suppressed exception
CountDownLatch latchA = new CountDownLatch(1);
CountDownLatch latchB = new CountDownLatch(1);
callables.add( callables.add(
() -> { () -> {
if (maximumPoolSize > 1) {
latchA.countDown();
latchB.await();
}
throw new RuntimeException("exception A"); throw new RuntimeException("exception A");
}); });
int tasksWithNormalExit = 50;
for (int i = 0; i < tasksWithNormalExit; i++) {
callables.add( callables.add(
() -> { () -> {
tasksExecuted.incrementAndGet(); if (maximumPoolSize > 1) {
return null; latchB.countDown();
}); latchA.await();
} }
callables.add(
() -> {
throw new IllegalStateException("exception B"); throw new IllegalStateException("exception B");
}); });
RuntimeException exc = RuntimeException exc =
expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables)); expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables));
Throwable[] suppressed = exc.getSuppressed(); Throwable[] suppressed = exc.getSuppressed();
if (maximumPoolSize == 1) {
assertEquals(0, suppressed.length);
} else {
assertEquals(1, suppressed.length); assertEquals(1, suppressed.length);
if (exc.getMessage().equals("exception A")) { if (exc.getMessage().equals("exception A")) {
assertEquals("exception B", suppressed[0].getMessage()); assertEquals("exception B", suppressed[0].getMessage());
@ -282,7 +312,46 @@ public class TestTaskExecutor extends LuceneTestCase {
assertEquals("exception A", suppressed[0].getMessage()); assertEquals("exception A", suppressed[0].getMessage());
assertEquals("exception B", exc.getMessage()); assertEquals("exception B", exc.getMessage());
} }
}
}
assertEquals(tasksWithNormalExit, tasksExecuted.get()); public void testCancelTasksOnException() {
TaskExecutor taskExecutor = new TaskExecutor(executorService);
int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize();
final int numTasks = random().nextInt(10, 50);
final int throwingTask = random().nextInt(numTasks);
boolean error = random().nextBoolean();
List<Callable<Void>> tasks = new ArrayList<>(numTasks);
AtomicInteger executedTasks = new AtomicInteger(0);
for (int i = 0; i < numTasks; i++) {
final int index = i;
tasks.add(
() -> {
if (index == throwingTask) {
if (error) {
throw new OutOfMemoryError();
} else {
throw new RuntimeException();
}
}
if (index > throwingTask && maximumPoolSize == 1) {
throw new AssertionError("task should not have started");
}
executedTasks.incrementAndGet();
return null;
});
}
Throwable throwable;
if (error) {
throwable = expectThrows(OutOfMemoryError.class, () -> taskExecutor.invokeAll(tasks));
} else {
throwable = expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(tasks));
}
assertEquals(0, throwable.getSuppressed().length);
if (maximumPoolSize == 1) {
assertEquals(throwingTask, executedTasks.get());
} else {
MatcherAssert.assertThat(executedTasks.get(), Matchers.greaterThanOrEqualTo(throwingTask));
}
} }
} }