diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java index 9cf1bb8292f..b4641625ee9 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java @@ -16,21 +16,27 @@ public final class ContextPreservingActionListener implements ActionListener< private final ActionListener delegate; private final ThreadContext.StoredContext context; + private final ThreadContext threadContext; - public ContextPreservingActionListener(ThreadContext.StoredContext context, ActionListener delegate) { + public ContextPreservingActionListener(ThreadContext threadContext, ThreadContext.StoredContext context, ActionListener delegate) { this.delegate = delegate; this.context = context; + this.threadContext = threadContext; } @Override public void onResponse(R r) { - context.restore(); - delegate.onResponse(r); + try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { + context.restore(); + delegate.onResponse(r); + } } @Override public void onFailure(Exception e) { - context.restore(); - delegate.onFailure(e); + try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { + context.restore(); + delegate.onFailure(e); + } } } diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java index 0f630626e73..b07e20fc039 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java @@ -74,10 +74,10 @@ public class InternalClient extends FilterClient { final ThreadContext threadContext = threadPool().getThreadContext(); final ThreadContext.StoredContext storedContext = threadContext.newStoredContext(); // we need to preserve the context here otherwise we execute the response with the XPack user which we can cause problems - // since we expect the callback to run wiht the authenticated user calling the doExecute method + // since we expect the callback to run with the authenticated user calling the doExecute method try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { processContext(threadContext); - super.doExecute(action, request, new ContextPreservingActionListener<>(storedContext, listener)); + super.doExecute(action, request, new ContextPreservingActionListener<>(threadContext, storedContext, listener)); } } diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java index 0f714ec4ebc..19070056d82 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java @@ -112,18 +112,20 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil // only restore the context if it is not empty. This is needed because sometimes a response is sent to the user // and then a cleanup action is executed (like for search without a scroll) - final boolean restoreOriginalContext = securityContext.getAuthentication() != null; + final ThreadContext.StoredContext originalContext = threadContext.newStoredContext(); final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action); // we should always restore the original here because we forcefully changed to the system user - final ThreadContext.StoredContext toRestore = restoreOriginalContext || useSystemUser ? threadContext.newStoredContext() : () -> {}; - final ActionListener signingListener = new ContextPreservingActionListener<>(toRestore, ActionListener.wrap(r -> { + final ThreadContext.StoredContext toRestore = useSystemUser ? originalContext : () -> {}; + final ActionListener signingListener = + new ContextPreservingActionListener<>(threadContext, toRestore, ActionListener.wrap(r -> { try { listener.onResponse(sign(r)); } catch (IOException e) { throw new UncheckedIOException(e); } }, listener::onFailure)); - ActionListener authenticatedListener = new ActionListener() { + ActionListener authenticatedListener = new ContextPreservingActionListener<>(threadContext, toRestore, + new ActionListener() { @Override public void onResponse(Void aVoid) { chain.proceed(task, action, request, signingListener); @@ -132,7 +134,7 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil public void onFailure(Exception e) { signingListener.onFailure(e); } - }; + }); try { if (useSystemUser) { securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> { diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java index 322f700434f..67f6904a963 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java @@ -85,7 +85,7 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor // which means that the user is copied over to system actions so we need to change the user if (AuthorizationUtils.shouldReplaceUserWithSystem(threadPool.getThreadContext(), action)) { securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> sendWithUser(node, action, request, options, - new ContextRestoreResponseHandler<>(original, handler), sender)); + new ContextRestoreResponseHandler<>(threadPool.getThreadContext(), original, handler), sender)); } else { sendWithUser(node, action, request, options, handler, sender); } @@ -260,11 +260,14 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor */ static final class ContextRestoreResponseHandler implements TransportResponseHandler { private final TransportResponseHandler delegate; - private final ThreadContext.StoredContext threadContext; + private final ThreadContext.StoredContext context; + private final ThreadContext threadContext; // pkg private for testing - ContextRestoreResponseHandler(ThreadContext.StoredContext threadContext, TransportResponseHandler delegate) { + ContextRestoreResponseHandler(ThreadContext threadContext, ThreadContext.StoredContext context, + TransportResponseHandler delegate) { this.delegate = delegate; + this.context = context; this.threadContext = threadContext; } @@ -275,14 +278,18 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor @Override public void handleResponse(T response) { - threadContext.restore(); - delegate.handleResponse(response); + try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { + context.restore(); + delegate.handleResponse(response); + } } @Override public void handleException(TransportException exp) { - threadContext.restore(); - delegate.handleException(exp); + try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { + context.restore(); + delegate.handleException(exp); + } } @Override diff --git a/elasticsearch/src/test/java/org/elasticsearch/xpack/common/ContextPreservingActionListenerTests.java b/elasticsearch/src/test/java/org/elasticsearch/xpack/common/ContextPreservingActionListenerTests.java new file mode 100644 index 00000000000..c0a79121547 --- /dev/null +++ b/elasticsearch/src/test/java/org/elasticsearch/xpack/common/ContextPreservingActionListenerTests.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.common; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +public class ContextPreservingActionListenerTests extends ESTestCase { + + public void testOriginalContextIsPreservedAfterOnResponse() throws IOException { + try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) { + final boolean nonEmptyContext = randomBoolean(); + if (nonEmptyContext) { + threadContext.putHeader("not empty", "value"); + } + ContextPreservingActionListener actionListener; + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + threadContext.putHeader("foo", "bar"); + actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(), + new ActionListener() { + @Override + public void onResponse(Void aVoid) { + assertEquals("bar", threadContext.getHeader("foo")); + assertNull(threadContext.getHeader("not empty")); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("onFailure shouldn't be called", e); + } + }); + } + + assertNull(threadContext.getHeader("foo")); + assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty")); + + actionListener.onResponse(null); + + assertNull(threadContext.getHeader("foo")); + assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty")); + } + } + + public void testOriginalContextIsPreservedAfterOnFailure() throws Exception { + try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) { + final boolean nonEmptyContext = randomBoolean(); + if (nonEmptyContext) { + threadContext.putHeader("not empty", "value"); + } + ContextPreservingActionListener actionListener; + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + threadContext.putHeader("foo", "bar"); + actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(), + new ActionListener() { + @Override + public void onResponse(Void aVoid) { + throw new RuntimeException("onResponse shouldn't be called"); + } + + @Override + public void onFailure(Exception e) { + assertEquals("bar", threadContext.getHeader("foo")); + assertNull(threadContext.getHeader("not empty")); + } + }); + } + + assertNull(threadContext.getHeader("foo")); + assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty")); + + actionListener.onFailure(null); + + assertNull(threadContext.getHeader("foo")); + assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty")); + } + } + + public void testOriginalContextIsWhenListenerThrows() throws Exception { + try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) { + final boolean nonEmptyContext = randomBoolean(); + if (nonEmptyContext) { + threadContext.putHeader("not empty", "value"); + } + ContextPreservingActionListener actionListener; + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + threadContext.putHeader("foo", "bar"); + actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(), + new ActionListener() { + @Override + public void onResponse(Void aVoid) { + assertEquals("bar", threadContext.getHeader("foo")); + assertNull(threadContext.getHeader("not empty")); + throw new RuntimeException("onResponse called"); + } + + @Override + public void onFailure(Exception e) { + assertEquals("bar", threadContext.getHeader("foo")); + assertNull(threadContext.getHeader("not empty")); + throw new RuntimeException("onFailure called"); + } + }); + } + + assertNull(threadContext.getHeader("foo")); + assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty")); + + RuntimeException e = expectThrows(RuntimeException.class, () -> actionListener.onResponse(null)); + assertEquals("onResponse called", e.getMessage()); + + assertNull(threadContext.getHeader("foo")); + assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty")); + + e = expectThrows(RuntimeException.class, () -> actionListener.onFailure(null)); + assertEquals("onFailure called", e.getMessage()); + + assertNull(threadContext.getHeader("foo")); + assertEquals(nonEmptyContext ? "value" : null, threadContext.getHeader("not empty")); + } + } +} diff --git a/elasticsearch/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java b/elasticsearch/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java index 290f5a84c0c..3f1fce8f1cc 100644 --- a/elasticsearch/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java +++ b/elasticsearch/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java @@ -185,7 +185,7 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase { try (ThreadContext.StoredContext storedContext = threadContext.stashContext()) { threadContext.putTransient("foo", "different_bar"); threadContext.putHeader("key", "value2"); - TransportResponseHandler handler = new ContextRestoreResponseHandler<>(storedContext, + TransportResponseHandler handler = new ContextRestoreResponseHandler<>(threadContext, storedContext, new TransportResponseHandler() { @Override @@ -215,4 +215,53 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase { handler.handleException(null); } } + + public void testContextRestoreResponseHandlerRestoreOriginalContext() throws Exception { + try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) { + threadContext.putTransient("foo", "bar"); + threadContext.putHeader("key", "value"); + TransportResponseHandler handler; + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + threadContext.putTransient("foo", "different_bar"); + threadContext.putHeader("key", "value2"); + handler = new ContextRestoreResponseHandler<>(threadContext, + threadContext.newStoredContext(), + new TransportResponseHandler() { + + @Override + public Empty newInstance() { + return Empty.INSTANCE; + } + + @Override + public void handleResponse(Empty response) { + assertEquals("different_bar", threadContext.getTransient("foo")); + assertEquals("value2", threadContext.getHeader("key")); + } + + @Override + public void handleException(TransportException exp) { + assertEquals("different_bar", threadContext.getTransient("foo")); + assertEquals("value2", threadContext.getHeader("key")); + } + + @Override + public String executor() { + return null; + } + }); + } + + assertEquals("bar", threadContext.getTransient("foo")); + assertEquals("value", threadContext.getHeader("key")); + handler.handleResponse(null); + + assertEquals("bar", threadContext.getTransient("foo")); + assertEquals("value", threadContext.getHeader("key")); + handler.handleException(null); + + assertEquals("bar", threadContext.getTransient("foo")); + assertEquals("value", threadContext.getHeader("key")); + } + } }