diff --git a/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java b/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java index 0a383a8e815..f2be51206e4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java +++ b/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java @@ -20,6 +20,7 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.concurrent.Callable; @@ -27,6 +28,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; +import java.util.concurrent.RunnableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.ThreadInterruptedException; @@ -64,70 +67,127 @@ public final class TaskExecutor { * @param the return type of the task execution */ public List invokeAll(Collection> callables) throws IOException { - List> tasks = new ArrayList<>(callables.size()); - boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0; - for (Callable callable : callables) { - Task task = new Task<>(callable); - tasks.add(task); - if (runOnCallerThread) { - task.run(); - } else { - executor.execute(task); - } - } - - Throwable exc = null; - final List results = new ArrayList<>(); - for (Future future : tasks) { - try { - results.add(future.get()); - } catch (InterruptedException e) { - var newException = new ThreadInterruptedException(e); - if (exc == null) { - exc = newException; - } else { - exc.addSuppressed(newException); - } - } catch (ExecutionException e) { - if (exc == null) { - exc = e.getCause(); - } else { - exc.addSuppressed(e.getCause()); - } - } - } - if (exc != null) { - throw IOUtils.rethrowAlways(exc); - } - return results; - } - - /** - * Extension of {@link FutureTask} that tracks the number of tasks that are running in each - * thread. - * - * @param the return type of the task - */ - private static final class Task extends FutureTask { - private Task(Callable callable) { - super(callable); - } - - @Override - public void run() { - try { - Integer counter = numberOfRunningTasksInCurrentThread.get(); - numberOfRunningTasksInCurrentThread.set(counter + 1); - super.run(); - } finally { - Integer counter = numberOfRunningTasksInCurrentThread.get(); - numberOfRunningTasksInCurrentThread.set(counter - 1); - } - } + TaskGroup taskGroup = new TaskGroup<>(callables); + return taskGroup.invokeAll(executor); } @Override public String toString() { return "TaskExecutor(" + "executor=" + executor + ')'; } + + /** + * Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and + * exposes the ability to invoke such tasks and wait for them all to complete their execution and + * provide their results. Ensures that each task does not get parallelized further: this is + * important to avoid a deadlock in situations where one executor thread waits on other executor + * threads to complete before it can progress. This happens in situations where for instance + * {@link Query#createWeight(IndexSearcher, ScoreMode, float)} is called as part of searching each + * slice, like {@link TopFieldCollector#populateScores(ScoreDoc[], IndexSearcher, Query)} does. + * Additionally, if one task throws an exception, all other tasks from the same group are + * cancelled, to avoid needless computation as their results would not be exposed anyways. Creates + * one {@link FutureTask} for each {@link Callable} provided + * + * @param the return type of all the callables + */ + private static final class TaskGroup { + private final Collection> futures; + + TaskGroup(Collection> callables) { + List> tasks = new ArrayList<>(callables.size()); + for (Callable callable : callables) { + tasks.add(createTask(callable)); + } + this.futures = Collections.unmodifiableCollection(tasks); + } + + RunnableFuture createTask(Callable callable) { + // -1: cancelled; 0: not yet started; 1: started + AtomicBoolean startedOrCancelled = new AtomicBoolean(false); + return new FutureTask<>( + () -> { + if (startedOrCancelled.compareAndSet(false, true)) { + try { + Integer counter = numberOfRunningTasksInCurrentThread.get(); + numberOfRunningTasksInCurrentThread.set(counter + 1); + return callable.call(); + } catch (Throwable t) { + cancelAll(); + throw t; + } finally { + Integer counter = numberOfRunningTasksInCurrentThread.get(); + numberOfRunningTasksInCurrentThread.set(counter - 1); + } + } + // task is cancelled hence it has no results to return. That's fine: they would be + // ignored anyway. + return null; + }) { + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + assert mayInterruptIfRunning == false + : "cancelling tasks that are running is not supported"; + /* + Future#get (called in invokeAll) throws CancellationException when invoked against a running task that has been cancelled but + leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns. + Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will + wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and + make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op. + */ + return startedOrCancelled.compareAndSet(false, true); + } + }; + } + + List invokeAll(Executor executor) throws IOException { + boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0; + for (Runnable runnable : futures) { + if (runOnCallerThread) { + runnable.run(); + } else { + executor.execute(runnable); + } + } + Throwable exc = null; + List results = new ArrayList<>(futures.size()); + for (Future future : futures) { + try { + results.add(future.get()); + } catch (InterruptedException e) { + var newException = new ThreadInterruptedException(e); + if (exc == null) { + exc = newException; + } else { + exc.addSuppressed(newException); + } + } catch (ExecutionException e) { + if (exc == null) { + exc = e.getCause(); + } else { + exc.addSuppressed(e.getCause()); + } + } + } + assert assertAllFuturesCompleted() : "Some tasks are still running?"; + if (exc != null) { + throw IOUtils.rethrowAlways(exc); + } + return results; + } + + private boolean assertAllFuturesCompleted() { + for (RunnableFuture future : futures) { + if (future.isDone() == false) { + return false; + } + } + return true; + } + + private void cancelAll() { + for (Future future : futures) { + future.cancel(false); + } + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTaskExecutor.java b/lucene/core/src/test/org/apache/lucene/search/TestTaskExecutor.java index c18be1ad4f4..341ff5ba39d 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTaskExecutor.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTaskExecutor.java @@ -22,8 +22,10 @@ import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.document.Document; import org.apache.lucene.index.DirectoryReader; @@ -32,6 +34,8 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.NamedThreadFactory; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -43,7 +47,8 @@ public class TestTaskExecutor extends LuceneTestCase { public static void createExecutor() { executorService = Executors.newFixedThreadPool( - 1, new NamedThreadFactory(TestTaskExecutor.class.getSimpleName())); + random().nextBoolean() ? 1 : 2, + new NamedThreadFactory(TestTaskExecutor.class.getSimpleName())); } @AfterClass @@ -228,11 +233,21 @@ public class TestTaskExecutor extends LuceneTestCase { } public void testInvokeAllDoesNotLeaveTasksBehind() { - TaskExecutor taskExecutor = new TaskExecutor(executorService); + AtomicInteger tasksStarted = new AtomicInteger(0); + TaskExecutor taskExecutor = + new TaskExecutor( + command -> { + executorService.execute( + () -> { + tasksStarted.incrementAndGet(); + command.run(); + }); + }); AtomicInteger tasksExecuted = new AtomicInteger(0); List> callables = new ArrayList<>(); callables.add( () -> { + tasksExecuted.incrementAndGet(); throw new RuntimeException(); }); int tasksWithNormalExit = 99; @@ -244,7 +259,14 @@ public class TestTaskExecutor extends LuceneTestCase { }); } expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables)); - assertEquals(tasksWithNormalExit, tasksExecuted.get()); + int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize(); + if (maximumPoolSize == 1) { + assertEquals(1, tasksExecuted.get()); + } else { + MatcherAssert.assertThat(tasksExecuted.get(), Matchers.greaterThanOrEqualTo(1)); + } + // the callables are technically all run, but the cancelled ones will be no-op + assertEquals(100, tasksStarted.get()); } /** @@ -253,36 +275,83 @@ public class TestTaskExecutor extends LuceneTestCase { */ public void testInvokeAllCatchesMultipleExceptions() { TaskExecutor taskExecutor = new TaskExecutor(executorService); - AtomicInteger tasksExecuted = new AtomicInteger(0); List> callables = new ArrayList<>(); + int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize(); + // if we have multiple threads, make sure both are started before an exception is thrown, + // otherwise there may or may not be a suppressed exception + CountDownLatch latchA = new CountDownLatch(1); + CountDownLatch latchB = new CountDownLatch(1); callables.add( () -> { + if (maximumPoolSize > 1) { + latchA.countDown(); + latchB.await(); + } throw new RuntimeException("exception A"); }); - int tasksWithNormalExit = 50; - for (int i = 0; i < tasksWithNormalExit; i++) { - callables.add( - () -> { - tasksExecuted.incrementAndGet(); - return null; - }); - } callables.add( () -> { + if (maximumPoolSize > 1) { + latchB.countDown(); + latchA.await(); + } throw new IllegalStateException("exception B"); }); RuntimeException exc = expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables)); Throwable[] suppressed = exc.getSuppressed(); - assertEquals(1, suppressed.length); - if (exc.getMessage().equals("exception A")) { - assertEquals("exception B", suppressed[0].getMessage()); - } else { - assertEquals("exception A", suppressed[0].getMessage()); - assertEquals("exception B", exc.getMessage()); - } - assertEquals(tasksWithNormalExit, tasksExecuted.get()); + if (maximumPoolSize == 1) { + assertEquals(0, suppressed.length); + } else { + assertEquals(1, suppressed.length); + if (exc.getMessage().equals("exception A")) { + assertEquals("exception B", suppressed[0].getMessage()); + } else { + assertEquals("exception A", suppressed[0].getMessage()); + assertEquals("exception B", exc.getMessage()); + } + } + } + + public void testCancelTasksOnException() { + TaskExecutor taskExecutor = new TaskExecutor(executorService); + int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize(); + final int numTasks = random().nextInt(10, 50); + final int throwingTask = random().nextInt(numTasks); + boolean error = random().nextBoolean(); + List> tasks = new ArrayList<>(numTasks); + AtomicInteger executedTasks = new AtomicInteger(0); + for (int i = 0; i < numTasks; i++) { + final int index = i; + tasks.add( + () -> { + if (index == throwingTask) { + if (error) { + throw new OutOfMemoryError(); + } else { + throw new RuntimeException(); + } + } + if (index > throwingTask && maximumPoolSize == 1) { + throw new AssertionError("task should not have started"); + } + executedTasks.incrementAndGet(); + return null; + }); + } + Throwable throwable; + if (error) { + throwable = expectThrows(OutOfMemoryError.class, () -> taskExecutor.invokeAll(tasks)); + } else { + throwable = expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(tasks)); + } + assertEquals(0, throwable.getSuppressed().length); + if (maximumPoolSize == 1) { + assertEquals(throwingTask, executedTasks.get()); + } else { + MatcherAssert.assertThat(executedTasks.get(), Matchers.greaterThanOrEqualTo(throwingTask)); + } } }