Reduce TaskExecutor overhead (#13861)

The `TaskGroup` class is redundant, the futures list can be a local variable
shared by the tasks (this also removes any need for making it read-only).
This commit is contained in:
Armin Braun 2024-10-08 11:04:39 +02:00
parent e10e9d136d
commit 22638ec8a2
1 changed files with 104 additions and 118 deletions

View File

@ -20,7 +20,6 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
@ -73,15 +72,68 @@ public final class TaskExecutor {
/**
* Execute all the callables provided as an argument, wait for them to complete and return the
* obtained results. If an exception is thrown by more than one callable, the subsequent ones will
* be added as suppressed exceptions to the first one that was caught.
* be added as suppressed exceptions to the first one that was caught. Additionally, if one task
* throws an exception, all other tasks from the same group are cancelled, to avoid needless
* computation as their results would not be exposed anyways.
*
* @param callables the callables to execute
* @return a list containing the results from the tasks execution
* @param <T> the return type of the task execution
*/
public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException {
TaskGroup<T> taskGroup = new TaskGroup<>(callables);
return taskGroup.invokeAll(executor);
List<RunnableFuture<T>> futures = new ArrayList<>(callables.size());
for (Callable<T> callable : callables) {
futures.add(new Task<>(callable, futures));
}
final int count = futures.size();
// taskId provides the first index of an un-executed task in #futures
final AtomicInteger taskId = new AtomicInteger(0);
// we fork execution count - 1 tasks to execute at least one task on the current thread to
// minimize needless forking and blocking of the current thread
if (count > 1) {
final Runnable work =
() -> {
int id = taskId.getAndIncrement();
if (id < count) {
futures.get(id).run();
}
};
for (int j = 0; j < count - 1; j++) {
executor.execute(work);
}
}
// try to execute as many tasks as possible on the current thread to minimize context
// switching in case of long running concurrent
// tasks as well as dead-locking if the current thread is part of #executor for executors that
// have limited or no parallelism
int id;
while ((id = taskId.getAndIncrement()) < count) {
futures.get(id).run();
if (id >= count - 1) {
// save redundant CAS in case this was the last task
break;
}
}
return collectResults(futures);
}
private static <T> List<T> collectResults(List<RunnableFuture<T>> futures) throws IOException {
Throwable exc = null;
List<T> results = new ArrayList<>(futures.size());
for (Future<T> future : futures) {
try {
results.add(future.get());
} catch (InterruptedException e) {
exc = IOUtils.useOrSuppress(exc, new ThreadInterruptedException(e));
} catch (ExecutionException e) {
exc = IOUtils.useOrSuppress(exc, e.getCause());
}
}
assert assertAllFuturesCompleted(futures) : "Some tasks are still running?";
if (exc != null) {
throw IOUtils.rethrowAlways(exc);
}
return results;
}
@Override
@ -89,128 +141,62 @@ public final class TaskExecutor {
return "TaskExecutor(" + "executor=" + executor + ')';
}
/**
* Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and
* exposes the ability to invoke such tasks and wait for them all to complete their execution and
* provide their results. Additionally, if one task throws an exception, all other tasks from the
* same group are cancelled, to avoid needless computation as their results would not be exposed
* anyways. Creates one {@link FutureTask} for each {@link Callable} provided
*
* @param <T> the return type of all the callables
*/
private static final class TaskGroup<T> {
private final List<RunnableFuture<T>> futures;
TaskGroup(Collection<Callable<T>> callables) {
List<RunnableFuture<T>> tasks = new ArrayList<>(callables.size());
for (Callable<T> callable : callables) {
tasks.add(createTask(callable));
private static boolean assertAllFuturesCompleted(Collection<? extends Future<?>> futures) {
for (Future<?> future : futures) {
if (future.isDone() == false) {
return false;
}
this.futures = Collections.unmodifiableList(tasks);
}
return true;
}
private static <T> void cancelAll(Collection<? extends Future<T>> futures) {
for (Future<?> future : futures) {
future.cancel(false);
}
}
private static class Task<T> extends FutureTask<T> {
private final AtomicBoolean startedOrCancelled = new AtomicBoolean(false);
private final Collection<? extends Future<T>> futures;
public Task(Callable<T> callable, Collection<? extends Future<T>> futures) {
super(callable);
this.futures = futures;
}
RunnableFuture<T> createTask(Callable<T> callable) {
return new FutureTask<>(callable) {
private final AtomicBoolean startedOrCancelled = new AtomicBoolean(false);
@Override
public void run() {
if (startedOrCancelled.compareAndSet(false, true)) {
super.run();
}
}
@Override
protected void setException(Throwable t) {
super.setException(t);
cancelAll();
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
assert mayInterruptIfRunning == false
: "cancelling tasks that are running is not supported";
/*
Future#get (called in invokeAll) throws CancellationException when invoked against a running task that has been cancelled but
leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
*/
if (startedOrCancelled.compareAndSet(false, true)) {
// task is cancelled hence it has no results to return. That's fine: they would be
// ignored anyway.
set(null);
return true;
}
return false;
}
};
@Override
public void run() {
if (startedOrCancelled.compareAndSet(false, true)) {
super.run();
}
}
List<T> invokeAll(Executor executor) throws IOException {
final int count = futures.size();
// taskId provides the first index of an un-executed task in #futures
final AtomicInteger taskId = new AtomicInteger(0);
// we fork execution count - 1 tasks to execute at least one task on the current thread to
// minimize needless forking and blocking of the current thread
if (count > 1) {
final Runnable work =
() -> {
int id = taskId.getAndIncrement();
if (id < count) {
futures.get(id).run();
}
};
for (int j = 0; j < count - 1; j++) {
executor.execute(work);
}
}
// try to execute as many tasks as possible on the current thread to minimize context
// switching in case of long running concurrent
// tasks as well as dead-locking if the current thread is part of #executor for executors that
// have limited or no parallelism
int id;
while ((id = taskId.getAndIncrement()) < count) {
futures.get(id).run();
if (id >= count - 1) {
// save redundant CAS in case this was the last task
break;
}
}
Throwable exc = null;
List<T> results = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
Future<T> future = futures.get(i);
try {
results.add(future.get());
} catch (InterruptedException e) {
exc = IOUtils.useOrSuppress(exc, new ThreadInterruptedException(e));
} catch (ExecutionException e) {
exc = IOUtils.useOrSuppress(exc, e.getCause());
}
}
assert assertAllFuturesCompleted() : "Some tasks are still running?";
if (exc != null) {
throw IOUtils.rethrowAlways(exc);
}
return results;
@Override
protected void setException(Throwable t) {
super.setException(t);
cancelAll(futures);
}
private boolean assertAllFuturesCompleted() {
for (RunnableFuture<T> future : futures) {
if (future.isDone() == false) {
return false;
}
}
return true;
}
private void cancelAll() {
for (Future<T> future : futures) {
future.cancel(false);
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
assert mayInterruptIfRunning == false : "cancelling tasks that are running is not supported";
/*
Future#get (called in #collectResults) throws CancellationException when invoked against a running task that has been cancelled but
leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
*/
if (startedOrCancelled.compareAndSet(false, true)) {
// task is cancelled hence it has no results to return. That's fine: they would be
// ignored anyway.
set(null);
return true;
}
return false;
}
}
}