diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java index b4641625ee9..1e7468bfc51 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java @@ -8,6 +8,8 @@ package org.elasticsearch.xpack.common; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.ThreadContext; +import java.util.function.Supplier; + /** * Restores the given {@link org.elasticsearch.common.util.concurrent.ThreadContext.StoredContext} * once the listener is invoked @@ -15,27 +17,23 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; public final class ContextPreservingActionListener implements ActionListener { private final ActionListener delegate; - private final ThreadContext.StoredContext context; - private final ThreadContext threadContext; + private final Supplier context; - public ContextPreservingActionListener(ThreadContext threadContext, ThreadContext.StoredContext context, ActionListener delegate) { + public ContextPreservingActionListener(Supplier contextSupplier, ActionListener delegate) { this.delegate = delegate; - this.context = context; - this.threadContext = threadContext; + this.context = contextSupplier; } @Override public void onResponse(R r) { - try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { - context.restore(); + try (ThreadContext.StoredContext ignore = context.get()) { delegate.onResponse(r); } } @Override public void onFailure(Exception e) { - try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { - context.restore(); + try (ThreadContext.StoredContext ignore = context.get()) { delegate.onFailure(e); } } diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java index 4095e65a9a2..9d4e5eff09b 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java @@ -12,6 +12,7 @@ import java.util.Collections; import java.util.List; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.Action; @@ -72,12 +73,12 @@ public class InternalClient extends FilterClient { } final ThreadContext threadContext = threadPool().getThreadContext(); - final ThreadContext.StoredContext storedContext = threadContext.newStoredContext(); + final Supplier storedContext = threadContext.newRestorableContext(true); // we need to preserve the context here otherwise we execute the response with the XPack user which we can cause problems // since we expect the callback to run with the authenticated user calling the doExecute method try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { processContext(threadContext); - super.doExecute(action, request, new ContextPreservingActionListener<>(threadContext, storedContext, listener)); + super.doExecute(action, request, new ContextPreservingActionListener<>(storedContext, listener)); } } diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/SecurityContext.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/SecurityContext.java index dc06f0fcd3e..7ca4f825dd9 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/SecurityContext.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/SecurityContext.java @@ -90,7 +90,7 @@ public class SecurityContext { * returns, the original context is restored. */ public void executeAsUser(User user, Consumer consumer) { - final StoredContext original = threadContext.newStoredContext(); + final StoredContext original = threadContext.newStoredContext(true); try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { setUser(user); consumer.accept(original); diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java index 8a3f742ba89..8ca0a570835 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java @@ -49,6 +49,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Set; import java.util.function.Predicate; +import java.util.function.Supplier; import static org.elasticsearch.xpack.security.support.Exceptions.authorizationError; @@ -101,8 +102,8 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil if (licenseState.isAuthAllowed()) { final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action); - final ThreadContext.StoredContext toRestore = threadContext.newStoredContext(); - final ActionListener signingListener = new ContextPreservingActionListener<>(threadContext, toRestore, + final Supplier toRestore = threadContext.newRestorableContext(true); + final ActionListener signingListener = new ContextPreservingActionListener<>(toRestore, ActionListener.wrap(r -> { try { listener.onResponse(sign(r)); @@ -122,7 +123,7 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil } }); } else { - try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { + try (ThreadContext.StoredContext ignore = threadContext.newStoredContext(true)) { applyInternal(action, request, authenticatedListener); } } diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java index 9b85e385a31..2ec8b6015b0 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java @@ -92,14 +92,16 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor // which means that the user is copied over to system actions so we need to change the user if (AuthorizationUtils.shouldReplaceUserWithSystem(threadPool.getThreadContext(), action)) { securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> sendWithUser(connection, action, request, options, - new ContextRestoreResponseHandler<>(threadPool.getThreadContext(), original, handler), sender)); + new TransportService.ContextRestoreResponseHandler<>(threadPool.getThreadContext().wrapRestorable(original) + , handler), sender)); } else if (reservedRealmEnabled && connection.getVersion().before(Version.V_5_2_0_UNRELEASED) && KibanaUser.NAME.equals(securityContext.getUser().principal())) { final User kibanaUser = securityContext.getUser(); final User bwcKibanaUser = new User(kibanaUser.principal(), new String[] { "kibana" }, kibanaUser.fullName(), kibanaUser.email(), kibanaUser.metadata(), kibanaUser.enabled()); securityContext.executeAsUser(bwcKibanaUser, (original) -> sendWithUser(connection, action, request, options, - new ContextRestoreResponseHandler<>(threadPool.getThreadContext(), original, handler), sender)); + new TransportService.ContextRestoreResponseHandler<>(threadPool.getThreadContext().wrapRestorable(original), + handler), sender)); } else { sendWithUser(connection, action, request, options, handler, sender); } @@ -212,7 +214,7 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor RequestContext.removeCurrent(); } }; - try (ThreadContext.StoredContext ctx = threadContext.newStoredContext()) { + try (ThreadContext.StoredContext ctx = threadContext.newStoredContext(true)) { if (licenseState.isAuthAllowed()) { String profile = channel.getProfileName(); ServerTransportFilter filter = profileFilters.get(profile); @@ -265,56 +267,4 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor throw new UnsupportedOperationException("task parameter is required for this operation"); } } - - /** - * This handler wrapper ensures that the response thread executes with the correct thread context. Before any of the handle methods - * are invoked we restore the context. - */ - static final class ContextRestoreResponseHandler implements TransportResponseHandler { - - private final TransportResponseHandler delegate; - private final ThreadContext.StoredContext context; - private final ThreadContext threadContext; - - // pkg private for testing - ContextRestoreResponseHandler(ThreadContext threadContext, ThreadContext.StoredContext context, - TransportResponseHandler delegate) { - this.delegate = delegate; - this.context = context; - this.threadContext = threadContext; - } - - @Override - public T newInstance() { - return delegate.newInstance(); - } - - @Override - public void handleResponse(T response) { - try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { - context.restore(); - delegate.handleResponse(response); - } - } - - @Override - public void handleException(TransportException exp) { - try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { - context.restore(); - delegate.handleException(exp); - } - } - - @Override - public String executor() { - return delegate.executor(); - } - - @Override - public String toString() { - return getClass().getName() + "/" + delegate.toString(); - } - - } - } diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/ServerTransportFilter.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/ServerTransportFilter.java index ee5307cbb16..651c0e32b58 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/ServerTransportFilter.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/ServerTransportFilter.java @@ -30,7 +30,6 @@ import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.authc.pki.PkiRealm; import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.elasticsearch.xpack.security.authz.AuthorizationUtils; -import org.elasticsearch.xpack.security.transport.SecurityServerTransportInterceptor.ContextRestoreResponseHandler; import org.elasticsearch.xpack.security.user.KibanaUser; import org.elasticsearch.xpack.security.user.User; diff --git a/elasticsearch/src/test/java/org/elasticsearch/xpack/common/ContextPreservingActionListenerTests.java b/elasticsearch/src/test/java/org/elasticsearch/xpack/common/ContextPreservingActionListenerTests.java index c0a79121547..5fa6325f349 100644 --- a/elasticsearch/src/test/java/org/elasticsearch/xpack/common/ContextPreservingActionListenerTests.java +++ b/elasticsearch/src/test/java/org/elasticsearch/xpack/common/ContextPreservingActionListenerTests.java @@ -23,7 +23,7 @@ public class ContextPreservingActionListenerTests extends ESTestCase { ContextPreservingActionListener actionListener; try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { threadContext.putHeader("foo", "bar"); - actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(), + actionListener = new ContextPreservingActionListener<>(threadContext.newRestorableContext(true), new ActionListener() { @Override public void onResponse(Void aVoid) { @@ -57,7 +57,7 @@ public class ContextPreservingActionListenerTests extends ESTestCase { ContextPreservingActionListener actionListener; try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { threadContext.putHeader("foo", "bar"); - actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(), + actionListener = new ContextPreservingActionListener<>(threadContext.newRestorableContext(true), new ActionListener() { @Override public void onResponse(Void aVoid) { @@ -91,7 +91,7 @@ public class ContextPreservingActionListenerTests extends ESTestCase { ContextPreservingActionListener actionListener; try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { threadContext.putHeader("foo", "bar"); - actionListener = new ContextPreservingActionListener<>(threadContext, threadContext.newStoredContext(), + actionListener = new ContextPreservingActionListener<>(threadContext.newRestorableContext(true), new ActionListener() { @Override public void onResponse(Void aVoid) { diff --git a/elasticsearch/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java b/elasticsearch/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java index 316a0e84cb2..247e0bc3187 100644 --- a/elasticsearch/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java +++ b/elasticsearch/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponse.Empty; import org.elasticsearch.transport.TransportResponseHandler; +import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.XPackSettings; import org.elasticsearch.xpack.security.SecurityContext; import org.elasticsearch.xpack.security.authc.Authentication; @@ -29,7 +30,6 @@ import org.elasticsearch.xpack.security.authc.Authentication.RealmRef; import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.elasticsearch.xpack.security.crypto.CryptoService; -import org.elasticsearch.xpack.security.transport.SecurityServerTransportInterceptor.ContextRestoreResponseHandler; import org.elasticsearch.xpack.security.user.KibanaUser; import org.elasticsearch.xpack.security.user.SystemUser; import org.elasticsearch.xpack.security.user.User; @@ -254,8 +254,8 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase { try (ThreadContext.StoredContext storedContext = threadContext.stashContext()) { threadContext.putTransient("foo", "different_bar"); threadContext.putHeader("key", "value2"); - TransportResponseHandler handler = new ContextRestoreResponseHandler<>(threadContext, storedContext, - new TransportResponseHandler() { + TransportResponseHandler handler = new TransportService.ContextRestoreResponseHandler<>( + threadContext.wrapRestorable(storedContext), new TransportResponseHandler() { @Override public Empty newInstance() { @@ -293,8 +293,7 @@ public class SecurityServerTransportInterceptorTests extends ESTestCase { try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { threadContext.putTransient("foo", "different_bar"); threadContext.putHeader("key", "value2"); - handler = new ContextRestoreResponseHandler<>(threadContext, - threadContext.newStoredContext(), + handler = new TransportService.ContextRestoreResponseHandler<>(threadContext.newRestorableContext(true), new TransportResponseHandler() { @Override