Restore the original ThreadContext after a preserved context is restored

This change adds the restoration of the original context inside the listeners and handlers where
we restore another context. This prevents us from polluting the context of the thread that called
the listener and leaving around a different user in the thread context.


Original commit: elastic/x-pack-elasticsearch@0f30363ef7
This commit is contained in:
Jay Modi 2016-11-09 16:02:43 -05:00 committed by GitHub
parent 743458705a
commit 65db63cac4
6 changed files with 212 additions and 20 deletions

View File

@ -16,21 +16,27 @@ public final class ContextPreservingActionListener<R> implements ActionListener<
private final ActionListener<R> delegate;
private final ThreadContext.StoredContext context;
private final ThreadContext threadContext;
public ContextPreservingActionListener(ThreadContext.StoredContext context, ActionListener<R> delegate) {
public ContextPreservingActionListener(ThreadContext threadContext, ThreadContext.StoredContext context, ActionListener<R> delegate) {
this.delegate = delegate;
this.context = context;
this.threadContext = threadContext;
}
@Override
public void onResponse(R r) {
context.restore();
delegate.onResponse(r);
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
context.restore();
delegate.onResponse(r);
}
}
@Override
public void onFailure(Exception e) {
context.restore();
delegate.onFailure(e);
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
context.restore();
delegate.onFailure(e);
}
}
}

View File

@ -74,10 +74,10 @@ public class InternalClient extends FilterClient {
final ThreadContext threadContext = threadPool().getThreadContext();
final ThreadContext.StoredContext storedContext = threadContext.newStoredContext();
// we need to preserve the context here otherwise we execute the response with the XPack user which we can cause problems
// since we expect the callback to run wiht the authenticated user calling the doExecute method
// since we expect the callback to run with the authenticated user calling the doExecute method
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
processContext(threadContext);
super.doExecute(action, request, new ContextPreservingActionListener<>(storedContext, listener));
super.doExecute(action, request, new ContextPreservingActionListener<>(threadContext, storedContext, listener));
}
}

View File

