Restore thread's original context before returning to the ThreadPool
This commit ensures that we always restore the thread's original context after execution of a context preserving runnable. We always wrap runnables in a wrapper that restores the context at the time it was submitted to the execute method. The ContextPreservingAbstractRunnable would restore the calling context in the doRun method and then in a try with resources block would restore the thread's original context. However, the onFailure and onAfter methods of a AbstractRunnable could modify the thread context and this modified thread context would continue on as the thread's context after it was returned to the pool and potentially used for a different purpose.
This commit is contained in:
parent
b326f0bc51
commit
6ecb023468
|
@ -109,6 +109,13 @@ public class EsThreadPoolExecutor extends ThreadPoolExecutor {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void afterExecute(Runnable r, Throwable t) {
|
||||
super.afterExecute(r, t);
|
||||
assert contextHolder.isDefaultContext() : "the thread context is not the default context and the thread [" +
|
||||
Thread.currentThread().getName() + "] is being returned to the pool after executing [" + r + "]";
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a stream of all pending tasks. This is similar to {@link #getQueue()} but will expose the originally submitted
|
||||
* {@link Runnable} instances rather than potentially wrapped ones.
|
||||
|
|
|
@ -107,6 +107,7 @@ public class PrioritizedEsThreadPoolExecutor extends EsThreadPoolExecutor {
|
|||
|
||||
@Override
|
||||
protected void afterExecute(Runnable r, Throwable t) {
|
||||
super.afterExecute(r, t);
|
||||
current.remove(r);
|
||||
}
|
||||
|
||||
|
|
|
@ -246,6 +246,13 @@ public final class ThreadContext implements Closeable, Writeable {
|
|||
return command;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the current context is the default context.
|
||||
*/
|
||||
boolean isDefaultContext() {
|
||||
return threadLocal.get() == DEFAULT_CONTEXT;
|
||||
}
|
||||
|
||||
@FunctionalInterface
|
||||
public interface StoredContext extends AutoCloseable {
|
||||
@Override
|
||||
|
@ -468,10 +475,12 @@ public final class ThreadContext implements Closeable, Writeable {
|
|||
*/
|
||||
private class ContextPreservingAbstractRunnable extends AbstractRunnable {
|
||||
private final AbstractRunnable in;
|
||||
private final ThreadContext.StoredContext ctx;
|
||||
private final ThreadContext.StoredContext creatorsContext;
|
||||
|
||||
private ThreadContext.StoredContext threadsOriginalContext = null;
|
||||
|
||||
private ContextPreservingAbstractRunnable(AbstractRunnable in) {
|
||||
ctx = newStoredContext();
|
||||
creatorsContext = newStoredContext();
|
||||
this.in = in;
|
||||
}
|
||||
|
||||
|
@ -482,7 +491,13 @@ public final class ThreadContext implements Closeable, Writeable {
|
|||
|
||||
@Override
|
||||
public void onAfter() {
|
||||
in.onAfter();
|
||||
try {
|
||||
in.onAfter();
|
||||
} finally {
|
||||
if (threadsOriginalContext != null) {
|
||||
threadsOriginalContext.restore();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -498,8 +513,9 @@ public final class ThreadContext implements Closeable, Writeable {
|
|||
@Override
|
||||
protected void doRun() throws Exception {
|
||||
boolean whileRunning = false;
|
||||
try (ThreadContext.StoredContext ignore = stashContext()){
|
||||
ctx.restore();
|
||||
threadsOriginalContext = stashContext();
|
||||
try {
|
||||
creatorsContext.restore();
|
||||
whileRunning = true;
|
||||
in.doRun();
|
||||
whileRunning = false;
|
||||
|
|
|
@ -319,6 +319,9 @@ public class ThreadContextTests extends ESTestCase {
|
|||
|
||||
// But we do inside of it
|
||||
withContext.run();
|
||||
|
||||
// but not after
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -350,6 +353,177 @@ public class ThreadContextTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testPreservesThreadsOriginalContextOnRunException() throws IOException {
|
||||
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
|
||||
Runnable withContext;
|
||||
|
||||
// create a abstract runnable, add headers and transient objects and verify in the methods
|
||||
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
|
||||
threadContext.putHeader("foo", "bar");
|
||||
threadContext.putTransient("foo", "bar_transient");
|
||||
withContext = threadContext.preserveContext(new AbstractRunnable() {
|
||||
|
||||
@Override
|
||||
public void onAfter() {
|
||||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals("bar_transient", threadContext.getTransient("foo"));
|
||||
assertNotNull(threadContext.getTransient("failure"));
|
||||
assertEquals("exception from doRun", ((RuntimeException)threadContext.getTransient("failure")).getMessage());
|
||||
assertFalse(threadContext.isDefaultContext());
|
||||
threadContext.putTransient("after", "after");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
assertEquals("exception from doRun", e.getMessage());
|
||||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals("bar_transient", threadContext.getTransient("foo"));
|
||||
assertFalse(threadContext.isDefaultContext());
|
||||
threadContext.putTransient("failure", e);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doRun() throws Exception {
|
||||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals("bar_transient", threadContext.getTransient("foo"));
|
||||
assertFalse(threadContext.isDefaultContext());
|
||||
throw new RuntimeException("exception from doRun");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// We don't see the header outside of the runnable
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertNull(threadContext.getTransient("foo"));
|
||||
assertNull(threadContext.getTransient("failure"));
|
||||
assertNull(threadContext.getTransient("after"));
|
||||
assertTrue(threadContext.isDefaultContext());
|
||||
|
||||
// But we do inside of it
|
||||
withContext.run();
|
||||
|
||||
// verify not seen after
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertNull(threadContext.getTransient("foo"));
|
||||
assertNull(threadContext.getTransient("failure"));
|
||||
assertNull(threadContext.getTransient("after"));
|
||||
assertTrue(threadContext.isDefaultContext());
|
||||
|
||||
// repeat with regular runnable
|
||||
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
|
||||
threadContext.putHeader("foo", "bar");
|
||||
threadContext.putTransient("foo", "bar_transient");
|
||||
withContext = threadContext.preserveContext(() -> {
|
||||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals("bar_transient", threadContext.getTransient("foo"));
|
||||
assertFalse(threadContext.isDefaultContext());
|
||||
threadContext.putTransient("run", true);
|
||||
throw new RuntimeException("exception from run");
|
||||
});
|
||||
}
|
||||
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertNull(threadContext.getTransient("foo"));
|
||||
assertNull(threadContext.getTransient("run"));
|
||||
assertTrue(threadContext.isDefaultContext());
|
||||
|
||||
final Runnable runnable = withContext;
|
||||
RuntimeException e = expectThrows(RuntimeException.class, runnable::run);
|
||||
assertEquals("exception from run", e.getMessage());
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertNull(threadContext.getTransient("foo"));
|
||||
assertNull(threadContext.getTransient("run"));
|
||||
assertTrue(threadContext.isDefaultContext());
|
||||
}
|
||||
}
|
||||
|
||||
public void testPreservesThreadsOriginalContextOnFailureException() throws IOException {
|
||||
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
|
||||
Runnable withContext;
|
||||
|
||||
// a runnable that throws from onFailure
|
||||
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
|
||||
threadContext.putHeader("foo", "bar");
|
||||
threadContext.putTransient("foo", "bar_transient");
|
||||
withContext = threadContext.preserveContext(new AbstractRunnable() {
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
throw new RuntimeException("from onFailure", e);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doRun() throws Exception {
|
||||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals("bar_transient", threadContext.getTransient("foo"));
|
||||
assertFalse(threadContext.isDefaultContext());
|
||||
throw new RuntimeException("from doRun");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// We don't see the header outside of the runnable
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertNull(threadContext.getTransient("foo"));
|
||||
assertTrue(threadContext.isDefaultContext());
|
||||
|
||||
// But we do inside of it
|
||||
RuntimeException e = expectThrows(RuntimeException.class, withContext::run);
|
||||
assertEquals("from onFailure", e.getMessage());
|
||||
assertEquals("from doRun", e.getCause().getMessage());
|
||||
|
||||
// but not after
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertNull(threadContext.getTransient("foo"));
|
||||
assertTrue(threadContext.isDefaultContext());
|
||||
}
|
||||
}
|
||||
|
||||
public void testPreservesThreadsOriginalContextOnAfterException() throws IOException {
|
||||
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
|
||||
Runnable withContext;
|
||||
|
||||
// a runnable that throws from onAfter
|
||||
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
|
||||
threadContext.putHeader("foo", "bar");
|
||||
threadContext.putTransient("foo", "bar_transient");
|
||||
withContext = threadContext.preserveContext(new AbstractRunnable() {
|
||||
|
||||
@Override
|
||||
public void onAfter() {
|
||||
throw new RuntimeException("from onAfter");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
throw new RuntimeException("from onFailure", e);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doRun() throws Exception {
|
||||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals("bar_transient", threadContext.getTransient("foo"));
|
||||
assertFalse(threadContext.isDefaultContext());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// We don't see the header outside of the runnable
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertNull(threadContext.getTransient("foo"));
|
||||
assertTrue(threadContext.isDefaultContext());
|
||||
|
||||
// But we do inside of it
|
||||
RuntimeException e = expectThrows(RuntimeException.class, withContext::run);
|
||||
assertEquals("from onAfter", e.getMessage());
|
||||
assertNull(e.getCause());
|
||||
|
||||
// but not after
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertNull(threadContext.getTransient("foo"));
|
||||
assertTrue(threadContext.isDefaultContext());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sometimes wraps a Runnable in an AbstractRunnable.
|
||||
*/
|
||||
|
|
Loading…
Reference in New Issue