From f8ba7f6fd8d1b92c34e26e30ef088dd9a3c3ad07 Mon Sep 17 00:00:00 2001 From: Simon Willnauer Date: Mon, 24 Oct 2016 14:39:00 +0200 Subject: [PATCH] Restore thread-context when executing with InternalClient (elastic/elasticsearch#3859) Today when a request is executed with InternalClient the thread context might be lost if another component like security exchanges it by executing an async call or an internal action. This can be a serious security problem since if the async call executes as the system user all subsequent calls made by the response thread will also execute as the system user instead. Original commit: elastic/x-pack-elasticsearch@80682f338dde3a492b4eee82b131bac2dd66da6c --- .../ContextPreservingActionListener.java | 36 ++++ .../xpack/security/InternalClient.java | 25 ++- .../xpack/security/InternalClientTests.java | 158 ++++++++++++++++++ 3 files changed, 211 insertions(+), 8 deletions(-) create mode 100644 elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java create mode 100644 elasticsearch/src/test/java/org/elasticsearch/xpack/security/InternalClientTests.java diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java new file mode 100644 index 00000000000..9cf1bb8292f --- /dev/null +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/common/ContextPreservingActionListener.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.common; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.ThreadContext; + +/** + * Restores the given {@link org.elasticsearch.common.util.concurrent.ThreadContext.StoredContext} + * once the listener is invoked + */ +public final class ContextPreservingActionListener implements ActionListener { + + private final ActionListener delegate; + private final ThreadContext.StoredContext context; + + public ContextPreservingActionListener(ThreadContext.StoredContext context, ActionListener delegate) { + this.delegate = delegate; + this.context = context; + } + + @Override + public void onResponse(R r) { + context.restore(); + delegate.onResponse(r); + } + + @Override + public void onFailure(Exception e) { + context.restore(); + delegate.onFailure(e); + } +} diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java index c3f4865871b..9eebfab525a 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/InternalClient.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.node.Node; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.common.ContextPreservingActionListener; import org.elasticsearch.xpack.security.authc.Authentication; import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.crypto.CryptoService; @@ -57,15 +58,23 @@ public class InternalClient extends FilterClient { return; } - try (ThreadContext.StoredContext ctx = threadPool().getThreadContext().stashContext()) { - try { - Authentication authentication = new Authentication(XPackUser.INSTANCE, + final ThreadContext threadContext = threadPool().getThreadContext(); + final ThreadContext.StoredContext storedContext = threadContext.newStoredContext(); + // we need to preserve the context here otherwise we execute the response with the XPack user which we can cause problems + // since we expect the callback to run wiht the authenticated user calling the doExecute method + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + processContext(threadContext); + super.doExecute(action, request, new ContextPreservingActionListener<>(storedContext, listener)); + } + } + + protected void processContext(ThreadContext threadContext) { + try { + Authentication authentication = new Authentication(XPackUser.INSTANCE, new Authentication.RealmRef("__attach", "__attach", nodeName), null); - authentication.writeToContext(threadPool().getThreadContext(), cryptoService, signUserHeader); - } catch (IOException ioe) { - throw new ElasticsearchException("failed to attach internal user to request", ioe); - } - super.doExecute(action, request, listener); + authentication.writeToContext(threadContext, cryptoService, signUserHeader); + } catch (IOException ioe) { + throw new ElasticsearchException("failed to attach internal user to request", ioe); } } } diff --git a/elasticsearch/src/test/java/org/elasticsearch/xpack/security/InternalClientTests.java b/elasticsearch/src/test/java/org/elasticsearch/xpack/security/InternalClientTests.java new file mode 100644 index 00000000000..dcae61a7742 --- /dev/null +++ b/elasticsearch/src/test/java/org/elasticsearch/xpack/security/InternalClientTests.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.security; + +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.FilterClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.env.Environment; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.yaml.section.Assertion; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.security.crypto.CryptoService; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +public class InternalClientTests extends ESTestCase { + private ThreadPool threadPool; + + @Override + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(InternalClientTests.class.getName()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + + public void testContextIsPreserved() throws IOException, InterruptedException { + FilterClient dummy = new FilterClient(Settings.EMPTY, threadPool, null) { + @Override + protected , Response extends ActionResponse, RequestBuilder extends + ActionRequestBuilder> void doExecute(Action + action, Request request, + ActionListener listener) { + threadPool.executor(ThreadPool.Names.GENERIC).execute(() -> listener.onResponse(null)); + } + }; + + Path tempDir = createTempDir(); + InternalClient client = new InternalClient(Settings.EMPTY, threadPool, dummy, new CryptoService(Settings.EMPTY, + new Environment(Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), tempDir.toString()).build()))) { + @Override + protected void processContext(ThreadContext threadContext) { + threadContext.putTransient("foo", "boom"); + } + }; + try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putTransient("foo", "bar"); + client.prepareSearch("boom").get(); + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } + + try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putTransient("foo", "bar"); + CountDownLatch latch = new CountDownLatch(1); + client.prepareSearch("boom").execute(new ActionListener() { + @Override + public void onResponse(SearchResponse searchResponse) { + try { + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + try { + throw new AssertionError(e); + } finally { + latch.countDown(); + } + + } + }); + latch.await(); + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } + } + + public void testContextIsPreservedOnError() throws IOException, InterruptedException { + FilterClient dummy = new FilterClient(Settings.EMPTY, threadPool, null) { + @Override + protected , Response extends ActionResponse, RequestBuilder extends + ActionRequestBuilder> void doExecute(Action + action, Request request, + ActionListener listener) { + threadPool.executor(ThreadPool.Names.GENERIC).execute(() -> listener.onFailure(new Exception("boom bam bang"))); + } + }; + + Path tempDir = createTempDir(); + InternalClient client = new InternalClient(Settings.EMPTY, threadPool, dummy, new CryptoService(Settings.EMPTY, + new Environment(Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), tempDir.toString()).build()))) { + @Override + protected void processContext(ThreadContext threadContext) { + threadContext.putTransient("foo", "boom"); + } + }; + try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putTransient("foo", "bar"); + try { + client.prepareSearch("boom").get(); + } catch (Exception ex) { + assertEquals("boom bam bang", ex.getCause().getCause().getMessage()); + + } + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } + + try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putTransient("foo", "bar"); + CountDownLatch latch = new CountDownLatch(1); + client.prepareSearch("boom").execute(new ActionListener() { + @Override + public void onResponse(SearchResponse searchResponse) { + try { + throw new AssertionError("exception expected"); + } finally { + latch.countDown(); + } + + } + + @Override + public void onFailure(Exception e) { + try { + assertEquals("boom bam bang", e.getMessage()); + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } finally { + latch.countDown(); + } + + } + }); + latch.await(); + assertEquals("bar", threadPool.getThreadContext().getTransient("foo")); + } + } +}