ListenableFuture should preserve ThreadContext (#34394)

ListenableFuture may run a listener on the same thread that called the
addListener method or it may execute on another thread after the future
has completed. Whenever the ListenableFuture stores the listener for
execution later, it should preserve the thread context which is what
this change does.
This commit is contained in:
Jay Modi 2018-10-11 15:24:38 +01:00 committed by GitHub
parent 7bc11a8099
commit 6d99d7dafc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 9 deletions

View File

@ -20,6 +20,7 @@
package org.elasticsearch.common.util.concurrent;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.common.collect.Tuple;
import java.util.ArrayList;
@ -47,7 +48,7 @@ public final class ListenableFuture<V> extends BaseFuture<V> implements ActionLi
* If the future has completed, the listener will be notified immediately without forking to
* a different thread.
*/
public void addListener(ActionListener<V> listener, ExecutorService executor) {
public void addListener(ActionListener<V> listener, ExecutorService executor, ThreadContext threadContext) {
if (done) {
// run the callback directly, we don't hold the lock and don't need to fork!
notifyListener(listener, EsExecutors.newDirectExecutorService());
@ -59,7 +60,7 @@ public final class ListenableFuture<V> extends BaseFuture<V> implements ActionLi
if (done) {
run = true;
} else {
listeners.add(new Tuple<>(listener, executor));
listeners.add(new Tuple<>(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext), executor));
run = false;
}
}

View File

@ -19,6 +19,7 @@
package org.elasticsearch.common.util.concurrent;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.ESTestCase;
@ -30,9 +31,12 @@ import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.Matchers.is;
public class ListenableFutureTests extends ESTestCase {
private ExecutorService executorService;
private ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
@After
public void stopExecutorService() throws InterruptedException {
@ -46,7 +50,7 @@ public class ListenableFutureTests extends ESTestCase {
AtomicInteger notifications = new AtomicInteger(0);
final int numberOfListeners = scaledRandomIntBetween(1, 12);
for (int i = 0; i < numberOfListeners; i++) {
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService());
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService(), threadContext);
}
future.onResponse("");
@ -63,7 +67,7 @@ public class ListenableFutureTests extends ESTestCase {
future.addListener(ActionListener.wrap(s -> fail("this should never be called"), e -> {
assertEquals(exception, e);
notifications.incrementAndGet();
}), EsExecutors.newDirectExecutorService());
}), EsExecutors.newDirectExecutorService(), threadContext);
}
future.onFailure(exception);
@ -76,7 +80,7 @@ public class ListenableFutureTests extends ESTestCase {
final int completingThread = randomIntBetween(0, numberOfThreads - 1);
final ListenableFuture<String> future = new ListenableFuture<>();
executorService = EsExecutors.newFixed("testConcurrentListenerRegistrationAndCompletion", numberOfThreads, 1000,
EsExecutors.daemonThreadFactory("listener"), new ThreadContext(Settings.EMPTY));
EsExecutors.daemonThreadFactory("listener"), threadContext);
final CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
final CountDownLatch listenersLatch = new CountDownLatch(numberOfThreads - 1);
final AtomicInteger numResponses = new AtomicInteger(0);
@ -85,20 +89,31 @@ public class ListenableFutureTests extends ESTestCase {
for (int i = 0; i < numberOfThreads; i++) {
final int threadNum = i;
Thread thread = new Thread(() -> {
threadContext.putTransient("key", threadNum);
try {
barrier.await();
if (threadNum == completingThread) {
// we need to do more than just call onResponse as this often results in synchronous
// execution of the listeners instead of actually going async
final int waitTime = randomIntBetween(0, 50);
Thread.sleep(waitTime);
logger.info("completing the future after sleeping {}ms", waitTime);
future.onResponse("");
logger.info("future received response");
} else {
logger.info("adding listener {}", threadNum);
future.addListener(ActionListener.wrap(s -> {
logger.info("listener {} received value {}", threadNum, s);
assertEquals("", s);
assertThat(threadContext.getTransient("key"), is(threadNum));
numResponses.incrementAndGet();
listenersLatch.countDown();
}, e -> {
logger.error("caught unexpected exception", e);
logger.error(new ParameterizedMessage("listener {} caught unexpected exception", threadNum), e);
numExceptions.incrementAndGet();
listenersLatch.countDown();
}), executorService);
}), executorService, threadContext);
logger.info("listener {} added", threadNum);
}
barrier.await();
} catch (InterruptedException | BrokenBarrierException e) {

View File

@ -153,7 +153,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
// is cleared of the failed authentication
cache.invalidate(token.principal(), listenableCacheEntry);
authenticateWithCache(token, listener);
}), threadPool.executor(ThreadPool.Names.GENERIC));
}), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
} else {
// attempt authentication against the authentication source
doAuthenticate(token, ActionListener.wrap(authResult -> {
@ -255,7 +255,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
} else {
listener.onResponse(null);
}
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC));
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
} catch (final ExecutionException e) {
listener.onFailure(e);
}

View File

@ -469,7 +469,9 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
List<Thread> threads = new ArrayList<>(numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
final boolean invalidPassword = randomBoolean();
final int threadNum = i;
threads.add(new Thread(() -> {
threadPool.getThreadContext().putTransient("key", threadNum);
try {
latch.countDown();
latch.await();
@ -477,6 +479,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
UsernamePasswordToken token = new UsernamePasswordToken(username, invalidPassword ? randomPassword : password);
realm.authenticate(token, ActionListener.wrap((result) -> {
assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum));
if (invalidPassword && result.isAuthenticated()) {
throw new RuntimeException("invalid password led to an authenticated user: " + result);
} else if (invalidPassword == false && result.isAuthenticated() == false) {
@ -529,12 +532,15 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
List<Thread> threads = new ArrayList<>(numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
final int threadNum = i;
threads.add(new Thread(() -> {
try {
threadPool.getThreadContext().putTransient("key", threadNum);
latch.countDown();
latch.await();
for (int i1 = 0; i1 < numberOfIterations; i1++) {
realm.lookupUser(username, ActionListener.wrap((user) -> {
assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum));
if (user == null) {
throw new RuntimeException("failed to lookup user");
}