From 6d99d7dafce78a8d6c8cb48bcfeced069e48351b Mon Sep 17 00:00:00 2001 From: Jay Modi Date: Thu, 11 Oct 2018 15:24:38 +0100 Subject: [PATCH] ListenableFuture should preserve ThreadContext (#34394) ListenableFuture may run a listener on the same thread that called the addListener method or it may execute on another thread after the future has completed. Whenever the ListenableFuture stores the listener for execution later, it should preserve the thread context which is what this change does. --- .../util/concurrent/ListenableFuture.java | 5 ++-- .../concurrent/ListenableFutureTests.java | 25 +++++++++++++++---- .../support/CachingUsernamePasswordRealm.java | 4 +-- .../CachingUsernamePasswordRealmTests.java | 6 +++++ 4 files changed, 31 insertions(+), 9 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java index d50f57aaafa..4d6bd51c5c3 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java @@ -20,6 +20,7 @@ package org.elasticsearch.common.util.concurrent; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.common.collect.Tuple; import java.util.ArrayList; @@ -47,7 +48,7 @@ public final class ListenableFuture extends BaseFuture implements ActionLi * If the future has completed, the listener will be notified immediately without forking to * a different thread. */ - public void addListener(ActionListener listener, ExecutorService executor) { + public void addListener(ActionListener listener, ExecutorService executor, ThreadContext threadContext) { if (done) { // run the callback directly, we don't hold the lock and don't need to fork! notifyListener(listener, EsExecutors.newDirectExecutorService()); @@ -59,7 +60,7 @@ public final class ListenableFuture extends BaseFuture implements ActionLi if (done) { run = true; } else { - listeners.add(new Tuple<>(listener, executor)); + listeners.add(new Tuple<>(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext), executor)); run = false; } } diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java index 712656777f9..75a2e299461 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.common.util.concurrent; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; @@ -30,9 +31,12 @@ import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicInteger; +import static org.hamcrest.Matchers.is; + public class ListenableFutureTests extends ESTestCase { private ExecutorService executorService; + private ThreadContext threadContext = new ThreadContext(Settings.EMPTY); @After public void stopExecutorService() throws InterruptedException { @@ -46,7 +50,7 @@ public class ListenableFutureTests extends ESTestCase { AtomicInteger notifications = new AtomicInteger(0); final int numberOfListeners = scaledRandomIntBetween(1, 12); for (int i = 0; i < numberOfListeners; i++) { - future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService()); + future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService(), threadContext); } future.onResponse(""); @@ -63,7 +67,7 @@ public class ListenableFutureTests extends ESTestCase { future.addListener(ActionListener.wrap(s -> fail("this should never be called"), e -> { assertEquals(exception, e); notifications.incrementAndGet(); - }), EsExecutors.newDirectExecutorService()); + }), EsExecutors.newDirectExecutorService(), threadContext); } future.onFailure(exception); @@ -76,7 +80,7 @@ public class ListenableFutureTests extends ESTestCase { final int completingThread = randomIntBetween(0, numberOfThreads - 1); final ListenableFuture future = new ListenableFuture<>(); executorService = EsExecutors.newFixed("testConcurrentListenerRegistrationAndCompletion", numberOfThreads, 1000, - EsExecutors.daemonThreadFactory("listener"), new ThreadContext(Settings.EMPTY)); + EsExecutors.daemonThreadFactory("listener"), threadContext); final CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads); final CountDownLatch listenersLatch = new CountDownLatch(numberOfThreads - 1); final AtomicInteger numResponses = new AtomicInteger(0); @@ -85,20 +89,31 @@ public class ListenableFutureTests extends ESTestCase { for (int i = 0; i < numberOfThreads; i++) { final int threadNum = i; Thread thread = new Thread(() -> { + threadContext.putTransient("key", threadNum); try { barrier.await(); if (threadNum == completingThread) { + // we need to do more than just call onResponse as this often results in synchronous + // execution of the listeners instead of actually going async + final int waitTime = randomIntBetween(0, 50); + Thread.sleep(waitTime); + logger.info("completing the future after sleeping {}ms", waitTime); future.onResponse(""); + logger.info("future received response"); } else { + logger.info("adding listener {}", threadNum); future.addListener(ActionListener.wrap(s -> { + logger.info("listener {} received value {}", threadNum, s); assertEquals("", s); + assertThat(threadContext.getTransient("key"), is(threadNum)); numResponses.incrementAndGet(); listenersLatch.countDown(); }, e -> { - logger.error("caught unexpected exception", e); + logger.error(new ParameterizedMessage("listener {} caught unexpected exception", threadNum), e); numExceptions.incrementAndGet(); listenersLatch.countDown(); - }), executorService); + }), executorService, threadContext); + logger.info("listener {} added", threadNum); } barrier.await(); } catch (InterruptedException | BrokenBarrierException e) { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java index 0d8609d61d9..fdb2fd0f33d 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java @@ -153,7 +153,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm // is cleared of the failed authentication cache.invalidate(token.principal(), listenableCacheEntry); authenticateWithCache(token, listener); - }), threadPool.executor(ThreadPool.Names.GENERIC)); + }), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext()); } else { // attempt authentication against the authentication source doAuthenticate(token, ActionListener.wrap(authResult -> { @@ -255,7 +255,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm } else { listener.onResponse(null); } - }, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC)); + }, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext()); } catch (final ExecutionException e) { listener.onFailure(e); } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java index 6d84dfb2a80..6230c637b89 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java @@ -469,7 +469,9 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase { List threads = new ArrayList<>(numberOfThreads); for (int i = 0; i < numberOfThreads; i++) { final boolean invalidPassword = randomBoolean(); + final int threadNum = i; threads.add(new Thread(() -> { + threadPool.getThreadContext().putTransient("key", threadNum); try { latch.countDown(); latch.await(); @@ -477,6 +479,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase { UsernamePasswordToken token = new UsernamePasswordToken(username, invalidPassword ? randomPassword : password); realm.authenticate(token, ActionListener.wrap((result) -> { + assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum)); if (invalidPassword && result.isAuthenticated()) { throw new RuntimeException("invalid password led to an authenticated user: " + result); } else if (invalidPassword == false && result.isAuthenticated() == false) { @@ -529,12 +532,15 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase { final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads); List threads = new ArrayList<>(numberOfThreads); for (int i = 0; i < numberOfThreads; i++) { + final int threadNum = i; threads.add(new Thread(() -> { try { + threadPool.getThreadContext().putTransient("key", threadNum); latch.countDown(); latch.await(); for (int i1 = 0; i1 < numberOfIterations; i1++) { realm.lookupUser(username, ActionListener.wrap((user) -> { + assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum)); if (user == null) { throw new RuntimeException("failed to lookup user"); }