diff --git a/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java b/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java index 331d692a854..6c89c267a52 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java +++ b/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java @@ -20,7 +20,6 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.concurrent.Callable; @@ -73,15 +72,68 @@ public final class TaskExecutor { /** * Execute all the callables provided as an argument, wait for them to complete and return the * obtained results. If an exception is thrown by more than one callable, the subsequent ones will - * be added as suppressed exceptions to the first one that was caught. + * be added as suppressed exceptions to the first one that was caught. Additionally, if one task + * throws an exception, all other tasks from the same group are cancelled, to avoid needless + * computation as their results would not be exposed anyways. * * @param callables the callables to execute * @return a list containing the results from the tasks execution * @param the return type of the task execution */ public List invokeAll(Collection> callables) throws IOException { - TaskGroup taskGroup = new TaskGroup<>(callables); - return taskGroup.invokeAll(executor); + List> futures = new ArrayList<>(callables.size()); + for (Callable callable : callables) { + futures.add(new Task<>(callable, futures)); + } + final int count = futures.size(); + // taskId provides the first index of an un-executed task in #futures + final AtomicInteger taskId = new AtomicInteger(0); + // we fork execution count - 1 tasks to execute at least one task on the current thread to + // minimize needless forking and blocking of the current thread + if (count > 1) { + final Runnable work = + () -> { + int id = taskId.getAndIncrement(); + if (id < count) { + futures.get(id).run(); + } + }; + for (int j = 0; j < count - 1; j++) { + executor.execute(work); + } + } + // try to execute as many tasks as possible on the current thread to minimize context + // switching in case of long running concurrent + // tasks as well as dead-locking if the current thread is part of #executor for executors that + // have limited or no parallelism + int id; + while ((id = taskId.getAndIncrement()) < count) { + futures.get(id).run(); + if (id >= count - 1) { + // save redundant CAS in case this was the last task + break; + } + } + return collectResults(futures); + } + + private static List collectResults(List> futures) throws IOException { + Throwable exc = null; + List results = new ArrayList<>(futures.size()); + for (Future future : futures) { + try { + results.add(future.get()); + } catch (InterruptedException e) { + exc = IOUtils.useOrSuppress(exc, new ThreadInterruptedException(e)); + } catch (ExecutionException e) { + exc = IOUtils.useOrSuppress(exc, e.getCause()); + } + } + assert assertAllFuturesCompleted(futures) : "Some tasks are still running?"; + if (exc != null) { + throw IOUtils.rethrowAlways(exc); + } + return results; } @Override @@ -89,128 +141,62 @@ public final class TaskExecutor { return "TaskExecutor(" + "executor=" + executor + ')'; } - /** - * Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and - * exposes the ability to invoke such tasks and wait for them all to complete their execution and - * provide their results. Additionally, if one task throws an exception, all other tasks from the - * same group are cancelled, to avoid needless computation as their results would not be exposed - * anyways. Creates one {@link FutureTask} for each {@link Callable} provided - * - * @param the return type of all the callables - */ - private static final class TaskGroup { - private final List> futures; - - TaskGroup(Collection> callables) { - List> tasks = new ArrayList<>(callables.size()); - for (Callable callable : callables) { - tasks.add(createTask(callable)); + private static boolean assertAllFuturesCompleted(Collection> futures) { + for (Future future : futures) { + if (future.isDone() == false) { + return false; } - this.futures = Collections.unmodifiableList(tasks); + } + return true; + } + + private static void cancelAll(Collection> futures) { + for (Future future : futures) { + future.cancel(false); + } + } + + private static class Task extends FutureTask { + + private final AtomicBoolean startedOrCancelled = new AtomicBoolean(false); + + private final Collection> futures; + + public Task(Callable callable, Collection> futures) { + super(callable); + this.futures = futures; } - RunnableFuture createTask(Callable callable) { - return new FutureTask<>(callable) { - - private final AtomicBoolean startedOrCancelled = new AtomicBoolean(false); - - @Override - public void run() { - if (startedOrCancelled.compareAndSet(false, true)) { - super.run(); - } - } - - @Override - protected void setException(Throwable t) { - super.setException(t); - cancelAll(); - } - - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - assert mayInterruptIfRunning == false - : "cancelling tasks that are running is not supported"; - /* - Future#get (called in invokeAll) throws CancellationException when invoked against a running task that has been cancelled but - leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns. - Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will - wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and - make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op. - */ - if (startedOrCancelled.compareAndSet(false, true)) { - // task is cancelled hence it has no results to return. That's fine: they would be - // ignored anyway. - set(null); - return true; - } - return false; - } - }; + @Override + public void run() { + if (startedOrCancelled.compareAndSet(false, true)) { + super.run(); + } } - List invokeAll(Executor executor) throws IOException { - final int count = futures.size(); - // taskId provides the first index of an un-executed task in #futures - final AtomicInteger taskId = new AtomicInteger(0); - // we fork execution count - 1 tasks to execute at least one task on the current thread to - // minimize needless forking and blocking of the current thread - if (count > 1) { - final Runnable work = - () -> { - int id = taskId.getAndIncrement(); - if (id < count) { - futures.get(id).run(); - } - }; - for (int j = 0; j < count - 1; j++) { - executor.execute(work); - } - } - // try to execute as many tasks as possible on the current thread to minimize context - // switching in case of long running concurrent - // tasks as well as dead-locking if the current thread is part of #executor for executors that - // have limited or no parallelism - int id; - while ((id = taskId.getAndIncrement()) < count) { - futures.get(id).run(); - if (id >= count - 1) { - // save redundant CAS in case this was the last task - break; - } - } - Throwable exc = null; - List results = new ArrayList<>(count); - for (int i = 0; i < count; i++) { - Future future = futures.get(i); - try { - results.add(future.get()); - } catch (InterruptedException e) { - exc = IOUtils.useOrSuppress(exc, new ThreadInterruptedException(e)); - } catch (ExecutionException e) { - exc = IOUtils.useOrSuppress(exc, e.getCause()); - } - } - assert assertAllFuturesCompleted() : "Some tasks are still running?"; - if (exc != null) { - throw IOUtils.rethrowAlways(exc); - } - return results; + @Override + protected void setException(Throwable t) { + super.setException(t); + cancelAll(futures); } - private boolean assertAllFuturesCompleted() { - for (RunnableFuture future : futures) { - if (future.isDone() == false) { - return false; - } - } - return true; - } - - private void cancelAll() { - for (Future future : futures) { - future.cancel(false); + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + assert mayInterruptIfRunning == false : "cancelling tasks that are running is not supported"; + /* + Future#get (called in #collectResults) throws CancellationException when invoked against a running task that has been cancelled but + leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns. + Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will + wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and + make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op. + */ + if (startedOrCancelled.compareAndSet(false, true)) { + // task is cancelled hence it has no results to return. That's fine: they would be + // ignored anyway. + set(null); + return true; } + return false; } } }