diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java index d50f57aaafa..4d6bd51c5c3 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java @@ -20,6 +20,7 @@ package org.elasticsearch.common.util.concurrent; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.common.collect.Tuple; import java.util.ArrayList; @@ -47,7 +48,7 @@ public final class ListenableFuture extends BaseFuture implements ActionLi * If the future has completed, the listener will be notified immediately without forking to * a different thread. */ - public void addListener(ActionListener listener, ExecutorService executor) { + public void addListener(ActionListener listener, ExecutorService executor, ThreadContext threadContext) { if (done) { // run the callback directly, we don't hold the lock and don't need to fork! notifyListener(listener, EsExecutors.newDirectExecutorService()); @@ -59,7 +60,7 @@ public final class ListenableFuture extends BaseFuture implements ActionLi if (done) { run = true; } else { - listeners.add(new Tuple<>(listener, executor)); + listeners.add(new Tuple<>(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext), executor)); run = false; } } diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java index 712656777f9..75a2e299461 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.common.util.concurrent; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; @@ -30,9 +31,12 @@ import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicInteger; +import static org.hamcrest.Matchers.is; + public class ListenableFutureTests extends ESTestCase { private ExecutorService executorService; + private ThreadContext threadContext = new ThreadContext(Settings.EMPTY); @After public void stopExecutorService() throws InterruptedException { @@ -46,7 +50,7 @@ public class ListenableFutureTests extends ESTestCase { AtomicInteger notifications = new AtomicInteger(0); final int numberOfListeners = scaledRandomIntBetween(1, 12); for (int i = 0; i < numberOfListeners; i++) { - future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService()); + future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService(), threadContext); } future.onResponse(""); @@ -63,7 +67,7 @@ public class ListenableFutureTests extends ESTestCase { future.addListener(ActionListener.wrap(s -> fail("this should never be called"), e -> { assertEquals(exception, e); notifications.incrementAndGet(); - }), EsExecutors.newDirectExecutorService()); + }), EsExecutors.newDirectExecutorService(), threadContext); } future.onFailure(exception); @@ -76,7 +80,7 @@ public class ListenableFutureTests extends ESTestCase { final int completingThread = randomIntBetween(0, numberOfThreads - 1); final ListenableFuture future = new ListenableFuture<>(); executorService = EsExecutors.newFixed("testConcurrentListenerRegistrationAndCompletion", numberOfThreads, 1000, - EsExecutors.daemonThreadFactory("listener"), new ThreadContext(Settings.EMPTY)); + EsExecutors.daemonThreadFactory("listener"), threadContext); final CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads); final CountDownLatch listenersLatch = new CountDownLatch(numberOfThreads - 1); final AtomicInteger numResponses = new AtomicInteger(0); @@ -85,20 +89,31 @@ public class ListenableFutureTests extends ESTestCase { for (int i = 0; i < numberOfThreads; i++) { final int threadNum = i; Thread thread = new Thread(() -> { + threadContext.putTransient("key", threadNum); try { barrier.await(); if (threadNum == completingThread) { + // we need to do more than just call onResponse as this often results in synchronous + // execution of the listeners instead of actually going async + final int waitTime = randomIntBetween(0, 50); + Thread.sleep(waitTime); + logger.info("completing the future after sleeping {}ms", waitTime); future.onResponse(""); + logger.info("future received response"); } else { + logger.info("adding listener {}", threadNum); future.addListener(ActionListener.wrap(s -> { + logger.info("listener {} received value {}", threadNum, s); assertEquals("", s); + assertThat(threadContext.getTransient("key"), is(threadNum)); numResponses.incrementAndGet(); listenersLatch.countDown(); }, e -> { - logger.error("caught unexpected exception", e); + logger.error(new ParameterizedMessage("listener {} caught unexpected exception", threadNum), e); numExceptions.incrementAndGet(); listenersLatch.countDown(); - }), executorService); + }), executorService, threadContext); + logger.info("listener {} added", threadNum); } barrier.await(); } catch (InterruptedException | BrokenBarrierException e) { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java index 0d8609d61d9..fdb2fd0f33d 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java @@ -153,7 +153,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm // is cleared of the failed authentication cache.invalidate(token.principal(), listenableCacheEntry); authenticateWithCache(token, listener); - }), threadPool.executor(ThreadPool.Names.GENERIC)); + }), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext()); } else { // attempt authentication against the authentication source doAuthenticate(token, ActionListener.wrap(authResult -> { @@ -255,7 +255,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm } else { listener.onResponse(null); } - }, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC)); + }, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext()); } catch (final ExecutionException e) { listener.onFailure(e); } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java index 6d84dfb2a80..6230c637b89 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java @@ -469,7 +469,9 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase { List threads = new ArrayList<>(numberOfThreads); for (int i = 0; i < numberOfThreads; i++) { final boolean invalidPassword = randomBoolean(); + final int threadNum = i; threads.add(new Thread(() -> { + threadPool.getThreadContext().putTransient("key", threadNum); try { latch.countDown(); latch.await(); @@ -477,6 +479,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase { UsernamePasswordToken token = new UsernamePasswordToken(username, invalidPassword ? randomPassword : password); realm.authenticate(token, ActionListener.wrap((result) -> { + assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum)); if (invalidPassword && result.isAuthenticated()) { throw new RuntimeException("invalid password led to an authenticated user: " + result); } else if (invalidPassword == false && result.isAuthenticated() == false) { @@ -529,12 +532,15 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase { final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads); List threads = new ArrayList<>(numberOfThreads); for (int i = 0; i < numberOfThreads; i++) { + final int threadNum = i; threads.add(new Thread(() -> { try { + threadPool.getThreadContext().putTransient("key", threadNum); latch.countDown(); latch.await(); for (int i1 = 0; i1 < numberOfIterations; i1++) { realm.lookupUser(username, ActionListener.wrap((user) -> { + assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum)); if (user == null) { throw new RuntimeException("failed to lookup user"); }