diff --git a/qa/evil-tests/src/test/java/org/elasticsearch/threadpool/EvilThreadPoolTests.java b/qa/evil-tests/src/test/java/org/elasticsearch/threadpool/EvilThreadPoolTests.java new file mode 100644 index 00000000000..c7848267ff1 --- /dev/null +++ b/qa/evil-tests/src/test/java/org/elasticsearch/threadpool/EvilThreadPoolTests.java @@ -0,0 +1,105 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.threadpool; + +import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.hamcrest.Matchers.instanceOf; + +public class EvilThreadPoolTests extends ESTestCase { + + private ThreadPool threadPool; + + @Before + public void setUpThreadPool() { + threadPool = new TestThreadPool(EvilThreadPoolTests.class.getName()); + } + + @After + public void tearDownThreadPool() throws InterruptedException { + terminate(threadPool); + } + + public void testExecutionException() throws InterruptedException { + runExecutionExceptionTest( + () -> { + throw new Error("future error"); + }, + true, + o -> { + assertTrue(o.isPresent()); + assertThat(o.get(), instanceOf(Error.class)); + assertThat(o.get(), hasToString(containsString("future error"))); + }); + runExecutionExceptionTest( + () -> { + throw new IllegalStateException("future exception"); + }, + false, + o -> assertFalse(o.isPresent())); + } + + private void runExecutionExceptionTest( + final Runnable runnable, + final boolean expectThrowable, + final Consumer> consumer) throws InterruptedException { + final AtomicReference throwableReference = new AtomicReference<>(); + final Thread.UncaughtExceptionHandler uncaughtExceptionHandler = Thread.getDefaultUncaughtExceptionHandler(); + final CountDownLatch uncaughtExceptionHandlerLatch = new CountDownLatch(1); + + try { + Thread.setDefaultUncaughtExceptionHandler((t, e) -> { + assertTrue(expectThrowable); + throwableReference.set(e); + uncaughtExceptionHandlerLatch.countDown(); + }); + + final CountDownLatch supplierLatch = new CountDownLatch(1); + + threadPool.generic().submit(() -> { + try { + runnable.run(); + } finally { + supplierLatch.countDown(); + } + }); + + supplierLatch.await(); + + if (expectThrowable) { + uncaughtExceptionHandlerLatch.await(); + } + consumer.accept(Optional.ofNullable(throwableReference.get())); + } finally { + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler); + } + } + +} diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java index 6427368c4b9..8f950c5434b 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java @@ -19,9 +19,11 @@ package org.elasticsearch.common.util.concurrent; import org.apache.lucene.util.CloseableThreadLocal; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.logging.ESLoggerFactory; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; @@ -33,7 +35,12 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.FutureTask; +import java.util.concurrent.RunnableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.function.Supplier; @@ -564,6 +571,36 @@ public final class ThreadContext implements Closeable, Writeable { ctx.restore(); whileRunning = true; in.run(); + if (in instanceof RunnableFuture) { + /* + * The wrapped runnable arose from asynchronous submission of a task to an executor. If an uncaught exception was thrown + * during the execution of this task, we need to inspect this runnable and see if it is an error that should be + * propagated to the uncaught exception handler. + */ + try { + ((RunnableFuture) in).get(); + } catch (final Exception e) { + /* + * In theory, Future#get can only throw a cancellation exception, an interrupted exception, or an execution + * exception. We want to ignore cancellation exceptions, restore the interrupt status on interrupted exceptions, and + * inspect the cause of an execution. We are going to be extra paranoid here though and completely unwrap the + * exception to ensure that there is not a buried error anywhere. We assume that a general exception has been + * handled by the executed task or the task submitter. + */ + assert e instanceof CancellationException + || e instanceof InterruptedException + || e instanceof ExecutionException : e; + final Optional maybeError = ExceptionsHelper.maybeError(e, ESLoggerFactory.getLogger(ThreadContext.class)); + if (maybeError.isPresent()) { + // throw this error where it will propagate to the uncaught exception handler + throw maybeError.get(); + } + if (e instanceof InterruptedException) { + // restore the interrupt status + Thread.currentThread().interrupt(); + } + } + } whileRunning = false; } catch (IllegalStateException ex) { if (whileRunning || threadLocal.closed.get() == false) {