IteratingActionListener should store context before calling consumers (elastic/x-pack-elasticsearch#675)

As part of authentication, we use a iterating action listener to perform asynchronous authentication against the realm
chain. When this listener is called with a response or a failure, it could be called from a thread that is not owned by
the Elasticsearch threadpool such as a LDAPConnectionReader thread. When this happens, we need to ensure that the
ThreadContext is not left with items in it otherwise we leave behind things like Authentication and hit obscure errors.

This commit stores the context when the listener calls the consumer or onResponse/onFailure is invoked, which prevents
us from polluting a external thread's ThreadContext.

Original commit: elastic/x-pack-elasticsearch@0f50fb6c10
This commit is contained in:
Jay Modi 2017-03-01 10:40:42 -05:00 committed by GitHub
parent 61ca6d435f
commit 02579c7acc
4 changed files with 196 additions and 84 deletions

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.common;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import java.util.Collections;
import java.util.List;
@ -27,13 +28,16 @@ public final class IteratingActionListener<T, U> implements ActionListener<T>, R
private final List<U> consumables;
private final ActionListener<T> delegate;
private final BiConsumer<U, ActionListener<T>> consumer;
private final ThreadContext threadContext;
private int position = 0;
public IteratingActionListener(ActionListener<T> delegate, BiConsumer<U, ActionListener<T>> consumer, List<U> consumables) {
public IteratingActionListener(ActionListener<T> delegate, BiConsumer<U, ActionListener<T>> consumer, List<U> consumables,
ThreadContext threadContext) {
this.delegate = delegate;
this.consumer = consumer;
this.consumables = Collections.unmodifiableList(consumables);
this.threadContext = threadContext;
}
@Override
@ -43,25 +47,35 @@ public final class IteratingActionListener<T, U> implements ActionListener<T>, R
} else if (position < 0 || position >= consumables.size()) {
onFailure(new IllegalStateException("invalid position [" + position + "]. List size [" + consumables.size() + "]"));
} else {
consumer.accept(consumables.get(position++), this);
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext(false)) {
consumer.accept(consumables.get(position++), this);
}
}
}
@Override
public void onResponse(T response) {
if (response == null) {
if (position == consumables.size()) {
delegate.onResponse(null);
// we need to store the context here as there is a chance that this method is called from a thread outside of the ThreadPool
// like a LDAP connection reader thread and we can pollute the context in certain cases
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext(false)) {
if (response == null) {
if (position == consumables.size()) {
delegate.onResponse(null);
} else {
consumer.accept(consumables.get(position++), this);
}
} else {
consumer.accept(consumables.get(position++), this);
delegate.onResponse(response);
}
} else {
delegate.onResponse(response);
}
}
@Override
public void onFailure(Exception e) {
delegate.onFailure(e);
// we need to store the context here as there is a chance that this method is called from a thread outside of the ThreadPool
// like a LDAP connection reader thread and we can pollute the context in certain cases
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext(false)) {
delegate.onFailure(e);
}
}
}

View File

