mirror of https://github.com/apache/lucene.git
TaskExecutor to cancel all tasks on exception (#12689)
When operations are parallelized, like query rewrite, or search, or createWeight, one of the tasks may throw an exception. In that case we wait for all tasks to be completed before re-throwing the exception that were caught. Tasks that were not started when the exception is captured though can be safely skipped. Ideally we would also cancel ongoing tasks but I left that for another time.
This commit is contained in:
parent
71c4ea74ba
commit
1200ecce3a
|
@ -20,6 +20,7 @@ package org.apache.lucene.search;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.Callable;
|
||||
|
@ -27,6 +28,8 @@ import java.util.concurrent.ExecutionException;
|
|||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.FutureTask;
|
||||
import java.util.concurrent.RunnableFuture;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.ThreadInterruptedException;
|
||||
|
||||
|
@ -64,21 +67,90 @@ public final class TaskExecutor {
|
|||
* @param <T> the return type of the task execution
|
||||
*/
|
||||
public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException {
|
||||
List<Task<T>> tasks = new ArrayList<>(callables.size());
|
||||
boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0;
|
||||
for (Callable<T> callable : callables) {
|
||||
Task<T> task = new Task<>(callable);
|
||||
tasks.add(task);
|
||||
if (runOnCallerThread) {
|
||||
task.run();
|
||||
} else {
|
||||
executor.execute(task);
|
||||
}
|
||||
TaskGroup<T> taskGroup = new TaskGroup<>(callables);
|
||||
return taskGroup.invokeAll(executor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "TaskExecutor(" + "executor=" + executor + ')';
|
||||
}
|
||||
|
||||
/**
|
||||
* Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and
|
||||
* exposes the ability to invoke such tasks and wait for them all to complete their execution and
|
||||
* provide their results. Ensures that each task does not get parallelized further: this is
|
||||
* important to avoid a deadlock in situations where one executor thread waits on other executor
|
||||
* threads to complete before it can progress. This happens in situations where for instance
|
||||
* {@link Query#createWeight(IndexSearcher, ScoreMode, float)} is called as part of searching each
|
||||
* slice, like {@link TopFieldCollector#populateScores(ScoreDoc[], IndexSearcher, Query)} does.
|
||||
* Additionally, if one task throws an exception, all other tasks from the same group are
|
||||
* cancelled, to avoid needless computation as their results would not be exposed anyways. Creates
|
||||
* one {@link FutureTask} for each {@link Callable} provided
|
||||
*
|
||||
* @param <T> the return type of all the callables
|
||||
*/
|
||||
private static final class TaskGroup<T> {
|
||||
private final Collection<RunnableFuture<T>> futures;
|
||||
|
||||
TaskGroup(Collection<Callable<T>> callables) {
|
||||
List<RunnableFuture<T>> tasks = new ArrayList<>(callables.size());
|
||||
for (Callable<T> callable : callables) {
|
||||
tasks.add(createTask(callable));
|
||||
}
|
||||
this.futures = Collections.unmodifiableCollection(tasks);
|
||||
}
|
||||
|
||||
RunnableFuture<T> createTask(Callable<T> callable) {
|
||||
// -1: cancelled; 0: not yet started; 1: started
|
||||
AtomicBoolean startedOrCancelled = new AtomicBoolean(false);
|
||||
return new FutureTask<>(
|
||||
() -> {
|
||||
if (startedOrCancelled.compareAndSet(false, true)) {
|
||||
try {
|
||||
Integer counter = numberOfRunningTasksInCurrentThread.get();
|
||||
numberOfRunningTasksInCurrentThread.set(counter + 1);
|
||||
return callable.call();
|
||||
} catch (Throwable t) {
|
||||
cancelAll();
|
||||
throw t;
|
||||
} finally {
|
||||
Integer counter = numberOfRunningTasksInCurrentThread.get();
|
||||
numberOfRunningTasksInCurrentThread.set(counter - 1);
|
||||
}
|
||||
}
|
||||
// task is cancelled hence it has no results to return. That's fine: they would be
|
||||
// ignored anyway.
|
||||
return null;
|
||||
}) {
|
||||
@Override
|
||||
public boolean cancel(boolean mayInterruptIfRunning) {
|
||||
assert mayInterruptIfRunning == false
|
||||
: "cancelling tasks that are running is not supported";
|
||||
/*
|
||||
Future#get (called in invokeAll) throws CancellationException when invoked against a running task that has been cancelled but
|
||||
leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
|
||||
Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
|
||||
wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
|
||||
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
|
||||
*/
|
||||
return startedOrCancelled.compareAndSet(false, true);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
List<T> invokeAll(Executor executor) throws IOException {
|
||||
boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0;
|
||||
for (Runnable runnable : futures) {
|
||||
if (runOnCallerThread) {
|
||||
runnable.run();
|
||||
} else {
|
||||
executor.execute(runnable);
|
||||
}
|
||||
}
|
||||
Throwable exc = null;
|
||||
final List<T> results = new ArrayList<>();
|
||||
for (Future<T> future : tasks) {
|
||||
List<T> results = new ArrayList<>(futures.size());
|
||||
for (Future<T> future : futures) {
|
||||
try {
|
||||
results.add(future.get());
|
||||
} catch (InterruptedException e) {
|
||||
|
@ -96,38 +168,26 @@ public final class TaskExecutor {
|
|||
}
|
||||
}
|
||||
}
|
||||
assert assertAllFuturesCompleted() : "Some tasks are still running?";
|
||||
if (exc != null) {
|
||||
throw IOUtils.rethrowAlways(exc);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extension of {@link FutureTask} that tracks the number of tasks that are running in each
|
||||
* thread.
|
||||
*
|
||||
* @param <V> the return type of the task
|
||||
*/
|
||||
private static final class Task<V> extends FutureTask<V> {
|
||||
private Task(Callable<V> callable) {
|
||||
super(callable);
|
||||
private boolean assertAllFuturesCompleted() {
|
||||
for (RunnableFuture<T> future : futures) {
|
||||
if (future.isDone() == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
Integer counter = numberOfRunningTasksInCurrentThread.get();
|
||||
numberOfRunningTasksInCurrentThread.set(counter + 1);
|
||||
super.run();
|
||||
} finally {
|
||||
Integer counter = numberOfRunningTasksInCurrentThread.get();
|
||||
numberOfRunningTasksInCurrentThread.set(counter - 1);
|
||||
private void cancelAll() {
|
||||
for (Future<T> future : futures) {
|
||||
future.cancel(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "TaskExecutor(" + "executor=" + executor + ')';
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,8 +22,10 @@ import java.util.Collection;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
|
@ -32,6 +34,8 @@ import org.apache.lucene.store.Directory;
|
|||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.NamedThreadFactory;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
import org.hamcrest.Matchers;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
|
@ -43,7 +47,8 @@ public class TestTaskExecutor extends LuceneTestCase {
|
|||
public static void createExecutor() {
|
||||
executorService =
|
||||
Executors.newFixedThreadPool(
|
||||
1, new NamedThreadFactory(TestTaskExecutor.class.getSimpleName()));
|
||||
random().nextBoolean() ? 1 : 2,
|
||||
new NamedThreadFactory(TestTaskExecutor.class.getSimpleName()));
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
|
@ -228,11 +233,21 @@ public class TestTaskExecutor extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testInvokeAllDoesNotLeaveTasksBehind() {
|
||||
TaskExecutor taskExecutor = new TaskExecutor(executorService);
|
||||
AtomicInteger tasksStarted = new AtomicInteger(0);
|
||||
TaskExecutor taskExecutor =
|
||||
new TaskExecutor(
|
||||
command -> {
|
||||
executorService.execute(
|
||||
() -> {
|
||||
tasksStarted.incrementAndGet();
|
||||
command.run();
|
||||
});
|
||||
});
|
||||
AtomicInteger tasksExecuted = new AtomicInteger(0);
|
||||
List<Callable<Void>> callables = new ArrayList<>();
|
||||
callables.add(
|
||||
() -> {
|
||||
tasksExecuted.incrementAndGet();
|
||||
throw new RuntimeException();
|
||||
});
|
||||
int tasksWithNormalExit = 99;
|
||||
|
@ -244,7 +259,14 @@ public class TestTaskExecutor extends LuceneTestCase {
|
|||
});
|
||||
}
|
||||
expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables));
|
||||
assertEquals(tasksWithNormalExit, tasksExecuted.get());
|
||||
int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize();
|
||||
if (maximumPoolSize == 1) {
|
||||
assertEquals(1, tasksExecuted.get());
|
||||
} else {
|
||||
MatcherAssert.assertThat(tasksExecuted.get(), Matchers.greaterThanOrEqualTo(1));
|
||||
}
|
||||
// the callables are technically all run, but the cancelled ones will be no-op
|
||||
assertEquals(100, tasksStarted.get());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -253,28 +275,36 @@ public class TestTaskExecutor extends LuceneTestCase {
|
|||
*/
|
||||
public void testInvokeAllCatchesMultipleExceptions() {
|
||||
TaskExecutor taskExecutor = new TaskExecutor(executorService);
|
||||
AtomicInteger tasksExecuted = new AtomicInteger(0);
|
||||
List<Callable<Void>> callables = new ArrayList<>();
|
||||
int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize();
|
||||
// if we have multiple threads, make sure both are started before an exception is thrown,
|
||||
// otherwise there may or may not be a suppressed exception
|
||||
CountDownLatch latchA = new CountDownLatch(1);
|
||||
CountDownLatch latchB = new CountDownLatch(1);
|
||||
callables.add(
|
||||
() -> {
|
||||
if (maximumPoolSize > 1) {
|
||||
latchA.countDown();
|
||||
latchB.await();
|
||||
}
|
||||
throw new RuntimeException("exception A");
|
||||
});
|
||||
int tasksWithNormalExit = 50;
|
||||
for (int i = 0; i < tasksWithNormalExit; i++) {
|
||||
callables.add(
|
||||
() -> {
|
||||
tasksExecuted.incrementAndGet();
|
||||
return null;
|
||||
});
|
||||
if (maximumPoolSize > 1) {
|
||||
latchB.countDown();
|
||||
latchA.await();
|
||||
}
|
||||
callables.add(
|
||||
() -> {
|
||||
throw new IllegalStateException("exception B");
|
||||
});
|
||||
|
||||
RuntimeException exc =
|
||||
expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables));
|
||||
Throwable[] suppressed = exc.getSuppressed();
|
||||
|
||||
if (maximumPoolSize == 1) {
|
||||
assertEquals(0, suppressed.length);
|
||||
} else {
|
||||
assertEquals(1, suppressed.length);
|
||||
if (exc.getMessage().equals("exception A")) {
|
||||
assertEquals("exception B", suppressed[0].getMessage());
|
||||
|
@ -282,7 +312,46 @@ public class TestTaskExecutor extends LuceneTestCase {
|
|||
assertEquals("exception A", suppressed[0].getMessage());
|
||||
assertEquals("exception B", exc.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(tasksWithNormalExit, tasksExecuted.get());
|
||||
public void testCancelTasksOnException() {
|
||||
TaskExecutor taskExecutor = new TaskExecutor(executorService);
|
||||
int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize();
|
||||
final int numTasks = random().nextInt(10, 50);
|
||||
final int throwingTask = random().nextInt(numTasks);
|
||||
boolean error = random().nextBoolean();
|
||||
List<Callable<Void>> tasks = new ArrayList<>(numTasks);
|
||||
AtomicInteger executedTasks = new AtomicInteger(0);
|
||||
for (int i = 0; i < numTasks; i++) {
|
||||
final int index = i;
|
||||
tasks.add(
|
||||
() -> {
|
||||
if (index == throwingTask) {
|
||||
if (error) {
|
||||
throw new OutOfMemoryError();
|
||||
} else {
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
if (index > throwingTask && maximumPoolSize == 1) {
|
||||
throw new AssertionError("task should not have started");
|
||||
}
|
||||
executedTasks.incrementAndGet();
|
||||
return null;
|
||||
});
|
||||
}
|
||||
Throwable throwable;
|
||||
if (error) {
|
||||
throwable = expectThrows(OutOfMemoryError.class, () -> taskExecutor.invokeAll(tasks));
|
||||
} else {
|
||||
throwable = expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(tasks));
|
||||
}
|
||||
assertEquals(0, throwable.getSuppressed().length);
|
||||
if (maximumPoolSize == 1) {
|
||||
assertEquals(throwingTask, executedTasks.get());
|
||||
} else {
|
||||
MatcherAssert.assertThat(executedTasks.get(), Matchers.greaterThanOrEqualTo(throwingTask));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue