Restore thread-context when executing with InternalClient (elastic/elasticsearch#3859)

Today when a request is executed with InternalClient the thread context might
be lost if another component like security exchanges it by executing an async call
or an internal action. This can be a serious security problem since if the async
call executes as the system user all subsequent calls made by the response
thread will also execute as the system user instead.

Original commit: elastic/x-pack-elasticsearch@80682f338d
This commit is contained in:
Simon Willnauer 2016-10-24 14:39:00 +02:00 committed by GitHub
parent 51b871f344
commit f8ba7f6fd8
3 changed files with 211 additions and 8 deletions

View File

@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.common;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.ThreadContext;
/**
* Restores the given {@link org.elasticsearch.common.util.concurrent.ThreadContext.StoredContext}
* once the listener is invoked
*/
public final class ContextPreservingActionListener<R> implements ActionListener<R> {
private final ActionListener<R> delegate;
private final ThreadContext.StoredContext context;
public ContextPreservingActionListener(ThreadContext.StoredContext context, ActionListener<R> delegate) {
this.delegate = delegate;
this.context = context;
}
@Override
public void onResponse(R r) {
context.restore();
delegate.onResponse(r);
}
@Override
public void onFailure(Exception e) {
context.restore();
delegate.onFailure(e);
}
}

View File

@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.node.Node; import org.elasticsearch.node.Node;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.common.ContextPreservingActionListener;
import org.elasticsearch.xpack.security.authc.Authentication; import org.elasticsearch.xpack.security.authc.Authentication;
import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.crypto.CryptoService; import org.elasticsearch.xpack.security.crypto.CryptoService;
@ -57,15 +58,23 @@ public class InternalClient extends FilterClient {
return; return;
} }
try (ThreadContext.StoredContext ctx = threadPool().getThreadContext().stashContext()) { final ThreadContext threadContext = threadPool().getThreadContext();
try { final ThreadContext.StoredContext storedContext = threadContext.newStoredContext();
Authentication authentication = new Authentication(XPackUser.INSTANCE, // we need to preserve the context here otherwise we execute the response with the XPack user which we can cause problems
// since we expect the callback to run wiht the authenticated user calling the doExecute method
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
processContext(threadContext);
super.doExecute(action, request, new ContextPreservingActionListener<>(storedContext, listener));
}
}
protected void processContext(ThreadContext threadContext) {
try {
Authentication authentication = new Authentication(XPackUser.INSTANCE,
new Authentication.RealmRef("__attach", "__attach", nodeName), null); new Authentication.RealmRef("__attach", "__attach", nodeName), null);
authentication.writeToContext(threadPool().getThreadContext(), cryptoService, signUserHeader); authentication.writeToContext(threadContext, cryptoService, signUserHeader);
} catch (IOException ioe) { } catch (IOException ioe) {
throw new ElasticsearchException("failed to attach internal user to request", ioe); throw new ElasticsearchException("failed to attach internal user to request", ioe);
}
super.doExecute(action, request, listener);
} }
} }
} }

View File

@ -0,0 +1,158 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.Action;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.FilterClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.yaml.section.Assertion;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.security.crypto.CryptoService;
import java.io.IOException;
import java.nio.file.Path;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
public class InternalClientTests extends ESTestCase {
private ThreadPool threadPool;
@Override
public void setUp() throws Exception {
super.setUp();
threadPool = new TestThreadPool(InternalClientTests.class.getName());
}
@Override
public void tearDown() throws Exception {
super.tearDown();
ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS);
}
public void testContextIsPreserved() throws IOException, InterruptedException {
FilterClient dummy = new FilterClient(Settings.EMPTY, threadPool, null) {
@Override
protected <Request extends ActionRequest<Request>, Response extends ActionResponse, RequestBuilder extends
ActionRequestBuilder<Request, Response, RequestBuilder>> void doExecute(Action<Request, Response, RequestBuilder>
action, Request request,
ActionListener<Response> listener) {
threadPool.executor(ThreadPool.Names.GENERIC).execute(() -> listener.onResponse(null));
}
};
Path tempDir = createTempDir();
InternalClient client = new InternalClient(Settings.EMPTY, threadPool, dummy, new CryptoService(Settings.EMPTY,
new Environment(Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), tempDir.toString()).build()))) {
@Override
protected void processContext(ThreadContext threadContext) {
threadContext.putTransient("foo", "boom");
}
};
try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) {
threadPool.getThreadContext().putTransient("foo", "bar");
client.prepareSearch("boom").get();
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
}
try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) {
threadPool.getThreadContext().putTransient("foo", "bar");
CountDownLatch latch = new CountDownLatch(1);
client.prepareSearch("boom").execute(new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse searchResponse) {
try {
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
} finally {
latch.countDown();
}
}
@Override
public void onFailure(Exception e) {
try {
throw new AssertionError(e);
} finally {
latch.countDown();
}
}
});
latch.await();
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
}
}
public void testContextIsPreservedOnError() throws IOException, InterruptedException {
FilterClient dummy = new FilterClient(Settings.EMPTY, threadPool, null) {
@Override
protected <Request extends ActionRequest<Request>, Response extends ActionResponse, RequestBuilder extends
ActionRequestBuilder<Request, Response, RequestBuilder>> void doExecute(Action<Request, Response, RequestBuilder>
action, Request request,
ActionListener<Response> listener) {
threadPool.executor(ThreadPool.Names.GENERIC).execute(() -> listener.onFailure(new Exception("boom bam bang")));
}
};
Path tempDir = createTempDir();
InternalClient client = new InternalClient(Settings.EMPTY, threadPool, dummy, new CryptoService(Settings.EMPTY,
new Environment(Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), tempDir.toString()).build()))) {
@Override
protected void processContext(ThreadContext threadContext) {
threadContext.putTransient("foo", "boom");
}
};
try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) {
threadPool.getThreadContext().putTransient("foo", "bar");
try {
client.prepareSearch("boom").get();
} catch (Exception ex) {
assertEquals("boom bam bang", ex.getCause().getCause().getMessage());
}
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
}
try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) {
threadPool.getThreadContext().putTransient("foo", "bar");
CountDownLatch latch = new CountDownLatch(1);
client.prepareSearch("boom").execute(new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse searchResponse) {
try {
throw new AssertionError("exception expected");
} finally {
latch.countDown();
}
}
@Override
public void onFailure(Exception e) {
try {
assertEquals("boom bam bang", e.getMessage());
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
} finally {
latch.countDown();
}
}
});
latch.await();
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
}
}
}