@ -251,7 +251,7 @@ public class AuthenticationService extends AbstractComponent {
final IteratingActionListener<User, Realm> authenticatingListener =
new IteratingActionListener<>(ActionListener.wrap(this::consumeUser,
(e) -> listener.onFailure(request.exceptionProcessingRequest(e, token))),
realmAuthenticatingConsumer, realmsList);
realmAuthenticatingConsumer, realmsList, threadContext);
try {
authenticatingListener.run();
} catch (Exception e) {
@ -352,7 +352,7 @@ public class AuthenticationService extends AbstractComponent {
final IteratingActionListener<User, Realm> userLookupListener =
new IteratingActionListener<>(ActionListener.wrap((lookupUser) -> userConsumer.accept(new User(user, lookupUser)),
(e) -> listener.onFailure(request.exceptionProcessingRequest(e, authenticationToken))),
realmLookupConsumer, realmsList);
realmLookupConsumer, realmsList, threadContext);
try {
userLookupListener.run();
} catch (Exception e) {

View File

@ -8,6 +8,8 @@ package org.elasticsearch.xpack.common;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.collect.HppcMaps.Object;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.junit.Assert;
@ -46,13 +48,52 @@ public class IteratingActionListenerTests extends ESTestCase {
}, (e) -> {
logger.error("unexpected exception", e);
fail("exception should not have been thrown");
}), consumer, items);
}), consumer, items, new ThreadContext(Settings.EMPTY));
iteratingListener.run();
// we never really went async, its all chained together so verify this for sanity
assertEquals(numberOfIterations, iterations.get());
}
public void testIterationDoesntAllowThreadContextLeak() {
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
final int numberOfItems = scaledRandomIntBetween(1, 32);
final int numberOfIterations = scaledRandomIntBetween(1, numberOfItems);
List<Object> items = new ArrayList<>(numberOfItems);
for (int i = 0; i < numberOfItems; i++) {
items.add(new Object());
}
threadContext.putHeader("outside", "listener");
final AtomicInteger iterations = new AtomicInteger(0);
final BiConsumer<Object, ActionListener<Object>> consumer = (listValue, listener) -> {
final int current = iterations.incrementAndGet();
assertEquals("listener", threadContext.getHeader("outside"));
if (current == numberOfIterations) {
threadContext.putHeader("foo", "bar");
listener.onResponse(items.get(current - 1));
} else {
listener.onResponse(null);
}
};
IteratingActionListener<Object, Object> iteratingListener = new IteratingActionListener<>(ActionListener.wrap((object) -> {
assertNotNull(object);
assertThat(object, sameInstance(items.get(numberOfIterations - 1)));
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("listener", threadContext.getHeader("outside"));
}, (e) -> {
logger.error("unexpected exception", e);
fail("exception should not have been thrown");
}), consumer, items, threadContext);
iteratingListener.run();
// we never really went async, its all chained together so verify this for sanity
assertEquals(numberOfIterations, iterations.get());
assertNull(threadContext.getHeader("foo"));
assertEquals("listener", threadContext.getHeader("outside"));
}
public void testIterationEmptyList() {
IteratingActionListener<Object, Object> listener = new IteratingActionListener<>(ActionListener.wrap(Assert::assertNull,
(e) -> {
@ -60,7 +101,7 @@ public class IteratingActionListenerTests extends ESTestCase {
fail("exception should not have been thrown");
}), (listValue, iteratingListener) -> {
fail("consumer should not have been called!!!");
}, Collections.emptyList());
}, Collections.emptyList(), new ThreadContext(Settings.EMPTY));
listener.run();
}
@ -88,7 +129,7 @@ public class IteratingActionListenerTests extends ESTestCase {
}, (e) -> {
assertEquals("expected exception", e.getMessage());
assertTrue(onFailureCalled.compareAndSet(false, true));
}), consumer, items);
}), consumer, items, new ThreadContext(Settings.EMPTY));
iteratingListener.run();
// we never really went async, its all chained together so verify this for sanity

View File