@ -112,18 +112,20 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil
// only restore the context if it is not empty. This is needed because sometimes a response is sent to the user
// and then a cleanup action is executed (like for search without a scroll)
final boolean restoreOriginalContext = securityContext.getAuthentication() != null;
final ThreadContext.StoredContext originalContext = threadContext.newStoredContext();
final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action);
// we should always restore the original here because we forcefully changed to the system user
final ThreadContext.StoredContext toRestore = restoreOriginalContext || useSystemUser ? threadContext.newStoredContext() : () -> {};
final ActionListener<ActionResponse> signingListener = new ContextPreservingActionListener<>(toRestore, ActionListener.wrap(r -> {
final ThreadContext.StoredContext toRestore = useSystemUser ? originalContext : () -> {};
final ActionListener<ActionResponse> signingListener =
new ContextPreservingActionListener<>(threadContext, toRestore, ActionListener.wrap(r -> {
try {
listener.onResponse(sign(r));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}, listener::onFailure));
ActionListener<Void> authenticatedListener = new ActionListener<Void>() {
ActionListener<Void> authenticatedListener = new ContextPreservingActionListener<>(threadContext, toRestore,
new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
chain.proceed(task, action, request, signingListener);
@ -132,7 +134,7 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil
public void onFailure(Exception e) {
signingListener.onFailure(e);
}
};
});
try {
if (useSystemUser) {
securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> {

View File

@ -85,7 +85,7 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
// which means that the user is copied over to system actions so we need to change the user
if (AuthorizationUtils.shouldReplaceUserWithSystem(threadPool.getThreadContext(), action)) {
securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> sendWithUser(node, action, request, options,
new ContextRestoreResponseHandler<>(original, handler), sender));
new ContextRestoreResponseHandler<>(threadPool.getThreadContext(), original, handler), sender));
} else {
sendWithUser(node, action, request, options, handler, sender);
}
@ -260,11 +260,14 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
*/
static final class ContextRestoreResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> {
private final TransportResponseHandler<T> delegate;
private final ThreadContext.StoredContext threadContext;
private final ThreadContext.StoredContext context;
private final ThreadContext threadContext;
// pkg private for testing
ContextRestoreResponseHandler(ThreadContext.StoredContext threadContext, TransportResponseHandler<T> delegate) {
ContextRestoreResponseHandler(ThreadContext threadContext, ThreadContext.StoredContext context,
TransportResponseHandler<T> delegate) {
this.delegate = delegate;
this.context = context;
this.threadContext = threadContext;
}
@ -275,14 +278,18 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
@Override
public void handleResponse(T response) {
threadContext.restore();
delegate.handleResponse(response);
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
context.restore();
delegate.handleResponse(response);
}
}
@Override
public void handleException(TransportException exp) {
threadContext.restore();
delegate.handleException(exp);
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
context.restore();
delegate.handleException(exp);
}
}
@Override

View File

@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.common;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
public class ContextPreservingActionListenerTests extends ESTestCase {
public void testOriginalContextIsPreservedAfterOnResponse() throws IOException {
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
final boolean nonEmptyContext = randomBoolean();
if (nonEmptyContext) {
threadContext.putHeader("not empty", "value");
}
ContextPreservingActionListener<Void> actionListener;
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(),
new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
assertEquals("bar", threadContext.getHeader("foo"));
assertNull(threadContext.getHeader("not empty"));
}
@Override
public void onFailure(Exception e) {
throw new RuntimeException("onFailure shouldn't be called", e);
}
});
}
assertNull(threadContext.getHeader("foo"));
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
actionListener.onResponse(null);
assertNull(threadContext.getHeader("foo"));
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
}
}
public void testOriginalContextIsPreservedAfterOnFailure() throws Exception {
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
final boolean nonEmptyContext = randomBoolean();
if (nonEmptyContext) {
threadContext.putHeader("not empty", "value");
}
ContextPreservingActionListener<Void> actionListener;
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(),
new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
throw new RuntimeException("onResponse shouldn't be called");
}
@Override
public void onFailure(Exception e) {
assertEquals("bar", threadContext.getHeader("foo"));
assertNull(threadContext.getHeader("not empty"));
}
});
}
assertNull(threadContext.getHeader("foo"));
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
actionListener.onFailure(null);
assertNull(threadContext.getHeader("foo"));
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
}
}
public void testOriginalContextIsWhenListenerThrows() throws Exception {
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
final boolean nonEmptyContext = randomBoolean();
if (nonEmptyContext) {
threadContext.putHeader("not empty", "value");
}
ContextPreservingActionListener<Void> actionListener;
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(),
new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
assertEquals("bar", threadContext.getHeader("foo"));
assertNull(threadContext.getHeader("not empty"));
throw new RuntimeException("onResponse called");
}
@Override
public void onFailure(Exception e) {
assertEquals("bar", threadContext.getHeader("foo"));
assertNull(threadContext.getHeader("not empty"));
throw new RuntimeException("onFailure called");
}
});
}
assertNull(threadContext.getHeader("foo"));
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
RuntimeException e = expectThrows(RuntimeException.class, () -> actionListener.onResponse(null));
assertEquals("onResponse called", e.getMessage());
assertNull(threadContext.getHeader("foo"));
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
e = expectThrows(RuntimeException.class, () -> actionListener.onFailure(null));
assertEquals("onFailure called", e.getMessage());
assertNull(threadContext.getHeader("foo"));
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
}
}
}

View File

@ -185,7 +185,7 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase {
try (ThreadContext.StoredContext storedContext = threadContext.stashContext()) {
threadContext.putTransient("foo", "different_bar");
threadContext.putHeader("key", "value2");
TransportResponseHandler<Empty> handler = new ContextRestoreResponseHandler<>(storedContext,
TransportResponseHandler<Empty> handler = new ContextRestoreResponseHandler<>(threadContext, storedContext,
new TransportResponseHandler<Empty>() {
@Override
@ -215,4 +215,53 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase {
handler.handleException(null);
}
}
public void testContextRestoreResponseHandlerRestoreOriginalContext() throws Exception {
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
threadContext.putTransient("foo", "bar");
threadContext.putHeader("key", "value");
TransportResponseHandler<Empty> handler;
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.putTransient("foo", "different_bar");
threadContext.putHeader("key", "value2");
handler = new ContextRestoreResponseHandler<>(threadContext,
threadContext.newStoredContext(),
new TransportResponseHandler<Empty>() {
@Override
public Empty newInstance() {
return Empty.INSTANCE;
}
@Override
public void handleResponse(Empty response) {
assertEquals("different_bar", threadContext.getTransient("foo"));
assertEquals("value2", threadContext.getHeader("key"));
}
@Override
public void handleException(TransportException exp) {
assertEquals("different_bar", threadContext.getTransient("foo"));
assertEquals("value2", threadContext.getHeader("key"));
}
@Override
public String executor() {
return null;
}
});
}
assertEquals("bar", threadContext.getTransient("foo"));
assertEquals("value", threadContext.getHeader("key"));
handler.handleResponse(null);
assertEquals("bar", threadContext.getTransient("foo"));
assertEquals("value", threadContext.getHeader("key"));
handler.handleException(null);
assertEquals("bar", threadContext.getTransient("foo"));
assertEquals("value", threadContext.getHeader("key"));
}
}
}