Restore the original ThreadContext after a preserved context is restored
This change adds the restoration of the original context inside the listeners and handlers where we restore another context. This prevents us from polluting the context of the thread that called the listener and leaving around a different user in the thread context. Original commit: elastic/x-pack-elasticsearch@0f30363ef7
This commit is contained in:
parent
743458705a
commit
65db63cac4
|
@ -16,21 +16,27 @@ public final class ContextPreservingActionListener<R> implements ActionListener<
|
||||||
|
|
||||||
private final ActionListener<R> delegate;
|
private final ActionListener<R> delegate;
|
||||||
private final ThreadContext.StoredContext context;
|
private final ThreadContext.StoredContext context;
|
||||||
|
private final ThreadContext threadContext;
|
||||||
|
|
||||||
public ContextPreservingActionListener(ThreadContext.StoredContext context, ActionListener<R> delegate) {
|
public ContextPreservingActionListener(ThreadContext threadContext, ThreadContext.StoredContext context, ActionListener<R> delegate) {
|
||||||
this.delegate = delegate;
|
this.delegate = delegate;
|
||||||
this.context = context;
|
this.context = context;
|
||||||
|
this.threadContext = threadContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onResponse(R r) {
|
public void onResponse(R r) {
|
||||||
context.restore();
|
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
|
||||||
delegate.onResponse(r);
|
context.restore();
|
||||||
|
delegate.onResponse(r);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onFailure(Exception e) {
|
public void onFailure(Exception e) {
|
||||||
context.restore();
|
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
|
||||||
delegate.onFailure(e);
|
context.restore();
|
||||||
|
delegate.onFailure(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,10 +74,10 @@ public class InternalClient extends FilterClient {
|
||||||
final ThreadContext threadContext = threadPool().getThreadContext();
|
final ThreadContext threadContext = threadPool().getThreadContext();
|
||||||
final ThreadContext.StoredContext storedContext = threadContext.newStoredContext();
|
final ThreadContext.StoredContext storedContext = threadContext.newStoredContext();
|
||||||
// we need to preserve the context here otherwise we execute the response with the XPack user which we can cause problems
|
// we need to preserve the context here otherwise we execute the response with the XPack user which we can cause problems
|
||||||
// since we expect the callback to run wiht the authenticated user calling the doExecute method
|
// since we expect the callback to run with the authenticated user calling the doExecute method
|
||||||
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
|
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
|
||||||
processContext(threadContext);
|
processContext(threadContext);
|
||||||
super.doExecute(action, request, new ContextPreservingActionListener<>(storedContext, listener));
|
super.doExecute(action, request, new ContextPreservingActionListener<>(threadContext, storedContext, listener));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -112,18 +112,20 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil
|
||||||
|
|
||||||
// only restore the context if it is not empty. This is needed because sometimes a response is sent to the user
|
// only restore the context if it is not empty. This is needed because sometimes a response is sent to the user
|
||||||
// and then a cleanup action is executed (like for search without a scroll)
|
// and then a cleanup action is executed (like for search without a scroll)
|
||||||
final boolean restoreOriginalContext = securityContext.getAuthentication() != null;
|
final ThreadContext.StoredContext originalContext = threadContext.newStoredContext();
|
||||||
final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action);
|
final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action);
|
||||||
// we should always restore the original here because we forcefully changed to the system user
|
// we should always restore the original here because we forcefully changed to the system user
|
||||||
final ThreadContext.StoredContext toRestore = restoreOriginalContext || useSystemUser ? threadContext.newStoredContext() : () -> {};
|
final ThreadContext.StoredContext toRestore = useSystemUser ? originalContext : () -> {};
|
||||||
final ActionListener<ActionResponse> signingListener = new ContextPreservingActionListener<>(toRestore, ActionListener.wrap(r -> {
|
final ActionListener<ActionResponse> signingListener =
|
||||||
|
new ContextPreservingActionListener<>(threadContext, toRestore, ActionListener.wrap(r -> {
|
||||||
try {
|
try {
|
||||||
listener.onResponse(sign(r));
|
listener.onResponse(sign(r));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new UncheckedIOException(e);
|
throw new UncheckedIOException(e);
|
||||||
}
|
}
|
||||||
}, listener::onFailure));
|
}, listener::onFailure));
|
||||||
ActionListener<Void> authenticatedListener = new ActionListener<Void>() {
|
ActionListener<Void> authenticatedListener = new ContextPreservingActionListener<>(threadContext, toRestore,
|
||||||
|
new ActionListener<Void>() {
|
||||||
@Override
|
@Override
|
||||||
public void onResponse(Void aVoid) {
|
public void onResponse(Void aVoid) {
|
||||||
chain.proceed(task, action, request, signingListener);
|
chain.proceed(task, action, request, signingListener);
|
||||||
|
@ -132,7 +134,7 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil
|
||||||
public void onFailure(Exception e) {
|
public void onFailure(Exception e) {
|
||||||
signingListener.onFailure(e);
|
signingListener.onFailure(e);
|
||||||
}
|
}
|
||||||
};
|
});
|
||||||
try {
|
try {
|
||||||
if (useSystemUser) {
|
if (useSystemUser) {
|
||||||
securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> {
|
securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> {
|
||||||
|
|
|
@ -85,7 +85,7 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
|
||||||
// which means that the user is copied over to system actions so we need to change the user
|
// which means that the user is copied over to system actions so we need to change the user
|
||||||
if (AuthorizationUtils.shouldReplaceUserWithSystem(threadPool.getThreadContext(), action)) {
|
if (AuthorizationUtils.shouldReplaceUserWithSystem(threadPool.getThreadContext(), action)) {
|
||||||
securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> sendWithUser(node, action, request, options,
|
securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> sendWithUser(node, action, request, options,
|
||||||
new ContextRestoreResponseHandler<>(original, handler), sender));
|
new ContextRestoreResponseHandler<>(threadPool.getThreadContext(), original, handler), sender));
|
||||||
} else {
|
} else {
|
||||||
sendWithUser(node, action, request, options, handler, sender);
|
sendWithUser(node, action, request, options, handler, sender);
|
||||||
}
|
}
|
||||||
|
@ -260,11 +260,14 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
|
||||||
*/
|
*/
|
||||||
static final class ContextRestoreResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> {
|
static final class ContextRestoreResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> {
|
||||||
private final TransportResponseHandler<T> delegate;
|
private final TransportResponseHandler<T> delegate;
|
||||||
private final ThreadContext.StoredContext threadContext;
|
private final ThreadContext.StoredContext context;
|
||||||
|
private final ThreadContext threadContext;
|
||||||
|
|
||||||
// pkg private for testing
|
// pkg private for testing
|
||||||
ContextRestoreResponseHandler(ThreadContext.StoredContext threadContext, TransportResponseHandler<T> delegate) {
|
ContextRestoreResponseHandler(ThreadContext threadContext, ThreadContext.StoredContext context,
|
||||||
|
TransportResponseHandler<T> delegate) {
|
||||||
this.delegate = delegate;
|
this.delegate = delegate;
|
||||||
|
this.context = context;
|
||||||
this.threadContext = threadContext;
|
this.threadContext = threadContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,14 +278,18 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void handleResponse(T response) {
|
public void handleResponse(T response) {
|
||||||
threadContext.restore();
|
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
|
||||||
delegate.handleResponse(response);
|
context.restore();
|
||||||
|
delegate.handleResponse(response);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void handleException(TransportException exp) {
|
public void handleException(TransportException exp) {
|
||||||
threadContext.restore();
|
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
|
||||||
delegate.handleException(exp);
|
context.restore();
|
||||||
|
delegate.handleException(exp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.common;
|
||||||
|
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class ContextPreservingActionListenerTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testOriginalContextIsPreservedAfterOnResponse() throws IOException {
|
||||||
|
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
|
||||||
|
final boolean nonEmptyContext = randomBoolean();
|
||||||
|
if (nonEmptyContext) {
|
||||||
|
threadContext.putHeader("not empty", "value");
|
||||||
|
}
|
||||||
|
ContextPreservingActionListener<Void> actionListener;
|
||||||
|
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
|
||||||
|
threadContext.putHeader("foo", "bar");
|
||||||
|
actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(),
|
||||||
|
new ActionListener<Void>() {
|
||||||
|
@Override
|
||||||
|
public void onResponse(Void aVoid) {
|
||||||
|
assertEquals("bar", threadContext.getHeader("foo"));
|
||||||
|
assertNull(threadContext.getHeader("not empty"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onFailure(Exception e) {
|
||||||
|
throw new RuntimeException("onFailure shouldn't be called", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
assertNull(threadContext.getHeader("foo"));
|
||||||
|
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
|
||||||
|
|
||||||
|
actionListener.onResponse(null);
|
||||||
|
|
||||||
|
assertNull(threadContext.getHeader("foo"));
|
||||||
|
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOriginalContextIsPreservedAfterOnFailure() throws Exception {
|
||||||
|
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
|
||||||
|
final boolean nonEmptyContext = randomBoolean();
|
||||||
|
if (nonEmptyContext) {
|
||||||
|
threadContext.putHeader("not empty", "value");
|
||||||
|
}
|
||||||
|
ContextPreservingActionListener<Void> actionListener;
|
||||||
|
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
|
||||||
|
threadContext.putHeader("foo", "bar");
|
||||||
|
actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(),
|
||||||
|
new ActionListener<Void>() {
|
||||||
|
@Override
|
||||||
|
public void onResponse(Void aVoid) {
|
||||||
|
throw new RuntimeException("onResponse shouldn't be called");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onFailure(Exception e) {
|
||||||
|
assertEquals("bar", threadContext.getHeader("foo"));
|
||||||
|
assertNull(threadContext.getHeader("not empty"));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
assertNull(threadContext.getHeader("foo"));
|
||||||
|
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
|
||||||
|
|
||||||
|
actionListener.onFailure(null);
|
||||||
|
|
||||||
|
assertNull(threadContext.getHeader("foo"));
|
||||||
|
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOriginalContextIsWhenListenerThrows() throws Exception {
|
||||||
|
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
|
||||||
|
final boolean nonEmptyContext = randomBoolean();
|
||||||
|
if (nonEmptyContext) {
|
||||||
|
threadContext.putHeader("not empty", "value");
|
||||||
|
}
|
||||||
|
ContextPreservingActionListener<Void> actionListener;
|
||||||
|
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
|
||||||
|
threadContext.putHeader("foo", "bar");
|
||||||
|
actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(),
|
||||||
|
new ActionListener<Void>() {
|
||||||
|
@Override
|
||||||
|
public void onResponse(Void aVoid) {
|
||||||
|
assertEquals("bar", threadContext.getHeader("foo"));
|
||||||
|
assertNull(threadContext.getHeader("not empty"));
|
||||||
|
throw new RuntimeException("onResponse called");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onFailure(Exception e) {
|
||||||
|
assertEquals("bar", threadContext.getHeader("foo"));
|
||||||
|
assertNull(threadContext.getHeader("not empty"));
|
||||||
|
throw new RuntimeException("onFailure called");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
assertNull(threadContext.getHeader("foo"));
|
||||||
|
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
|
||||||
|
|
||||||
|
RuntimeException e = expectThrows(RuntimeException.class, () -> actionListener.onResponse(null));
|
||||||
|
assertEquals("onResponse called", e.getMessage());
|
||||||
|
|
||||||
|
assertNull(threadContext.getHeader("foo"));
|
||||||
|
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
|
||||||
|
|
||||||
|
e = expectThrows(RuntimeException.class, () -> actionListener.onFailure(null));
|
||||||
|
assertEquals("onFailure called", e.getMessage());
|
||||||
|
|
||||||
|
assertNull(threadContext.getHeader("foo"));
|
||||||
|
assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -185,7 +185,7 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase {
|
||||||
try (ThreadContext.StoredContext storedContext = threadContext.stashContext()) {
|
try (ThreadContext.StoredContext storedContext = threadContext.stashContext()) {
|
||||||
threadContext.putTransient("foo", "different_bar");
|
threadContext.putTransient("foo", "different_bar");
|
||||||
threadContext.putHeader("key", "value2");
|
threadContext.putHeader("key", "value2");
|
||||||
TransportResponseHandler<Empty> handler = new ContextRestoreResponseHandler<>(storedContext,
|
TransportResponseHandler<Empty> handler = new ContextRestoreResponseHandler<>(threadContext, storedContext,
|
||||||
new TransportResponseHandler<Empty>() {
|
new TransportResponseHandler<Empty>() {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -215,4 +215,53 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase {
|
||||||
handler.handleException(null);
|
handler.handleException(null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testContextRestoreResponseHandlerRestoreOriginalContext() throws Exception {
|
||||||
|
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
|
||||||
|
threadContext.putTransient("foo", "bar");
|
||||||
|
threadContext.putHeader("key", "value");
|
||||||
|
TransportResponseHandler<Empty> handler;
|
||||||
|
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
|
||||||
|
threadContext.putTransient("foo", "different_bar");
|
||||||
|
threadContext.putHeader("key", "value2");
|
||||||
|
handler = new ContextRestoreResponseHandler<>(threadContext,
|
||||||
|
threadContext.newStoredContext(),
|
||||||
|
new TransportResponseHandler<Empty>() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Empty newInstance() {
|
||||||
|
return Empty.INSTANCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handleResponse(Empty response) {
|
||||||
|
assertEquals("different_bar", threadContext.getTransient("foo"));
|
||||||
|
assertEquals("value2", threadContext.getHeader("key"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handleException(TransportException exp) {
|
||||||
|
assertEquals("different_bar", threadContext.getTransient("foo"));
|
||||||
|
assertEquals("value2", threadContext.getHeader("key"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String executor() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals("bar", threadContext.getTransient("foo"));
|
||||||
|
assertEquals("value", threadContext.getHeader("key"));
|
||||||
|
handler.handleResponse(null);
|
||||||
|
|
||||||
|
assertEquals("bar", threadContext.getTransient("foo"));
|
||||||
|
assertEquals("value", threadContext.getHeader("key"));
|
||||||
|
handler.handleException(null);
|
||||||
|
|
||||||
|
assertEquals("bar", threadContext.getTransient("foo"));
|
||||||
|
assertEquals("value", threadContext.getHeader("key"));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue