diff --git a/core/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/core/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index a2bd82fd03f..e7c2018ad31 100644 --- a/core/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/core/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -377,11 +377,9 @@ public class TransportBulkAction extends HandledTransportAction handler = - new TransportChannelResponseHandler<>(logger, channel, extraMessage, () -> TransportResponse.Empty.INSTANCE); + new TransportChannelResponseHandler<>(logger, channel, extraMessage, + () -> TransportResponse.Empty.INSTANCE); transportService.sendRequest(clusterService.localNode(), transportReplicaAction, new ConcreteShardRequest<>(request, targetAllocationID), handler); @@ -809,11 +808,9 @@ public abstract class TransportReplicationAction< } setPhase(task, "waiting_for_retry"); request.onRetry(); - final ThreadContext.StoredContext context = threadPool.getThreadContext().newStoredContext(); observer.waitForNextChange(new ClusterStateObserver.Listener() { @Override public void onNewClusterState(ClusterState state) { - context.close(); run(); } @@ -824,7 +821,6 @@ public abstract class TransportReplicationAction< @Override public void onTimeout(TimeValue timeout) { - context.close(); // Try one more time... run(); } diff --git a/core/src/main/java/org/elasticsearch/cluster/ClusterStateObserver.java b/core/src/main/java/org/elasticsearch/cluster/ClusterStateObserver.java index cad98198a80..4964a14dfca 100644 --- a/core/src/main/java/org/elasticsearch/cluster/ClusterStateObserver.java +++ b/core/src/main/java/org/elasticsearch/cluster/ClusterStateObserver.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; +import java.util.function.Supplier; /** * A utility class which simplifies interacting with the cluster state in cases where @@ -118,7 +119,7 @@ public class ClusterStateObserver { * @param timeOutValue a timeout for waiting. If null the global observer timeout will be used. */ public void waitForNextChange(Listener listener, Predicate statePredicate, @Nullable TimeValue timeOutValue) { - + listener = new ContextPreservingListener(listener, contextHolder.newRestorableContext(false)); if (observingContext.get() != null) { throw new ElasticsearchException("already waiting for a cluster state change"); } @@ -157,8 +158,7 @@ public class ClusterStateObserver { listener.onNewClusterState(newState); } else { logger.trace("observer: sampled state rejected by predicate ({}). adding listener to ClusterService", newState); - ObservingContext context = - new ObservingContext(new ContextPreservingListener(listener, contextHolder.newStoredContext()), statePredicate); + final ObservingContext context = new ObservingContext(listener, statePredicate); if (!observingContext.compareAndSet(null, context)) { throw new ElasticsearchException("already waiting for a cluster state change"); } @@ -279,30 +279,33 @@ public class ClusterStateObserver { private static final class ContextPreservingListener implements Listener { private final Listener delegate; - private final ThreadContext.StoredContext tempContext; + private final Supplier contextSupplier; - private ContextPreservingListener(Listener delegate, ThreadContext.StoredContext storedContext) { - this.tempContext = storedContext; + private ContextPreservingListener(Listener delegate, Supplier contextSupplier) { + this.contextSupplier = contextSupplier; this.delegate = delegate; } @Override public void onNewClusterState(ClusterState state) { - tempContext.restore(); - delegate.onNewClusterState(state); + try (ThreadContext.StoredContext context = contextSupplier.get()) { + delegate.onNewClusterState(state); + } } @Override public void onClusterServiceClose() { - tempContext.restore(); - delegate.onClusterServiceClose(); + try (ThreadContext.StoredContext context = contextSupplier.get()) { + delegate.onClusterServiceClose(); + } } @Override public void onTimeout(TimeValue timeout) { - tempContext.restore(); - delegate.onTimeout(timeout); + try (ThreadContext.StoredContext context = contextSupplier.get()) { + delegate.onTimeout(timeout); + } } } } diff --git a/core/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java b/core/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java index ca1d364ffd2..d439696b720 100644 --- a/core/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java +++ b/core/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.store.Store; import java.io.Closeable; import java.io.IOException; @@ -34,6 +35,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with @@ -115,12 +119,57 @@ public final class ThreadContext implements Closeable, Writeable { return () -> threadLocal.set(context); } + /** * Just like {@link #stashContext()} but no default context is set. + * @param preserveResponseHeaders if set to true the response headers of the restore thread will be preserved. */ - public StoredContext newStoredContext() { + public StoredContext newStoredContext(boolean preserveResponseHeaders) { final ThreadContextStruct context = threadLocal.get(); - return () -> threadLocal.set(context); + return () -> { + if (preserveResponseHeaders && threadLocal.get() != context) { + threadLocal.set(context.putResponseHeaders(threadLocal.get().responseHeaders)); + } else { + threadLocal.set(context); + } + }; + } + + /** + * Returns a supplier that gathers a {@link #newStoredContext(boolean)} and restores it once the + * returned supplier is invoked. The context returned from the supplier is a stored version of the + * suppliers callers context that should be restored once the originally gathered context is not needed anymore. + * For instance this method should be used like this: + * + *
+     *     Supplier<ThreadContext.StoredContext> restorable = context.newRestorableContext(true);
+     *     new Thread() {
+     *         public void run() {
+     *             try (ThreadContext.StoredContext ctx = restorable.get()) {
+     *                 // execute with the parents context and restore the threads context afterwards
+     *             }
+     *         }
+     *
+     *     }.start();
+     * 
+ * + * @param preserveResponseHeaders if set to true the response headers of the restore thread will be preserved. + * @return a restorable context supplier + */ + public Supplier newRestorableContext(boolean preserveResponseHeaders) { + return wrapRestorable(newStoredContext(preserveResponseHeaders)); + } + + /** + * Same as {@link #newRestorableContext(boolean)} but wraps an existing context to restore. + * @param storedContext the context to restore + */ + public Supplier wrapRestorable(StoredContext storedContext) { + return () -> { + StoredContext context = newStoredContext(false); + storedContext.restore(); + return context; + }; } @Override @@ -327,6 +376,26 @@ public final class ThreadContext implements Closeable, Writeable { } } + private ThreadContextStruct putResponseHeaders(Map> headers) { + assert headers != null; + if (headers.isEmpty()) { + return this; + } + final Map> newResponseHeaders = new HashMap<>(this.responseHeaders); + for (Map.Entry> entry : headers.entrySet()) { + String key = entry.getKey(); + final List existingValues = newResponseHeaders.get(key); + if (existingValues != null) { + List newValues = Stream.concat(entry.getValue().stream(), + existingValues.stream()).distinct().collect(Collectors.toList()); + newResponseHeaders.put(key, Collections.unmodifiableList(newValues)); + } else { + newResponseHeaders.put(key, entry.getValue()); + } + } + return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders); + } + private ThreadContextStruct putResponse(String key, String value) { assert value != null; @@ -445,7 +514,7 @@ public final class ThreadContext implements Closeable, Writeable { private final ThreadContext.StoredContext ctx; private ContextPreservingRunnable(Runnable in) { - ctx = newStoredContext(); + ctx = newStoredContext(false); this.in = in; } @@ -487,7 +556,7 @@ public final class ThreadContext implements Closeable, Writeable { private ThreadContext.StoredContext threadsOriginalContext = null; private ContextPreservingAbstractRunnable(AbstractRunnable in) { - creatorsContext = newStoredContext(); + creatorsContext = newStoredContext(false); this.in = in; } diff --git a/core/src/main/java/org/elasticsearch/transport/TransportService.java b/core/src/main/java/org/elasticsearch/transport/TransportService.java index 8c816d12be1..6ba0abb6503 100644 --- a/core/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/core/src/main/java/org/elasticsearch/transport/TransportService.java @@ -543,8 +543,8 @@ public class TransportService extends AbstractLifecycleComponent { } else { timeoutHandler = new TimeoutHandler(requestId); } - TransportResponseHandler responseHandler = - new ContextRestoreResponseHandler<>(threadPool.getThreadContext().newStoredContext(), handler); + Supplier storedContextSupplier = threadPool.getThreadContext().newRestorableContext(true); + TransportResponseHandler responseHandler = new ContextRestoreResponseHandler<>(storedContextSupplier, handler); clientHandlers.put(requestId, new RequestHolder<>(responseHandler, connection.getNode(), action, timeoutHandler)); if (lifecycle.stoppedOrClosed()) { // if we are not started the exception handling will remove the RequestHolder again and calls the handler to notify @@ -1000,14 +1000,14 @@ public class TransportService extends AbstractLifecycleComponent { * This handler wrapper ensures that the response thread executes with the correct thread context. Before any of the4 handle methods * are invoked we restore the context. */ - private static final class ContextRestoreResponseHandler implements TransportResponseHandler { + public static final class ContextRestoreResponseHandler implements TransportResponseHandler { private final TransportResponseHandler delegate; - private final ThreadContext.StoredContext threadContext; + private final Supplier contextSupplier; - private ContextRestoreResponseHandler(ThreadContext.StoredContext threadContext, TransportResponseHandler delegate) { + public ContextRestoreResponseHandler(Supplier contextSupplier, TransportResponseHandler delegate) { this.delegate = delegate; - this.threadContext = threadContext; + this.contextSupplier = contextSupplier; } @Override @@ -1017,14 +1017,16 @@ public class TransportService extends AbstractLifecycleComponent { @Override public void handleResponse(T response) { - threadContext.restore(); - delegate.handleResponse(response); + try (ThreadContext.StoredContext ignore = contextSupplier.get()) { + delegate.handleResponse(response); + } } @Override public void handleException(TransportException exp) { - threadContext.restore(); - delegate.handleException(exp); + try (ThreadContext.StoredContext ignore = contextSupplier.get()) { + delegate.handleException(exp); + } } @Override @@ -1081,13 +1083,7 @@ public class TransportService extends AbstractLifecycleComponent { if (ThreadPool.Names.SAME.equals(executor)) { processResponse(handler, response); } else { - threadPool.executor(executor).execute(new Runnable() { - @SuppressWarnings({"unchecked"}) - @Override - public void run() { - processResponse(handler, response); - } - }); + threadPool.executor(executor).execute(() -> processResponse(handler, response)); } } } diff --git a/core/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java b/core/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java index 3f914f61d48..582f5d58d80 100644 --- a/core/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java +++ b/core/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java @@ -27,6 +27,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; @@ -86,7 +87,7 @@ public class ThreadContextTests extends ESTestCase { assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); - ThreadContext.StoredContext storedContext = threadContext.newStoredContext(); + ThreadContext.StoredContext storedContext = threadContext.newStoredContext(false); threadContext.putHeader("foo.bar", "baz"); try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { assertNull(threadContext.getHeader("foo")); @@ -109,6 +110,63 @@ public class ThreadContextTests extends ESTestCase { assertNull(threadContext.getHeader("foo.bar")); } + public void testRestorableContext() { + Settings build = Settings.builder().put("request.headers.default", "1").build(); + ThreadContext threadContext = new ThreadContext(build); + threadContext.putHeader("foo", "bar"); + threadContext.putTransient("ctx.foo", 1); + threadContext.addResponseHeader("resp.header", "baaaam"); + Supplier contextSupplier = threadContext.newRestorableContext(true); + + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + assertNull(threadContext.getHeader("foo")); + assertEquals("1", threadContext.getHeader("default")); + threadContext.addResponseHeader("resp.header", "boom"); + try (ThreadContext.StoredContext tmp = contextSupplier.get()) { + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + assertEquals(2, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0)); + assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(1)); + } + assertNull(threadContext.getHeader("foo")); + assertNull(threadContext.getTransient("ctx.foo")); + assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0)); + } + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0)); + + contextSupplier = threadContext.newRestorableContext(false); + + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + assertNull(threadContext.getHeader("foo")); + assertEquals("1", threadContext.getHeader("default")); + threadContext.addResponseHeader("resp.header", "boom"); + try (ThreadContext.StoredContext tmp = contextSupplier.get()) { + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0)); + } + assertNull(threadContext.getHeader("foo")); + assertNull(threadContext.getTransient("ctx.foo")); + assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0)); + } + + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0)); + } + public void testResponseHeaders() { final boolean expectThird = randomBoolean(); diff --git a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java index d4dfe69a727..28425d9145c 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java +++ b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java @@ -142,7 +142,7 @@ public class RemoteScrollableHitSource extends ScrollableHitSource { private void execute(String method, String uri, Map params, HttpEntity entity, BiFunction parser, Consumer listener) { // Preserve the thread context so headers survive after the call - ThreadContext.StoredContext ctx = threadPool.getThreadContext().newStoredContext(); + java.util.function.Supplier contextSupplier = threadPool.getThreadContext().newRestorableContext(true); class RetryHelper extends AbstractRunnable { private final Iterator retries = backoffPolicy.iterator(); @@ -152,63 +152,68 @@ public class RemoteScrollableHitSource extends ScrollableHitSource { @Override public void onSuccess(org.elasticsearch.client.Response response) { // Restore the thread context to get the precious headers - ctx.restore(); - T parsedResponse; - try { - HttpEntity responseEntity = response.getEntity(); - InputStream content = responseEntity.getContent(); - XContentType xContentType = null; - if (responseEntity.getContentType() != null) { - xContentType = XContentType.fromMediaTypeOrFormat(responseEntity.getContentType().getValue()); - } - if (xContentType == null) { - try { - throw new ElasticsearchException( - "Response didn't include Content-Type: " + bodyMessage(response.getEntity())); - } catch (IOException e) { - ElasticsearchException ee = new ElasticsearchException("Error extracting body from response"); - ee.addSuppressed(e); - throw ee; + try (ThreadContext.StoredContext ctx = contextSupplier.get()) { + assert ctx != null; // eliminates compiler warning + T parsedResponse; + try { + HttpEntity responseEntity = response.getEntity(); + InputStream content = responseEntity.getContent(); + XContentType xContentType = null; + if (responseEntity.getContentType() != null) { + xContentType = XContentType.fromMediaTypeOrFormat(responseEntity.getContentType().getValue()); } - } - // EMPTY is safe here because we don't call namedObject - try (XContentParser xContentParser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY, + if (xContentType == null) { + try { + throw new ElasticsearchException( + "Response didn't include Content-Type: " + bodyMessage(response.getEntity())); + } catch (IOException e) { + ElasticsearchException ee = new ElasticsearchException("Error extracting body from response"); + ee.addSuppressed(e); + throw ee; + } + } + // EMPTY is safe here because we don't call namedObject + try (XContentParser xContentParser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY, content)) { - parsedResponse = parser.apply(xContentParser, null); - } catch (ParsingException e) { + parsedResponse = parser.apply(xContentParser, null); + } catch (ParsingException e) { /* Because we're streaming the response we can't get a copy of it here. The best we can do is hint that it * is totally wrong and we're probably not talking to Elasticsearch. */ - throw new ElasticsearchException( + throw new ElasticsearchException( "Error parsing the response, remote is likely not an Elasticsearch instance", e); + } + } catch (IOException e) { + throw new ElasticsearchException( + "Error deserializing response, remote is likely not an Elasticsearch instance", e); } - } catch (IOException e) { - throw new ElasticsearchException("Error deserializing response, remote is likely not an Elasticsearch instance", - e); + listener.accept(parsedResponse); } - listener.accept(parsedResponse); } @Override public void onFailure(Exception e) { - if (e instanceof ResponseException) { - ResponseException re = (ResponseException) e; - if (RestStatus.TOO_MANY_REQUESTS.getStatus() == re.getResponse().getStatusLine().getStatusCode()) { - if (retries.hasNext()) { - TimeValue delay = retries.next(); - logger.trace( - (Supplier) () -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e); - countSearchRetry.run(); - threadPool.schedule(delay, ThreadPool.Names.SAME, RetryHelper.this); - return; + try (ThreadContext.StoredContext ctx = contextSupplier.get()) { + assert ctx != null; // eliminates compiler warning + if (e instanceof ResponseException) { + ResponseException re = (ResponseException) e; + if (RestStatus.TOO_MANY_REQUESTS.getStatus() == re.getResponse().getStatusLine().getStatusCode()) { + if (retries.hasNext()) { + TimeValue delay = retries.next(); + logger.trace( + (Supplier) () -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e); + countSearchRetry.run(); + threadPool.schedule(delay, ThreadPool.Names.SAME, RetryHelper.this); + return; + } } - } - e = wrapExceptionToPreserveStatus(re.getResponse().getStatusLine().getStatusCode(), + e = wrapExceptionToPreserveStatus(re.getResponse().getStatusLine().getStatusCode(), re.getResponse().getEntity(), re); - } else if (e instanceof ContentTooLongException) { - e = new IllegalArgumentException( + } else if (e instanceof ContentTooLongException) { + e = new IllegalArgumentException( "Remote responded with a chunk that was too large. Use a smaller batch size.", e); + } + fail.accept(e); } - fail.accept(e); } }); } diff --git a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java index e54d7c72efe..de33688a4c8 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java @@ -150,7 +150,6 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { client.close(); } client = new MyMockClient(new NoOpClient(threadPool)); - client.threadPool().getThreadContext().newStoredContext(); client.threadPool().getThreadContext().putHeader(expectedHeaders); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 9cb4ccba39a..e7c0020c0bf 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -21,6 +21,7 @@ package org.elasticsearch.transport; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; +import org.apache.lucene.util.CollectionUtil; import org.apache.lucene.util.Constants; import org.apache.lucene.util.IOUtils; import org.elasticsearch.ExceptionsHelper; @@ -1932,4 +1933,68 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { t.join(); } } + + public void testResponseHeadersArePreserved() throws InterruptedException { + List executors = new ArrayList<>(ThreadPool.THREAD_POOL_TYPES.keySet()); + CollectionUtil.timSort(executors); // makes sure it's reproducible + serviceA.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, + (request, channel) -> { + + threadPool.getThreadContext().putTransient("boom", new Object()); + threadPool.getThreadContext().addResponseHeader("foo.bar", "baz"); + if ("fail".equals(request.info)) { + throw new RuntimeException("boom"); + } else { + channel.sendResponse(TransportResponse.Empty.INSTANCE); + } + }); + + CountDownLatch latch = new CountDownLatch(2); + + TransportResponseHandler transportResponseHandler = new TransportResponseHandler() { + @Override + public TransportResponse newInstance() { + return TransportResponse.Empty.INSTANCE; + } + + @Override + public void handleResponse(TransportResponse response) { + try { + assertSame(response, TransportResponse.Empty.INSTANCE); + assertTrue(threadPool.getThreadContext().getResponseHeaders().containsKey("foo.bar")); + assertEquals(1, threadPool.getThreadContext().getResponseHeaders().get("foo.bar").size()); + assertEquals("baz", threadPool.getThreadContext().getResponseHeaders().get("foo.bar").get(0)); + assertNull(threadPool.getThreadContext().getTransient("boom")); + } finally { + latch.countDown(); + } + + } + + @Override + public void handleException(TransportException exp) { + try { + assertTrue(threadPool.getThreadContext().getResponseHeaders().containsKey("foo.bar")); + assertEquals(1, threadPool.getThreadContext().getResponseHeaders().get("foo.bar").size()); + assertEquals("baz", threadPool.getThreadContext().getResponseHeaders().get("foo.bar").get(0)); + assertNull(threadPool.getThreadContext().getTransient("boom")); + } finally { + latch.countDown(); + } + } + + @Override + public String executor() { + if (1 == 1) + return "same"; + + return randomFrom(executors); + } + }; + + serviceB.sendRequest(nodeA, "action", new TestRequest(randomFrom("fail", "pass")), transportResponseHandler); + serviceA.sendRequest(nodeA, "action", new TestRequest(randomFrom("fail", "pass")), transportResponseHandler); + latch.await(); + } + }