Reduce TaskExecutor overhead (#13861)

The `TaskGroup` class is redundant, the futures list can be a local variable
shared by the tasks (this also removes any need for making it read-only).
This commit is contained in:
Armin Braun 2024-10-08 11:04:39 +02:00
parent e10e9d136d
commit 22638ec8a2
1 changed files with 104 additions and 118 deletions

View File

@ -20,7 +20,6 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
@ -73,83 +72,19 @@ public final class TaskExecutor {
/** /**
* Execute all the callables provided as an argument, wait for them to complete and return the * Execute all the callables provided as an argument, wait for them to complete and return the
* obtained results. If an exception is thrown by more than one callable, the subsequent ones will * obtained results. If an exception is thrown by more than one callable, the subsequent ones will
* be added as suppressed exceptions to the first one that was caught. * be added as suppressed exceptions to the first one that was caught. Additionally, if one task
* throws an exception, all other tasks from the same group are cancelled, to avoid needless
* computation as their results would not be exposed anyways.
* *
* @param callables the callables to execute * @param callables the callables to execute
* @return a list containing the results from the tasks execution * @return a list containing the results from the tasks execution
* @param <T> the return type of the task execution * @param <T> the return type of the task execution
*/ */
public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException { public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException {
TaskGroup<T> taskGroup = new TaskGroup<>(callables); List<RunnableFuture<T>> futures = new ArrayList<>(callables.size());
return taskGroup.invokeAll(executor);
}
@Override
public String toString() {
return "TaskExecutor(" + "executor=" + executor + ')';
}
/**
* Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and
* exposes the ability to invoke such tasks and wait for them all to complete their execution and
* provide their results. Additionally, if one task throws an exception, all other tasks from the
* same group are cancelled, to avoid needless computation as their results would not be exposed
* anyways. Creates one {@link FutureTask} for each {@link Callable} provided
*
* @param <T> the return type of all the callables
*/
private static final class TaskGroup<T> {
private final List<RunnableFuture<T>> futures;
TaskGroup(Collection<Callable<T>> callables) {
List<RunnableFuture<T>> tasks = new ArrayList<>(callables.size());
for (Callable<T> callable : callables) { for (Callable<T> callable : callables) {
tasks.add(createTask(callable)); futures.add(new Task<>(callable, futures));
} }
this.futures = Collections.unmodifiableList(tasks);
}
RunnableFuture<T> createTask(Callable<T> callable) {
return new FutureTask<>(callable) {
private final AtomicBoolean startedOrCancelled = new AtomicBoolean(false);
@Override
public void run() {
if (startedOrCancelled.compareAndSet(false, true)) {
super.run();
}
}
@Override
protected void setException(Throwable t) {
super.setException(t);
cancelAll();
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
assert mayInterruptIfRunning == false
: "cancelling tasks that are running is not supported";
/*
Future#get (called in invokeAll) throws CancellationException when invoked against a running task that has been cancelled but
leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
*/
if (startedOrCancelled.compareAndSet(false, true)) {
// task is cancelled hence it has no results to return. That's fine: they would be
// ignored anyway.
set(null);
return true;
}
return false;
}
};
}
List<T> invokeAll(Executor executor) throws IOException {
final int count = futures.size(); final int count = futures.size();
// taskId provides the first index of an un-executed task in #futures // taskId provides the first index of an un-executed task in #futures
final AtomicInteger taskId = new AtomicInteger(0); final AtomicInteger taskId = new AtomicInteger(0);
@ -179,10 +114,13 @@ public final class TaskExecutor {
break; break;
} }
} }
return collectResults(futures);
}
private static <T> List<T> collectResults(List<RunnableFuture<T>> futures) throws IOException {
Throwable exc = null; Throwable exc = null;
List<T> results = new ArrayList<>(count); List<T> results = new ArrayList<>(futures.size());
for (int i = 0; i < count; i++) { for (Future<T> future : futures) {
Future<T> future = futures.get(i);
try { try {
results.add(future.get()); results.add(future.get());
} catch (InterruptedException e) { } catch (InterruptedException e) {
@ -191,15 +129,20 @@ public final class TaskExecutor {
exc = IOUtils.useOrSuppress(exc, e.getCause()); exc = IOUtils.useOrSuppress(exc, e.getCause());
} }
} }
assert assertAllFuturesCompleted() : "Some tasks are still running?"; assert assertAllFuturesCompleted(futures) : "Some tasks are still running?";
if (exc != null) { if (exc != null) {
throw IOUtils.rethrowAlways(exc); throw IOUtils.rethrowAlways(exc);
} }
return results; return results;
} }
private boolean assertAllFuturesCompleted() { @Override
for (RunnableFuture<T> future : futures) { public String toString() {
return "TaskExecutor(" + "executor=" + executor + ')';
}
private static boolean assertAllFuturesCompleted(Collection<? extends Future<?>> futures) {
for (Future<?> future : futures) {
if (future.isDone() == false) { if (future.isDone() == false) {
return false; return false;
} }
@ -207,10 +150,53 @@ public final class TaskExecutor {
return true; return true;
} }
private void cancelAll() { private static <T> void cancelAll(Collection<? extends Future<T>> futures) {
for (Future<T> future : futures) { for (Future<?> future : futures) {
future.cancel(false); future.cancel(false);
} }
} }
private static class Task<T> extends FutureTask<T> {
private final AtomicBoolean startedOrCancelled = new AtomicBoolean(false);
private final Collection<? extends Future<T>> futures;
public Task(Callable<T> callable, Collection<? extends Future<T>> futures) {
super(callable);
this.futures = futures;
}
@Override
public void run() {
if (startedOrCancelled.compareAndSet(false, true)) {
super.run();
}
}
@Override
protected void setException(Throwable t) {
super.setException(t);
cancelAll(futures);
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
assert mayInterruptIfRunning == false : "cancelling tasks that are running is not supported";
/*
Future#get (called in #collectResults) throws CancellationException when invoked against a running task that has been cancelled but
leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
*/
if (startedOrCancelled.compareAndSet(false, true)) {
// task is cancelled hence it has no results to return. That's fine: they would be
// ignored anyway.
set(null);
return true;
}
return false;
}
} }
} }