diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java new file mode 100644 index 00000000000..9cf1bb8292f --- /dev/null +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.common; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.ThreadContext; + +/** + * Restores the given {@link org.elasticsearch.common.util.concurrent.ThreadContext.StoredContext} + * once the listener is invoked + */ +public final class ContextPreservingActionListener implements ActionListener { + + private final ActionListener delegate; + private final ThreadContext.StoredContext context; + + public ContextPreservingActionListener(ThreadContext.StoredContext context, ActionListener delegate) { + this.delegate = delegate; + this.context = context; + } + + @Override + public void onResponse(R r) { + context.restore(); + delegate.onResponse(r); + } + + @Override + public void onFailure(Exception e) { + context.restore(); + delegate.onFailure(e); + } +} diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java index c3f4865871b..9eebfab525a 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.node.Node; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.common.ContextPreservingActionListener; import org.elasticsearch.xpack.security.authc.Authentication; import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.crypto.CryptoService; @@ -57,15 +58,23 @@ public class InternalClient extends FilterClient { return; } - try (ThreadContext.StoredContext ctx = threadPool().getThreadContext().stashContext()) { - try { - Authentication authentication = new Authentication(XPackUser.INSTANCE, + final ThreadContext threadContext = threadPool().getThreadContext(); + final ThreadContext.StoredContext storedContext = threadContext.newStoredContext(); + // we need to preserve the context here otherwise we execute the response with the XPack user which we can cause problems + // since we expect the callback to run wiht the authenticated user calling the doExecute method + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + processContext(threadContext); + super.doExecute(action, request, new ContextPreservingActionListener<>(storedContext, listener)); + } + } + + protected void processContext(ThreadContext threadContext) { + try { + Authentication authentication = new Authentication(XPackUser.INSTANCE, new Authentication.RealmRef("__attach", "__attach", nodeName), null); - authentication.writeToContext(threadPool().getThreadContext(), cryptoService, signUserHeader); - } catch (IOException ioe) { - throw new ElasticsearchException("failed to attach internal user to request", ioe); - } - super.doExecute(action, request, listener); + authentication.writeToContext(threadContext, cryptoService, signUserHeader); + } catch (IOException ioe) { + throw new ElasticsearchException("failed to attach internal user to request", ioe); } } } diff --git a/elasticsearch/src/test/java/org/elasticsearch/xpack/security/InternalClientTests.java b/elasticsearch/src/test/java/org/elasticsearch/xpack/security/InternalClientTests.java new file mode 100644 index 00000000000..dcae61a7742 --- /dev/null +++ b/elasticsearch/src/test/java/org/elasticsearch/xpack/security/InternalClientTests.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.security; + +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.FilterClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.env.Environment; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.yaml.section.Assertion; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.security.crypto.CryptoService; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +public class InternalClientTests extends ESTestCase { + private ThreadPool threadPool; + + @Override + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(InternalClientTests.class.getName()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + + public void testContextIsPreserved() throws IOException, InterruptedException { + FilterClient dummy = new FilterClient(Settings.EMPTY, threadPool, null) { + @Override + protected , Response extends ActionResponse, RequestBuilder extends + ActionRequestBuilder> void doExecute(Action + action, Request request, + ActionListener listener) { + threadPool.executor(ThreadPool.Names.GENERIC).execute(() -> listener.onResponse(null)); + } + }; + + Path tempDir = createTempDir(); + InternalClient client = new InternalClient(Settings.EMPTY, threadPool, dummy, new CryptoService(Settings.EMPTY, + new Environment(Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), tempDir.toString()).build()))) { + @Override + protected void processContext(ThreadContext threadContext) { + threadContext.putTransient("foo", "boom"); + } + }; + try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putTransient("foo", "bar"); + client.prepareSearch("boom").get(); + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } + + try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putTransient("foo", "bar"); + CountDownLatch latch = new CountDownLatch(1); + client.prepareSearch("boom").execute(new ActionListener() { + @Override + public void onResponse(SearchResponse searchResponse) { + try { + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + try { + throw new AssertionError(e); + } finally { + latch.countDown(); + } + + } + }); + latch.await(); + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } + } + + public void testContextIsPreservedOnError() throws IOException, InterruptedException { + FilterClient dummy = new FilterClient(Settings.EMPTY, threadPool, null) { + @Override + protected , Response extends ActionResponse, RequestBuilder extends + ActionRequestBuilder> void doExecute(Action + action, Request request, + ActionListener listener) { + threadPool.executor(ThreadPool.Names.GENERIC).execute(() -> listener.onFailure(new Exception("boom bam bang"))); + } + }; + + Path tempDir = createTempDir(); + InternalClient client = new InternalClient(Settings.EMPTY, threadPool, dummy, new CryptoService(Settings.EMPTY, + new Environment(Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), tempDir.toString()).build()))) { + @Override + protected void processContext(ThreadContext threadContext) { + threadContext.putTransient("foo", "boom"); + } + }; + try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putTransient("foo", "bar"); + try { + client.prepareSearch("boom").get(); + } catch (Exception ex) { + assertEquals("boom bam bang", ex.getCause().getCause().getMessage()); + + } + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } + + try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putTransient("foo", "bar"); + CountDownLatch latch = new CountDownLatch(1); + client.prepareSearch("boom").execute(new ActionListener() { + @Override + public void onResponse(SearchResponse searchResponse) { + try { + throw new AssertionError("exception expected"); + } finally { + latch.countDown(); + } + + } + + @Override + public void onFailure(Exception e) { + try { + assertEquals("boom bam bang", e.getMessage()); + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } finally { + latch.countDown(); + } + + } + }); + latch.await(); + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } + } +}