ListenableFuture should preserve ThreadContext (#34394)
ListenableFuture may run a listener on the same thread that called the addListener method or it may execute on another thread after the future has completed. Whenever the ListenableFuture stores the listener for execution later, it should preserve the thread context which is what this change does.
This commit is contained in:
parent
7bc11a8099
commit
6d99d7dafc
|
@ -20,6 +20,7 @@
|
|||
package org.elasticsearch.common.util.concurrent;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ContextPreservingActionListener;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -47,7 +48,7 @@ public final class ListenableFuture<V> extends BaseFuture<V> implements ActionLi
|
|||
* If the future has completed, the listener will be notified immediately without forking to
|
||||
* a different thread.
|
||||
*/
|
||||
public void addListener(ActionListener<V> listener, ExecutorService executor) {
|
||||
public void addListener(ActionListener<V> listener, ExecutorService executor, ThreadContext threadContext) {
|
||||
if (done) {
|
||||
// run the callback directly, we don't hold the lock and don't need to fork!
|
||||
notifyListener(listener, EsExecutors.newDirectExecutorService());
|
||||
|
@ -59,7 +60,7 @@ public final class ListenableFuture<V> extends BaseFuture<V> implements ActionLi
|
|||
if (done) {
|
||||
run = true;
|
||||
} else {
|
||||
listeners.add(new Tuple<>(listener, executor));
|
||||
listeners.add(new Tuple<>(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext), executor));
|
||||
run = false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
package org.elasticsearch.common.util.concurrent;
|
||||
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
@ -30,9 +31,12 @@ import java.util.concurrent.CyclicBarrier;
|
|||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class ListenableFutureTests extends ESTestCase {
|
||||
|
||||
private ExecutorService executorService;
|
||||
private ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
|
||||
|
||||
@After
|
||||
public void stopExecutorService() throws InterruptedException {
|
||||
|
@ -46,7 +50,7 @@ public class ListenableFutureTests extends ESTestCase {
|
|||
AtomicInteger notifications = new AtomicInteger(0);
|
||||
final int numberOfListeners = scaledRandomIntBetween(1, 12);
|
||||
for (int i = 0; i < numberOfListeners; i++) {
|
||||
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService());
|
||||
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService(), threadContext);
|
||||
}
|
||||
|
||||
future.onResponse("");
|
||||
|
@ -63,7 +67,7 @@ public class ListenableFutureTests extends ESTestCase {
|
|||
future.addListener(ActionListener.wrap(s -> fail("this should never be called"), e -> {
|
||||
assertEquals(exception, e);
|
||||
notifications.incrementAndGet();
|
||||
}), EsExecutors.newDirectExecutorService());
|
||||
}), EsExecutors.newDirectExecutorService(), threadContext);
|
||||
}
|
||||
|
||||
future.onFailure(exception);
|
||||
|
@ -76,7 +80,7 @@ public class ListenableFutureTests extends ESTestCase {
|
|||
final int completingThread = randomIntBetween(0, numberOfThreads - 1);
|
||||
final ListenableFuture<String> future = new ListenableFuture<>();
|
||||
executorService = EsExecutors.newFixed("testConcurrentListenerRegistrationAndCompletion", numberOfThreads, 1000,
|
||||
EsExecutors.daemonThreadFactory("listener"), new ThreadContext(Settings.EMPTY));
|
||||
EsExecutors.daemonThreadFactory("listener"), threadContext);
|
||||
final CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
|
||||
final CountDownLatch listenersLatch = new CountDownLatch(numberOfThreads - 1);
|
||||
final AtomicInteger numResponses = new AtomicInteger(0);
|
||||
|
@ -85,20 +89,31 @@ public class ListenableFutureTests extends ESTestCase {
|
|||
for (int i = 0; i < numberOfThreads; i++) {
|
||||
final int threadNum = i;
|
||||
Thread thread = new Thread(() -> {
|
||||
threadContext.putTransient("key", threadNum);
|
||||
try {
|
||||
barrier.await();
|
||||
if (threadNum == completingThread) {
|
||||
// we need to do more than just call onResponse as this often results in synchronous
|
||||
// execution of the listeners instead of actually going async
|
||||
final int waitTime = randomIntBetween(0, 50);
|
||||
Thread.sleep(waitTime);
|
||||
logger.info("completing the future after sleeping {}ms", waitTime);
|
||||
future.onResponse("");
|
||||
logger.info("future received response");
|
||||
} else {
|
||||
logger.info("adding listener {}", threadNum);
|
||||
future.addListener(ActionListener.wrap(s -> {
|
||||
logger.info("listener {} received value {}", threadNum, s);
|
||||
assertEquals("", s);
|
||||
assertThat(threadContext.getTransient("key"), is(threadNum));
|
||||
numResponses.incrementAndGet();
|
||||
listenersLatch.countDown();
|
||||
}, e -> {
|
||||
logger.error("caught unexpected exception", e);
|
||||
logger.error(new ParameterizedMessage("listener {} caught unexpected exception", threadNum), e);
|
||||
numExceptions.incrementAndGet();
|
||||
listenersLatch.countDown();
|
||||
}), executorService);
|
||||
}), executorService, threadContext);
|
||||
logger.info("listener {} added", threadNum);
|
||||
}
|
||||
barrier.await();
|
||||
} catch (InterruptedException | BrokenBarrierException e) {
|
||||
|
|
|
@ -153,7 +153,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
|
|||
// is cleared of the failed authentication
|
||||
cache.invalidate(token.principal(), listenableCacheEntry);
|
||||
authenticateWithCache(token, listener);
|
||||
}), threadPool.executor(ThreadPool.Names.GENERIC));
|
||||
}), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
|
||||
} else {
|
||||
// attempt authentication against the authentication source
|
||||
doAuthenticate(token, ActionListener.wrap(authResult -> {
|
||||
|
@ -255,7 +255,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
|
|||
} else {
|
||||
listener.onResponse(null);
|
||||
}
|
||||
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC));
|
||||
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
|
||||
} catch (final ExecutionException e) {
|
||||
listener.onFailure(e);
|
||||
}
|
||||
|
|
|
@ -469,7 +469,9 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
|
|||
List<Thread> threads = new ArrayList<>(numberOfThreads);
|
||||
for (int i = 0; i < numberOfThreads; i++) {
|
||||
final boolean invalidPassword = randomBoolean();
|
||||
final int threadNum = i;
|
||||
threads.add(new Thread(() -> {
|
||||
threadPool.getThreadContext().putTransient("key", threadNum);
|
||||
try {
|
||||
latch.countDown();
|
||||
latch.await();
|
||||
|
@ -477,6 +479,7 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
|
|||
UsernamePasswordToken token = new UsernamePasswordToken(username, invalidPassword ? randomPassword : password);
|
||||
|
||||
realm.authenticate(token, ActionListener.wrap((result) -> {
|
||||
assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum));
|
||||
if (invalidPassword && result.isAuthenticated()) {
|
||||
throw new RuntimeException("invalid password led to an authenticated user: " + result);
|
||||
} else if (invalidPassword == false && result.isAuthenticated() == false) {
|
||||
|
@ -529,12 +532,15 @@ public class CachingUsernamePasswordRealmTests extends ESTestCase {
|
|||
final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
|
||||
List<Thread> threads = new ArrayList<>(numberOfThreads);
|
||||
for (int i = 0; i < numberOfThreads; i++) {
|
||||
final int threadNum = i;
|
||||
threads.add(new Thread(() -> {
|
||||
try {
|
||||
threadPool.getThreadContext().putTransient("key", threadNum);
|
||||
latch.countDown();
|
||||
latch.await();
|
||||
for (int i1 = 0; i1 < numberOfIterations; i1++) {
|
||||
realm.lookupUser(username, ActionListener.wrap((user) -> {
|
||||
assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum));
|
||||
if (user == null) {
|
||||
throw new RuntimeException("failed to lookup user");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue