Reduce TaskExecutor overhead (#13861)

The `TaskGroup` class is redundant, the futures list can be a local variable
shared by the tasks (this also removes any need for making it read-only).
This commit is contained in:
Armin Braun 2024-10-08 11:04:39 +02:00
parent e10e9d136d
commit 22638ec8a2
1 changed files with 104 additions and 118 deletions

View File

@ -20,7 +20,6 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
@ -73,15 +72,68 @@ public final class TaskExecutor {
/** /**
* Execute all the callables provided as an argument, wait for them to complete and return the * Execute all the callables provided as an argument, wait for them to complete and return the
* obtained results. If an exception is thrown by more than one callable, the subsequent ones will * obtained results. If an exception is thrown by more than one callable, the subsequent ones will
* be added as suppressed exceptions to the first one that was caught. * be added as suppressed exceptions to the first one that was caught. Additionally, if one task
* throws an exception, all other tasks from the same group are cancelled, to avoid needless
* computation as their results would not be exposed anyways.
* *
* @param callables the callables to execute * @param callables the callables to execute
* @return a list containing the results from the tasks execution * @return a list containing the results from the tasks execution
* @param <T> the return type of the task execution * @param <T> the return type of the task execution
*/ */
public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException { public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException {
TaskGroup<T> taskGroup = new TaskGroup<>(callables); List<RunnableFuture<T>> futures = new ArrayList<>(callables.size());
return taskGroup.invokeAll(executor); for (Callable<T> callable : callables) {
futures.add(new Task<>(callable, futures));
}
final int count = futures.size();
// taskId provides the first index of an un-executed task in #futures
final AtomicInteger taskId = new AtomicInteger(0);
// we fork execution count - 1 tasks to execute at least one task on the current thread to
// minimize needless forking and blocking of the current thread
if (count > 1) {
final Runnable work =
() -> {
int id = taskId.getAndIncrement();
if (id < count) {
futures.get(id).run();
}
};
for (int j = 0; j < count - 1; j++) {
executor.execute(work);
}
}
// try to execute as many tasks as possible on the current thread to minimize context
// switching in case of long running concurrent
// tasks as well as dead-locking if the current thread is part of #executor for executors that
// have limited or no parallelism
int id;
while ((id = taskId.getAndIncrement()) < count) {
futures.get(id).run();
if (id >= count - 1) {
// save redundant CAS in case this was the last task
break;
}
}
return collectResults(futures);
}
private static <T> List<T> collectResults(List<RunnableFuture<T>> futures) throws IOException {
Throwable exc = null;
List<T> results = new ArrayList<>(futures.size());
for (Future<T> future : futures) {
try {
results.add(future.get());
} catch (InterruptedException e) {
exc = IOUtils.useOrSuppress(exc, new ThreadInterruptedException(e));
} catch (ExecutionException e) {
exc = IOUtils.useOrSuppress(exc, e.getCause());
}
}
assert assertAllFuturesCompleted(futures) : "Some tasks are still running?";
if (exc != null) {
throw IOUtils.rethrowAlways(exc);
}
return results;
} }
@Override @Override
@ -89,128 +141,62 @@ public final class TaskExecutor {
return "TaskExecutor(" + "executor=" + executor + ')'; return "TaskExecutor(" + "executor=" + executor + ')';
} }
/** private static boolean assertAllFuturesCompleted(Collection<? extends Future<?>> futures) {
* Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and for (Future<?> future : futures) {
* exposes the ability to invoke such tasks and wait for them all to complete their execution and if (future.isDone() == false) {
* provide their results. Additionally, if one task throws an exception, all other tasks from the return false;
* same group are cancelled, to avoid needless computation as their results would not be exposed
* anyways. Creates one {@link FutureTask} for each {@link Callable} provided
*
* @param <T> the return type of all the callables
*/
private static final class TaskGroup<T> {
private final List<RunnableFuture<T>> futures;
TaskGroup(Collection<Callable<T>> callables) {
List<RunnableFuture<T>> tasks = new ArrayList<>(callables.size());
for (Callable<T> callable : callables) {
tasks.add(createTask(callable));
} }
this.futures = Collections.unmodifiableList(tasks); }
return true;
}
private static <T> void cancelAll(Collection<? extends Future<T>> futures) {
for (Future<?> future : futures) {
future.cancel(false);
}
}
private static class Task<T> extends FutureTask<T> {
private final AtomicBoolean startedOrCancelled = new AtomicBoolean(false);
private final Collection<? extends Future<T>> futures;
public Task(Callable<T> callable, Collection<? extends Future<T>> futures) {
super(callable);
this.futures = futures;
} }
RunnableFuture<T> createTask(Callable<T> callable) { @Override
return new FutureTask<>(callable) { public void run() {
if (startedOrCancelled.compareAndSet(false, true)) {
private final AtomicBoolean startedOrCancelled = new AtomicBoolean(false); super.run();
}
@Override
public void run() {
if (startedOrCancelled.compareAndSet(false, true)) {
super.run();
}
}
@Override
protected void setException(Throwable t) {
super.setException(t);
cancelAll();
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
assert mayInterruptIfRunning == false
: "cancelling tasks that are running is not supported";
/*
Future#get (called in invokeAll) throws CancellationException when invoked against a running task that has been cancelled but
leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
*/
if (startedOrCancelled.compareAndSet(false, true)) {
// task is cancelled hence it has no results to return. That's fine: they would be
// ignored anyway.
set(null);
return true;
}
return false;
}
};
} }
List<T> invokeAll(Executor executor) throws IOException { @Override
final int count = futures.size(); protected void setException(Throwable t) {
// taskId provides the first index of an un-executed task in #futures super.setException(t);
final AtomicInteger taskId = new AtomicInteger(0); cancelAll(futures);
// we fork execution count - 1 tasks to execute at least one task on the current thread to
// minimize needless forking and blocking of the current thread
if (count > 1) {
final Runnable work =
() -> {
int id = taskId.getAndIncrement();
if (id < count) {
futures.get(id).run();
}
};
for (int j = 0; j < count - 1; j++) {
executor.execute(work);
}
}
// try to execute as many tasks as possible on the current thread to minimize context
// switching in case of long running concurrent
// tasks as well as dead-locking if the current thread is part of #executor for executors that
// have limited or no parallelism
int id;
while ((id = taskId.getAndIncrement()) < count) {
futures.get(id).run();
if (id >= count - 1) {
// save redundant CAS in case this was the last task
break;
}
}
Throwable exc = null;
List<T> results = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
Future<T> future = futures.get(i);
try {
results.add(future.get());
} catch (InterruptedException e) {
exc = IOUtils.useOrSuppress(exc, new ThreadInterruptedException(e));
} catch (ExecutionException e) {
exc = IOUtils.useOrSuppress(exc, e.getCause());
}
}
assert assertAllFuturesCompleted() : "Some tasks are still running?";
if (exc != null) {
throw IOUtils.rethrowAlways(exc);
}
return results;
} }
private boolean assertAllFuturesCompleted() { @Override
for (RunnableFuture<T> future : futures) { public boolean cancel(boolean mayInterruptIfRunning) {
if (future.isDone() == false) { assert mayInterruptIfRunning == false : "cancelling tasks that are running is not supported";
return false; /*
} Future#get (called in #collectResults) throws CancellationException when invoked against a running task that has been cancelled but
} leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
return true; Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
} wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
private void cancelAll() { */
for (Future<T> future : futures) { if (startedOrCancelled.compareAndSet(false, true)) {
future.cancel(false); // task is cancelled hence it has no results to return. That's fine: they would be
// ignored anyway.
set(null);
return true;
} }
return false;
} }
} }
} }