From 1200ecce3a299f798095e04584cc11ac530ddea8 Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Tue, 24 Oct 2023 17:38:25 +0200 Subject: [PATCH] TaskExecutor to cancel all tasks on exception (#12689) When operations are parallelized, like query rewrite, or search, or createWeight, one of the tasks may throw an exception. In that case we wait for all tasks to be completed before re-throwing the exception that were caught. Tasks that were not started when the exception is captured though can be safely skipped. Ideally we would also cancel ongoing tasks but I left that for another time. --- .../apache/lucene/search/TaskExecutor.java | 180 ++++++++++++------ .../lucene/search/TestTaskExecutor.java | 109 +++++++++-- 2 files changed, 209 insertions(+), 80 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java b/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java index 0a383a8e815..f2be51206e4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java +++ b/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java @@ -20,6 +20,7 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.concurrent.Callable; @@ -27,6 +28,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; +import java.util.concurrent.RunnableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.ThreadInterruptedException; @@ -64,70 +67,127 @@ public final class TaskExecutor { * @param the return type of the task execution */ public List invokeAll(Collection> callables) throws IOException { - List> tasks = new ArrayList<>(callables.size()); - boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0; - for (Callable callable : callables) { - Task task = new Task<>(callable); - tasks.add(task); - if (runOnCallerThread) { - task.run(); - } else { - executor.execute(task); - } - } - - Throwable exc = null; - final List results = new ArrayList<>(); - for (Future future : tasks) { - try { - results.add(future.get()); - } catch (InterruptedException e) { - var newException = new ThreadInterruptedException(e); - if (exc == null) { - exc = newException; - } else { - exc.addSuppressed(newException); - } - } catch (ExecutionException e) { - if (exc == null) { - exc = e.getCause(); - } else { - exc.addSuppressed(e.getCause()); - } - } - } - if (exc != null) { - throw IOUtils.rethrowAlways(exc); - } - return results; - } - - /** - * Extension of {@link FutureTask} that tracks the number of tasks that are running in each - * thread. - * - * @param the return type of the task - */ - private static final class Task extends FutureTask { - private Task(Callable callable) { - super(callable); - } - - @Override - public void run() { - try { - Integer counter = numberOfRunningTasksInCurrentThread.get(); - numberOfRunningTasksInCurrentThread.set(counter + 1); - super.run(); - } finally { - Integer counter = numberOfRunningTasksInCurrentThread.get(); - numberOfRunningTasksInCurrentThread.set(counter - 1); - } - } + TaskGroup taskGroup = new TaskGroup<>(callables); + return taskGroup.invokeAll(executor); } @Override public String toString() { return "TaskExecutor(" + "executor=" + executor + ')'; } + + /** + * Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and + * exposes the ability to invoke such tasks and wait for them all to complete their execution and + * provide their results. Ensures that each task does not get parallelized further: this is + * important to avoid a deadlock in situations where one executor thread waits on other executor + * threads to complete before it can progress. This happens in situations where for instance + * {@link Query#createWeight(IndexSearcher, ScoreMode, float)} is called as part of searching each + * slice, like {@link TopFieldCollector#populateScores(ScoreDoc[], IndexSearcher, Query)} does. + * Additionally, if one task throws an exception, all other tasks from the same group are + * cancelled, to avoid needless computation as their results would not be exposed anyways. Creates + * one {@link FutureTask} for each {@link Callable} provided + * + * @param the return type of all the callables + */ + private static final class TaskGroup { + private final Collection> futures; + + TaskGroup(Collection> callables) { + List> tasks = new ArrayList<>(callables.size()); + for (Callable callable : callables) { + tasks.add(createTask(callable)); + } + this.futures = Collections.unmodifiableCollection(tasks); + } + + RunnableFuture createTask(Callable callable) { + // -1: cancelled; 0: not yet started; 1: started + AtomicBoolean startedOrCancelled = new AtomicBoolean(false); + return new FutureTask<>( + () -> { + if (startedOrCancelled.compareAndSet(false, true)) { + try { + Integer counter = numberOfRunningTasksInCurrentThread.get(); + numberOfRunningTasksInCurrentThread.set(counter + 1); + return callable.call(); + } catch (Throwable t) { + cancelAll(); + throw t; + } finally { + Integer counter = numberOfRunningTasksInCurrentThread.get(); + numberOfRunningTasksInCurrentThread.set(counter - 1); + } + } + // task is cancelled hence it has no results to return. That's fine: they would be + // ignored anyway. + return null; + }) { + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + assert mayInterruptIfRunning == false + : "cancelling tasks that are running is not supported"; + /* + Future#get (called in invokeAll) throws CancellationException when invoked against a running task that has been cancelled but + leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns. + Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will + wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and + make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op. + */ + return startedOrCancelled.compareAndSet(false, true); + } + }; + } + + List invokeAll(Executor executor) throws IOException { + boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0; + for (Runnable runnable : futures) { + if (runOnCallerThread) { + runnable.run(); + } else { + executor.execute(runnable); + } + } + Throwable exc = null; + List results = new ArrayList<>(futures.size()); + for (Future future : futures) { + try { + results.add(future.get()); + } catch (InterruptedException e) { + var newException = new ThreadInterruptedException(e); + if (exc == null) { + exc = newException; + } else { + exc.addSuppressed(newException); + } + } catch (ExecutionException e) { + if (exc == null) { + exc = e.getCause(); + } else { + exc.addSuppressed(e.getCause()); + } + } + } + assert assertAllFuturesCompleted() : "Some tasks are still running?"; + if (exc != null) { + throw IOUtils.rethrowAlways(exc); + } + return results; + } + + private boolean assertAllFuturesCompleted() { + for (RunnableFuture future : futures) { + if (future.isDone() == false) { + return false; + } + } + return true; + } + + private void cancelAll() { + for (Future future : futures) { + future.cancel(false); + } + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTaskExecutor.java b/lucene/core/src/test/org/apache/lucene/search/TestTaskExecutor.java index c18be1ad4f4..341ff5ba39d 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTaskExecutor.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTaskExecutor.java @@ -22,8 +22,10 @@ import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.document.Document; import org.apache.lucene.index.DirectoryReader; @@ -32,6 +34,8 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.NamedThreadFactory; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -43,7 +47,8 @@ public class TestTaskExecutor extends LuceneTestCase { public static void createExecutor() { executorService = Executors.newFixedThreadPool( - 1, new NamedThreadFactory(TestTaskExecutor.class.getSimpleName())); + random().nextBoolean() ? 1 : 2, + new NamedThreadFactory(TestTaskExecutor.class.getSimpleName())); } @AfterClass @@ -228,11 +233,21 @@ public class TestTaskExecutor extends LuceneTestCase { } public void testInvokeAllDoesNotLeaveTasksBehind() { - TaskExecutor taskExecutor = new TaskExecutor(executorService); + AtomicInteger tasksStarted = new AtomicInteger(0); + TaskExecutor taskExecutor = + new TaskExecutor( + command -> { + executorService.execute( + () -> { + tasksStarted.incrementAndGet(); + command.run(); + }); + }); AtomicInteger tasksExecuted = new AtomicInteger(0); List> callables = new ArrayList<>(); callables.add( () -> { + tasksExecuted.incrementAndGet(); throw new RuntimeException(); }); int tasksWithNormalExit = 99; @@ -244,7 +259,14 @@ public class TestTaskExecutor extends LuceneTestCase { }); } expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables)); - assertEquals(tasksWithNormalExit, tasksExecuted.get()); + int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize(); + if (maximumPoolSize == 1) { + assertEquals(1, tasksExecuted.get()); + } else { + MatcherAssert.assertThat(tasksExecuted.get(), Matchers.greaterThanOrEqualTo(1)); + } + // the callables are technically all run, but the cancelled ones will be no-op + assertEquals(100, tasksStarted.get()); } /** @@ -253,36 +275,83 @@ public class TestTaskExecutor extends LuceneTestCase { */ public void testInvokeAllCatchesMultipleExceptions() { TaskExecutor taskExecutor = new TaskExecutor(executorService); - AtomicInteger tasksExecuted = new AtomicInteger(0); List> callables = new ArrayList<>(); + int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize(); + // if we have multiple threads, make sure both are started before an exception is thrown, + // otherwise there may or may not be a suppressed exception + CountDownLatch latchA = new CountDownLatch(1); + CountDownLatch latchB = new CountDownLatch(1); callables.add( () -> { + if (maximumPoolSize > 1) { + latchA.countDown(); + latchB.await(); + } throw new RuntimeException("exception A"); }); - int tasksWithNormalExit = 50; - for (int i = 0; i < tasksWithNormalExit; i++) { - callables.add( - () -> { - tasksExecuted.incrementAndGet(); - return null; - }); - } callables.add( () -> { + if (maximumPoolSize > 1) { + latchB.countDown(); + latchA.await(); + } throw new IllegalStateException("exception B"); }); RuntimeException exc = expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables)); Throwable[] suppressed = exc.getSuppressed(); - assertEquals(1, suppressed.length); - if (exc.getMessage().equals("exception A")) { - assertEquals("exception B", suppressed[0].getMessage()); - } else { - assertEquals("exception A", suppressed[0].getMessage()); - assertEquals("exception B", exc.getMessage()); - } - assertEquals(tasksWithNormalExit, tasksExecuted.get()); + if (maximumPoolSize == 1) { + assertEquals(0, suppressed.length); + } else { + assertEquals(1, suppressed.length); + if (exc.getMessage().equals("exception A")) { + assertEquals("exception B", suppressed[0].getMessage()); + } else { + assertEquals("exception A", suppressed[0].getMessage()); + assertEquals("exception B", exc.getMessage()); + } + } + } + + public void testCancelTasksOnException() { + TaskExecutor taskExecutor = new TaskExecutor(executorService); + int maximumPoolSize = ((ThreadPoolExecutor) executorService).getMaximumPoolSize(); + final int numTasks = random().nextInt(10, 50); + final int throwingTask = random().nextInt(numTasks); + boolean error = random().nextBoolean(); + List> tasks = new ArrayList<>(numTasks); + AtomicInteger executedTasks = new AtomicInteger(0); + for (int i = 0; i < numTasks; i++) { + final int index = i; + tasks.add( + () -> { + if (index == throwingTask) { + if (error) { + throw new OutOfMemoryError(); + } else { + throw new RuntimeException(); + } + } + if (index > throwingTask && maximumPoolSize == 1) { + throw new AssertionError("task should not have started"); + } + executedTasks.incrementAndGet(); + return null; + }); + } + Throwable throwable; + if (error) { + throwable = expectThrows(OutOfMemoryError.class, () -> taskExecutor.invokeAll(tasks)); + } else { + throwable = expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(tasks)); + } + assertEquals(0, throwable.getSuppressed().length); + if (maximumPoolSize == 1) { + assertEquals(throwingTask, executedTasks.get()); + } else { + MatcherAssert.assertThat(executedTasks.get(), Matchers.greaterThanOrEqualTo(throwingTask)); + } } }