Restore thread's original context before returning to the ThreadPool

This commit ensures that we always restore the thread's original context after execution of
a context preserving runnable. We always wrap runnables in a wrapper that restores the context
at the time it was submitted to the execute method. The ContextPreservingAbstractRunnable
would restore the calling context in the doRun method and then in a try with resources
block would restore the thread's original context. However, the onFailure and onAfter methods
of a AbstractRunnable could modify the thread context and this modified thread context would
continue on as the thread's context after it was returned to the pool and potentially used
for a different purpose.
This commit is contained in:
Jay Modi 2016-11-08 17:51:31 -05:00 committed by GitHub
parent b326f0bc51
commit 6ecb023468
4 changed files with 203 additions and 5 deletions

View File

@ -109,6 +109,13 @@ public class EsThreadPoolExecutor extends ThreadPoolExecutor {
}
}
@Override
protected void afterExecute(Runnable r, Throwable t) {
super.afterExecute(r, t);
assert contextHolder.isDefaultContext() : "the thread context is not the default context and the thread [" +
Thread.currentThread().getName() + "] is being returned to the pool after executing [" + r + "]";
}
/**
* Returns a stream of all pending tasks. This is similar to {@link #getQueue()} but will expose the originally submitted
* {@link Runnable} instances rather than potentially wrapped ones.

View File

@ -107,6 +107,7 @@ public class PrioritizedEsThreadPoolExecutor extends EsThreadPoolExecutor {
@Override
protected void afterExecute(Runnable r, Throwable t) {
super.afterExecute(r, t);
current.remove(r);
}

View File

@ -246,6 +246,13 @@ public final class ThreadContext implements Closeable, Writeable {
return command;
}
/**
* Returns true if the current context is the default context.
*/
boolean isDefaultContext() {
return threadLocal.get() == DEFAULT_CONTEXT;
}
@FunctionalInterface
public interface StoredContext extends AutoCloseable {
@Override
@ -468,10 +475,12 @@ public final class ThreadContext implements Closeable, Writeable {
*/
private class ContextPreservingAbstractRunnable extends AbstractRunnable {
private final AbstractRunnable in;
private final ThreadContext.StoredContext ctx;
private final ThreadContext.StoredContext creatorsContext;
private ThreadContext.StoredContext threadsOriginalContext = null;
private ContextPreservingAbstractRunnable(AbstractRunnable in) {
ctx = newStoredContext();
creatorsContext = newStoredContext();
this.in = in;
}
@ -482,7 +491,13 @@ public final class ThreadContext implements Closeable, Writeable {
@Override
public void onAfter() {
in.onAfter();
try {
in.onAfter();
} finally {
if (threadsOriginalContext != null) {
threadsOriginalContext.restore();
}
}
}
@Override
@ -498,8 +513,9 @@ public final class ThreadContext implements Closeable, Writeable {
@Override
protected void doRun() throws Exception {
boolean whileRunning = false;
try (ThreadContext.StoredContext ignore = stashContext()){
ctx.restore();
threadsOriginalContext = stashContext();
try {
creatorsContext.restore();
whileRunning = true;
in.doRun();
whileRunning = false;

View File

@ -319,6 +319,9 @@ public class ThreadContextTests extends ESTestCase {
// But we do inside of it
withContext.run();
// but not after
assertNull(threadContext.getHeader("foo"));
}
}
@ -350,6 +353,177 @@ public class ThreadContextTests extends ESTestCase {
}
}
public void testPreservesThreadsOriginalContextOnRunException() throws IOException {
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
Runnable withContext;
// create a abstract runnable, add headers and transient objects and verify in the methods
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("foo", "bar_transient");
withContext = threadContext.preserveContext(new AbstractRunnable() {
@Override
public void onAfter() {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertNotNull(threadContext.getTransient("failure"));
assertEquals("exception from doRun", ((RuntimeException)threadContext.getTransient("failure")).getMessage());
assertFalse(threadContext.isDefaultContext());
threadContext.putTransient("after", "after");
}
@Override
public void onFailure(Exception e) {
assertEquals("exception from doRun", e.getMessage());
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertFalse(threadContext.isDefaultContext());
threadContext.putTransient("failure", e);
}
@Override
protected void doRun() throws Exception {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertFalse(threadContext.isDefaultContext());
throw new RuntimeException("exception from doRun");
}
});
}
// We don't see the header outside of the runnable
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertNull(threadContext.getTransient("failure"));
assertNull(threadContext.getTransient("after"));
assertTrue(threadContext.isDefaultContext());
// But we do inside of it
withContext.run();
// verify not seen after
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertNull(threadContext.getTransient("failure"));
assertNull(threadContext.getTransient("after"));
assertTrue(threadContext.isDefaultContext());
// repeat with regular runnable
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("foo", "bar_transient");
withContext = threadContext.preserveContext(() -> {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertFalse(threadContext.isDefaultContext());
threadContext.putTransient("run", true);
throw new RuntimeException("exception from run");
});
}
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertNull(threadContext.getTransient("run"));
assertTrue(threadContext.isDefaultContext());
final Runnable runnable = withContext;
RuntimeException e = expectThrows(RuntimeException.class, runnable::run);
assertEquals("exception from run", e.getMessage());
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertNull(threadContext.getTransient("run"));
assertTrue(threadContext.isDefaultContext());
}
}
public void testPreservesThreadsOriginalContextOnFailureException() throws IOException {
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
Runnable withContext;
// a runnable that throws from onFailure
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("foo", "bar_transient");
withContext = threadContext.preserveContext(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
throw new RuntimeException("from onFailure", e);
}
@Override
protected void doRun() throws Exception {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertFalse(threadContext.isDefaultContext());
throw new RuntimeException("from doRun");
}
});
}
// We don't see the header outside of the runnable
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertTrue(threadContext.isDefaultContext());
// But we do inside of it
RuntimeException e = expectThrows(RuntimeException.class, withContext::run);
assertEquals("from onFailure", e.getMessage());
assertEquals("from doRun", e.getCause().getMessage());
// but not after
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertTrue(threadContext.isDefaultContext());
}
}
public void testPreservesThreadsOriginalContextOnAfterException() throws IOException {
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
Runnable withContext;
// a runnable that throws from onAfter
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("foo", "bar_transient");
withContext = threadContext.preserveContext(new AbstractRunnable() {
@Override
public void onAfter() {
throw new RuntimeException("from onAfter");
}
@Override
public void onFailure(Exception e) {
throw new RuntimeException("from onFailure", e);
}
@Override
protected void doRun() throws Exception {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertFalse(threadContext.isDefaultContext());
}
});
}
// We don't see the header outside of the runnable
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertTrue(threadContext.isDefaultContext());
// But we do inside of it
RuntimeException e = expectThrows(RuntimeException.class, withContext::run);
assertEquals("from onAfter", e.getMessage());
assertNull(e.getCause());
// but not after
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertTrue(threadContext.isDefaultContext());
}
}
/**
* Sometimes wraps a Runnable in an AbstractRunnable.
*/