@ -10,7 +10,10 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener;
@ -146,13 +149,17 @@ public class AuthenticationServiceTests extends ESTestCase {
when(secondRealm.token(threadContext)).thenReturn(token);
}
Authentication result = authenticateBlocking("_action", message, null);
assertThat(result, notNullValue());
assertThat(result.getUser(), is(user));
assertThat(result.getLookedUpBy(), is(nullValue()));
assertThat(result.getAuthenticatedBy(), is(notNullValue())); // TODO implement equals
final AtomicBoolean completed = new AtomicBoolean(false);
service.authenticate("_action", message, null, ActionListener.wrap(result -> {
assertThat(result, notNullValue());
assertThat(result.getUser(), is(user));
assertThat(result.getLookedUpBy(), is(nullValue()));
assertThat(result.getAuthenticatedBy(), is(notNullValue())); // TODO implement equals
assertThreadContextContainsAuthentication(result);
setCompletedToTrue(completed);
}, this::logAndFail));
assertTrue(completed.get());
verify(auditTrail).authenticationFailed(firstRealm.name(), token, "_action", message);
assertThreadContextContainsAuthentication(result);
}
public void testAuthenticateFirstNotSupportingSecondSucceeds() throws Exception {
@ -162,13 +169,17 @@ public class AuthenticationServiceTests extends ESTestCase {
mockAuthenticate(secondRealm, token, user);
when(secondRealm.token(threadContext)).thenReturn(token);
Authentication result = authenticateBlocking("_action", message, null);
assertThat(result, notNullValue());
assertThat(result.getUser(), is(user));
final AtomicBoolean completed = new AtomicBoolean(false);
service.authenticate("_action", message, null, ActionListener.wrap(result -> {
assertThat(result, notNullValue());
assertThat(result.getUser(), is(user));
assertThreadContextContainsAuthentication(result);
setCompletedToTrue(completed);
}, this::logAndFail));
verify(auditTrail).authenticationSuccess(secondRealm.name(), user, "_action", message);
verifyNoMoreInteractions(auditTrail);
verify(firstRealm, never()).authenticate(eq(token), any(ActionListener.class));
assertThreadContextContainsAuthentication(result);
assertTrue(completed.get());
}
public void testAuthenticateCached() throws Exception {
@ -295,12 +306,17 @@ public class AuthenticationServiceTests extends ESTestCase {
when(firstRealm.supports(token)).thenReturn(true);
mockAuthenticate(firstRealm, token, user);
Authentication result = authenticateBlocking("_action", message, fallback);
assertThat(result, notNullValue());
assertThat(result.getUser(), sameInstance(user));
assertThreadContextContainsAuthentication(result);
final AtomicBoolean completed = new AtomicBoolean(false);
service.authenticate("_action", message, fallback, ActionListener.wrap(result -> {
assertThat(result, notNullValue());
assertThat(result.getUser(), sameInstance(user));
assertThreadContextContainsAuthentication(result);
setCompletedToTrue(completed);
}, this::logAndFail));
verify(auditTrail).authenticationSuccess(firstRealm.name(), user, "_action", message);
verifyNoMoreInteractions(auditTrail);
assertTrue(completed.get());
}
public void testAuthenticateRestSuccess() throws Exception {
@ -308,12 +324,17 @@ public class AuthenticationServiceTests extends ESTestCase {
when(firstRealm.token(threadContext)).thenReturn(token);
when(firstRealm.supports(token)).thenReturn(true);
mockAuthenticate(firstRealm, token, user1);
Authentication result = authenticateBlocking(restRequest);
assertThat(result, notNullValue());
assertThat(result.getUser(), sameInstance(user1));
assertThreadContextContainsAuthentication(result);
// this call does not actually go async
final AtomicBoolean completed = new AtomicBoolean(false);
service.authenticate(restRequest, ActionListener.wrap(authentication -> {
assertThat(authentication, notNullValue());
assertThat(authentication.getUser(), sameInstance(user1));
assertThreadContextContainsAuthentication(authentication);
setCompletedToTrue(completed);
}, this::logAndFail));
verify(auditTrail).authenticationSuccess(firstRealm.name(), user1, restRequest);
verifyNoMoreInteractions(auditTrail);
assertTrue(completed.get());
}
public void testAutheticateTransportContextAndHeader() throws Exception {
@ -321,45 +342,60 @@ public class AuthenticationServiceTests extends ESTestCase {
when(firstRealm.token(threadContext)).thenReturn(token);
when(firstRealm.supports(token)).thenReturn(true);
mockAuthenticate(firstRealm, token, user1);
Authentication authentication = authenticateBlocking("_action", message, SystemUser.INSTANCE);
assertThat(authentication, notNullValue());
assertThat(authentication.getUser(), sameInstance(user1));
assertThreadContextContainsAuthentication(authentication);
final AtomicBoolean completed = new AtomicBoolean(false);
final SetOnce<Authentication> authRef = new SetOnce<>();
final SetOnce<String> authHeaderRef = new SetOnce<>();
service.authenticate("_action", message, SystemUser.INSTANCE, ActionListener.wrap(authentication -> {
assertThat(authentication, notNullValue());
assertThat(authentication.getUser(), sameInstance(user1));
assertThreadContextContainsAuthentication(authentication);
authRef.set(authentication);
authHeaderRef.set(threadContext.getHeader(Authentication.AUTHENTICATION_KEY));
setCompletedToTrue(completed);
}, this::logAndFail));
assertTrue(completed.compareAndSet(true, false));
reset(firstRealm);
// checking authentication from the context
InternalMessage message1 = new InternalMessage();
ThreadContext threadContext1 = new ThreadContext(Settings.EMPTY);
final ThreadContext threadContext1 = new ThreadContext(Settings.EMPTY);
when(threadPool.getThreadContext()).thenReturn(threadContext1);
service = new AuthenticationService(Settings.EMPTY, realms, auditTrail,
new DefaultAuthenticationFailureHandler(), threadPool, new AnonymousUser(Settings.EMPTY));
threadContext1.putTransient(Authentication.AUTHENTICATION_KEY, threadContext.getTransient(Authentication.AUTHENTICATION_KEY));
threadContext1.putHeader(Authentication.AUTHENTICATION_KEY, threadContext.getHeader(Authentication.AUTHENTICATION_KEY));
Authentication ctxAuth = authenticateBlocking("_action", message1, SystemUser.INSTANCE);
assertThat(ctxAuth, sameInstance(authentication));
threadContext1.putTransient(Authentication.AUTHENTICATION_KEY, authRef.get());
threadContext1.putHeader(Authentication.AUTHENTICATION_KEY, authHeaderRef.get());
service.authenticate("_action", message1, SystemUser.INSTANCE, ActionListener.wrap(ctxAuth -> {
assertThat(ctxAuth, sameInstance(authRef.get()));
assertThat(threadContext1.getHeader(Authentication.AUTHENTICATION_KEY), sameInstance(authHeaderRef.get()));
setCompletedToTrue(completed);
}, this::logAndFail));
assertTrue(completed.compareAndSet(true, false));
verifyZeroInteractions(firstRealm);
reset(firstRealm);
// checking authentication from the user header
threadContext1 = new ThreadContext(Settings.EMPTY);
when(threadPool.getThreadContext()).thenReturn(threadContext1);
ThreadContext threadContext2 = new ThreadContext(Settings.EMPTY);
when(threadPool.getThreadContext()).thenReturn(threadContext2);
service = new AuthenticationService(Settings.EMPTY, realms, auditTrail,
new DefaultAuthenticationFailureHandler(), threadPool, new AnonymousUser(Settings.EMPTY));
threadContext1.putHeader(Authentication.AUTHENTICATION_KEY, threadContext.getHeader(Authentication.AUTHENTICATION_KEY));
threadContext2.putHeader(Authentication.AUTHENTICATION_KEY, authHeaderRef.get());
BytesStreamOutput output = new BytesStreamOutput();
threadContext1.writeTo(output);
threadContext2.writeTo(output);
StreamInput input = output.bytes().streamInput();
threadContext1 = new ThreadContext(Settings.EMPTY);
threadContext1.readHeaders(input);
threadContext2 = new ThreadContext(Settings.EMPTY);
threadContext2.readHeaders(input);
when(threadPool.getThreadContext()).thenReturn(threadContext1);
when(threadPool.getThreadContext()).thenReturn(threadContext2);
service = new AuthenticationService(Settings.EMPTY, realms, auditTrail,
new DefaultAuthenticationFailureHandler(), threadPool, new AnonymousUser(Settings.EMPTY));
Authentication result = authenticateBlocking("_action", new InternalMessage(), SystemUser.INSTANCE);
assertThat(result, notNullValue());
assertThat(result.getUser(), equalTo(user1));
service.authenticate("_action", new InternalMessage(), SystemUser.INSTANCE, ActionListener.wrap(result -> {
assertThat(result, notNullValue());
assertThat(result.getUser(), equalTo(user1));
setCompletedToTrue(completed);
}, this::logAndFail));
assertTrue(completed.get());
verifyZeroInteractions(firstRealm);
}
@ -589,27 +625,33 @@ public class AuthenticationServiceTests extends ESTestCase {
return null;
}).when(secondRealm).lookupUser(eq("run_as"), any(ActionListener.class));
Authentication result;
final AtomicBoolean completed = new AtomicBoolean(false);
ActionListener<Authentication> listener = ActionListener.wrap(result -> {
assertThat(result, notNullValue());
User authenticated = result.getUser();
assertThat(SystemUser.is(authenticated), is(false));
assertThat(authenticated.runAs(), is(notNullValue()));
assertThat(authenticated.principal(), is("lookup user"));
assertThat(authenticated.roles(), arrayContaining("user"));
assertEquals(user.metadata(), authenticated.metadata());
assertEquals(user.email(), authenticated.email());
assertEquals(user.enabled(), authenticated.enabled());
assertEquals(user.fullName(), authenticated.fullName());
assertThat(authenticated.runAs().principal(), is("looked up user"));
assertThat(authenticated.runAs().roles(), arrayContaining("some role"));
assertThreadContextContainsAuthentication(result);
setCompletedToTrue(completed);
}, this::logAndFail);
// we do not actually go async
if (randomBoolean()) {
result = authenticateBlocking("_action", message, null);
service.authenticate("_action", message, null, listener);
} else {
result = authenticateBlocking(restRequest);
service.authenticate(restRequest, listener);
}
assertThat(result, notNullValue());
User authenticated = result.getUser();
assertThat(SystemUser.is(authenticated), is(false));
assertThat(authenticated.runAs(), is(notNullValue()));
assertThat(authenticated.principal(), is("lookup user"));
assertThat(authenticated.roles(), arrayContaining("user"));
assertEquals(user.metadata(), authenticated.metadata());
assertEquals(user.email(), authenticated.email());
assertEquals(user.enabled(), authenticated.enabled());
assertEquals(user.fullName(), authenticated.fullName());
assertThat(authenticated.runAs().principal(), is("looked up user"));
assertThat(authenticated.runAs().roles(), arrayContaining("some role"));
assertThreadContextContainsAuthentication(result);
assertTrue(completed.get());
}
public void testRunAsLookupDifferentRealm() throws Exception {
@ -624,22 +666,28 @@ public class AuthenticationServiceTests extends ESTestCase {
return null;
}).when(firstRealm).lookupUser(eq("run_as"), any(ActionListener.class));
Authentication result;
if (randomBoolean()) {
result = authenticateBlocking("_action", message, null);
} else {
result = authenticateBlocking(restRequest);
}
assertThat(result, notNullValue());
User authenticated = result.getUser();
final AtomicBoolean completed = new AtomicBoolean(false);
ActionListener<Authentication> listener = ActionListener.wrap(result -> {
assertThat(result, notNullValue());
User authenticated = result.getUser();
assertThat(SystemUser.is(authenticated), is(false));
assertThat(authenticated.runAs(), is(notNullValue()));
assertThat(authenticated.principal(), is("lookup user"));
assertThat(authenticated.roles(), arrayContaining("user"));
assertThat(authenticated.runAs().principal(), is("looked up user"));
assertThat(authenticated.runAs().roles(), arrayContaining("some role"));
assertThreadContextContainsAuthentication(result);
assertThat(SystemUser.is(authenticated), is(false));
assertThat(authenticated.runAs(), is(notNullValue()));
assertThat(authenticated.principal(), is("lookup user"));
assertThat(authenticated.roles(), arrayContaining("user"));
assertThat(authenticated.runAs().principal(), is("looked up user"));
assertThat(authenticated.runAs().roles(), arrayContaining("some role"));
assertThreadContextContainsAuthentication(result);
setCompletedToTrue(completed);
}, this::logAndFail);
// call service asynchronously but it doesn't actually go async
if (randomBoolean()) {
service.authenticate("_action", message, null, listener);
} else {
service.authenticate(restRequest, listener);
}
assertTrue(completed.get());
}
public void testRunAsWithEmptyRunAsUsernameRest() throws Exception {
@ -763,4 +811,13 @@ public class AuthenticationServiceTests extends ESTestCase {
this.internalRealmsOnly = internalRealms;
}
}
private void logAndFail(Exception e) {
logger.error("unexpected exception", e);
fail("unexpected exception " + e.getMessage());
}
private void setCompletedToTrue(AtomicBoolean completed) {
assertTrue(completed.compareAndSet(false, true));
}
}