Restore thread-context when executing with InternalClient (elastic/elasticsearch#3859)
Today when a request is executed with InternalClient the thread context might be lost if another component like security exchanges it by executing an async call or an internal action. This can be a serious security problem since if the async call executes as the system user all subsequent calls made by the response thread will also execute as the system user instead. Original commit: elastic/x-pack-elasticsearch@80682f338d
This commit is contained in:
parent
51b871f344
commit
f8ba7f6fd8
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.common;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
|
||||
/**
|
||||
* Restores the given {@link org.elasticsearch.common.util.concurrent.ThreadContext.StoredContext}
|
||||
* once the listener is invoked
|
||||
*/
|
||||
public final class ContextPreservingActionListener<R> implements ActionListener<R> {
|
||||
|
||||
private final ActionListener<R> delegate;
|
||||
private final ThreadContext.StoredContext context;
|
||||
|
||||
public ContextPreservingActionListener(ThreadContext.StoredContext context, ActionListener<R> delegate) {
|
||||
this.delegate = delegate;
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(R r) {
|
||||
context.restore();
|
||||
delegate.onResponse(r);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
context.restore();
|
||||
delegate.onFailure(e);
|
||||
}
|
||||
}
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings;
|
|||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.node.Node;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.common.ContextPreservingActionListener;
|
||||
import org.elasticsearch.xpack.security.authc.Authentication;
|
||||
import org.elasticsearch.xpack.security.authc.AuthenticationService;
|
||||
import org.elasticsearch.xpack.security.crypto.CryptoService;
|
||||
|
@ -57,15 +58,23 @@ public class InternalClient extends FilterClient {
|
|||
return;
|
||||
}
|
||||
|
||||
try (ThreadContext.StoredContext ctx = threadPool().getThreadContext().stashContext()) {
|
||||
final ThreadContext threadContext = threadPool().getThreadContext();
|
||||
final ThreadContext.StoredContext storedContext = threadContext.newStoredContext();
|
||||
// we need to preserve the context here otherwise we execute the response with the XPack user which we can cause problems
|
||||
// since we expect the callback to run wiht the authenticated user calling the doExecute method
|
||||
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
|
||||
processContext(threadContext);
|
||||
super.doExecute(action, request, new ContextPreservingActionListener<>(storedContext, listener));
|
||||
}
|
||||
}
|
||||
|
||||
protected void processContext(ThreadContext threadContext) {
|
||||
try {
|
||||
Authentication authentication = new Authentication(XPackUser.INSTANCE,
|
||||
new Authentication.RealmRef("__attach", "__attach", nodeName), null);
|
||||
authentication.writeToContext(threadPool().getThreadContext(), cryptoService, signUserHeader);
|
||||
authentication.writeToContext(threadContext, cryptoService, signUserHeader);
|
||||
} catch (IOException ioe) {
|
||||
throw new ElasticsearchException("failed to attach internal user to request", ioe);
|
||||
}
|
||||
super.doExecute(action, request, listener);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.security;
|
||||
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.action.Action;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.ActionRequest;
|
||||
import org.elasticsearch.action.ActionRequestBuilder;
|
||||
import org.elasticsearch.action.ActionResponse;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.client.FilterClient;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.env.Environment;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.rest.yaml.section.Assertion;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.security.crypto.CryptoService;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
public class InternalClientTests extends ESTestCase {
|
||||
private ThreadPool threadPool;
|
||||
|
||||
@Override
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
threadPool = new TestThreadPool(InternalClientTests.class.getName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void tearDown() throws Exception {
|
||||
super.tearDown();
|
||||
ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
public void testContextIsPreserved() throws IOException, InterruptedException {
|
||||
FilterClient dummy = new FilterClient(Settings.EMPTY, threadPool, null) {
|
||||
@Override
|
||||
protected <Request extends ActionRequest<Request>, Response extends ActionResponse, RequestBuilder extends
|
||||
ActionRequestBuilder<Request, Response, RequestBuilder>> void doExecute(Action<Request, Response, RequestBuilder>
|
||||
action, Request request,
|
||||
ActionListener<Response> listener) {
|
||||
threadPool.executor(ThreadPool.Names.GENERIC).execute(() -> listener.onResponse(null));
|
||||
}
|
||||
};
|
||||
|
||||
Path tempDir = createTempDir();
|
||||
InternalClient client = new InternalClient(Settings.EMPTY, threadPool, dummy, new CryptoService(Settings.EMPTY,
|
||||
new Environment(Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), tempDir.toString()).build()))) {
|
||||
@Override
|
||||
protected void processContext(ThreadContext threadContext) {
|
||||
threadContext.putTransient("foo", "boom");
|
||||
}
|
||||
};
|
||||
try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) {
|
||||
threadPool.getThreadContext().putTransient("foo", "bar");
|
||||
client.prepareSearch("boom").get();
|
||||
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
|
||||
}
|
||||
|
||||
try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) {
|
||||
threadPool.getThreadContext().putTransient("foo", "bar");
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
client.prepareSearch("boom").execute(new ActionListener<SearchResponse>() {
|
||||
@Override
|
||||
public void onResponse(SearchResponse searchResponse) {
|
||||
try {
|
||||
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
|
||||
} finally {
|
||||
latch.countDown();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
try {
|
||||
throw new AssertionError(e);
|
||||
} finally {
|
||||
latch.countDown();
|
||||
}
|
||||
|
||||
}
|
||||
});
|
||||
latch.await();
|
||||
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testContextIsPreservedOnError() throws IOException, InterruptedException {
|
||||
FilterClient dummy = new FilterClient(Settings.EMPTY, threadPool, null) {
|
||||
@Override
|
||||
protected <Request extends ActionRequest<Request>, Response extends ActionResponse, RequestBuilder extends
|
||||
ActionRequestBuilder<Request, Response, RequestBuilder>> void doExecute(Action<Request, Response, RequestBuilder>
|
||||
action, Request request,
|
||||
ActionListener<Response> listener) {
|
||||
threadPool.executor(ThreadPool.Names.GENERIC).execute(() -> listener.onFailure(new Exception("boom bam bang")));
|
||||
}
|
||||
};
|
||||
|
||||
Path tempDir = createTempDir();
|
||||
InternalClient client = new InternalClient(Settings.EMPTY, threadPool, dummy, new CryptoService(Settings.EMPTY,
|
||||
new Environment(Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), tempDir.toString()).build()))) {
|
||||
@Override
|
||||
protected void processContext(ThreadContext threadContext) {
|
||||
threadContext.putTransient("foo", "boom");
|
||||
}
|
||||
};
|
||||
try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) {
|
||||
threadPool.getThreadContext().putTransient("foo", "bar");
|
||||
try {
|
||||
client.prepareSearch("boom").get();
|
||||
} catch (Exception ex) {
|
||||
assertEquals("boom bam bang", ex.getCause().getCause().getMessage());
|
||||
|
||||
}
|
||||
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
|
||||
}
|
||||
|
||||
try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) {
|
||||
threadPool.getThreadContext().putTransient("foo", "bar");
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
client.prepareSearch("boom").execute(new ActionListener<SearchResponse>() {
|
||||
@Override
|
||||
public void onResponse(SearchResponse searchResponse) {
|
||||
try {
|
||||
throw new AssertionError("exception expected");
|
||||
} finally {
|
||||
latch.countDown();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
try {
|
||||
assertEquals("boom bam bang", e.getMessage());
|
||||
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
|
||||
} finally {
|
||||
latch.countDown();
|
||||
}
|
||||
|
||||
}
|
||||
});
|
||||
latch.await();
|
||||
assertEquals("bar", threadPool.getThreadContext().getTransient("foo"));
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue