From 24e2847af277cca729760ac910f684fdcafe5f0c Mon Sep 17 00:00:00 2001 From: Simon Willnauer Date: Wed, 18 Jan 2017 16:17:54 +0100 Subject: [PATCH] Streamline foreign stored context restore and allow to perserve response headers (#22677) Today we do not preserve response headers if they are present on a transport protocol response. While preserving these headers is not always desired, in the most cases we should pass on these headers to have consistent results for depreciation headers etc. yet, this hasn't been much of a problem since most of the deprecations are detected early ie. on the coordinating node such that this bug wasn't uncovered until #22647 This commit allow to optionally preserve headers when a context is restored and also streamlines the context restore since it leaked frequently into the callers thread context when the callers context wasn't restored again. --- .../action/bulk/TransportBulkAction.java | 3 - .../TransportReplicationAction.java | 8 +- .../cluster/ClusterStateObserver.java | 27 +++--- .../common/util/concurrent/ThreadContext.java | 77 +++++++++++++++- .../transport/TransportService.java | 30 +++--- .../util/concurrent/ThreadContextTests.java | 60 +++++++++++- .../remote/RemoteScrollableHitSource.java | 91 ++++++++++--------- .../reindex/AsyncBulkByScrollActionTests.java | 1 - .../AbstractSimpleTransportTestCase.java | 65 +++++++++++++ 9 files changed, 275 insertions(+), 87 deletions(-) diff --git a/core/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/core/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index a2bd82fd03f..e7c2018ad31 100644 --- a/core/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/core/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -377,11 +377,9 @@ public class TransportBulkAction extends HandledTransportAction handler = - new TransportChannelResponseHandler<>(logger, channel, extraMessage, () -> TransportResponse.Empty.INSTANCE); + new TransportChannelResponseHandler<>(logger, channel, extraMessage, + () -> TransportResponse.Empty.INSTANCE); transportService.sendRequest(clusterService.localNode(), transportReplicaAction, new ConcreteShardRequest<>(request, targetAllocationID), handler); @@ -809,11 +808,9 @@ public abstract class TransportReplicationAction< } setPhase(task, "waiting_for_retry"); request.onRetry(); - final ThreadContext.StoredContext context = threadPool.getThreadContext().newStoredContext(); observer.waitForNextChange(new ClusterStateObserver.Listener() { @Override public void onNewClusterState(ClusterState state) { - context.close(); run(); } @@ -824,7 +821,6 @@ public abstract class TransportReplicationAction< @Override public void onTimeout(TimeValue timeout) { - context.close(); // Try one more time... run(); } diff --git a/core/src/main/java/org/elasticsearch/cluster/ClusterStateObserver.java b/core/src/main/java/org/elasticsearch/cluster/ClusterStateObserver.java index cad98198a80..4964a14dfca 100644 --- a/core/src/main/java/org/elasticsearch/cluster/ClusterStateObserver.java +++ b/core/src/main/java/org/elasticsearch/cluster/ClusterStateObserver.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; +import java.util.function.Supplier; /** * A utility class which simplifies interacting with the cluster state in cases where @@ -118,7 +119,7 @@ public class ClusterStateObserver { * @param timeOutValue a timeout for waiting. If null the global observer timeout will be used. */ public void waitForNextChange(Listener listener, Predicate statePredicate, @Nullable TimeValue timeOutValue) { - + listener = new ContextPreservingListener(listener, contextHolder.newRestorableContext(false)); if (observingContext.get() != null) { throw new ElasticsearchException("already waiting for a cluster state change"); } @@ -157,8 +158,7 @@ public class ClusterStateObserver { listener.onNewClusterState(newState); } else { logger.trace("observer: sampled state rejected by predicate ({}). adding listener to ClusterService", newState); - ObservingContext context = - new ObservingContext(new ContextPreservingListener(listener, contextHolder.newStoredContext()), statePredicate); + final ObservingContext context = new ObservingContext(listener, statePredicate); if (!observingContext.compareAndSet(null, context)) { throw new ElasticsearchException("already waiting for a cluster state change"); } @@ -279,30 +279,33 @@ public class ClusterStateObserver { private static final class ContextPreservingListener implements Listener { private final Listener delegate; - private final ThreadContext.StoredContext tempContext; + private final Supplier contextSupplier; - private ContextPreservingListener(Listener delegate, ThreadContext.StoredContext storedContext) { - this.tempContext = storedContext; + private ContextPreservingListener(Listener delegate, Supplier contextSupplier) { + this.contextSupplier = contextSupplier; this.delegate = delegate; } @Override public void onNewClusterState(ClusterState state) { - tempContext.restore(); - delegate.onNewClusterState(state); + try (ThreadContext.StoredContext context = contextSupplier.get()) { + delegate.onNewClusterState(state); + } } @Override public void onClusterServiceClose() { - tempContext.restore(); - delegate.onClusterServiceClose(); + try (ThreadContext.StoredContext context = contextSupplier.get()) { + delegate.onClusterServiceClose(); + } } @Override public void onTimeout(TimeValue timeout) { - tempContext.restore(); - delegate.onTimeout(timeout); + try (ThreadContext.StoredContext context = contextSupplier.get()) { + delegate.onTimeout(timeout); + } } } } diff --git a/core/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java b/core/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java index ca1d364ffd2..d439696b720 100644 --- a/core/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java +++ b/core/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.store.Store; import java.io.Closeable; import java.io.IOException; @@ -34,6 +35,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with @@ -115,12 +119,57 @@ public final class ThreadContext implements Closeable, Writeable { return () -> threadLocal.set(context); } + /** * Just like {@link #stashContext()} but no default context is set. + * @param preserveResponseHeaders if set to true the response headers of the restore thread will be preserved. */ - public StoredContext newStoredContext() { + public StoredContext newStoredContext(boolean preserveResponseHeaders) { final ThreadContextStruct context = threadLocal.get(); - return () -> threadLocal.set(context); + return () -> { + if (preserveResponseHeaders && threadLocal.get() != context) { + threadLocal.set(context.putResponseHeaders(threadLocal.get().responseHeaders)); + } else { + threadLocal.set(context); + } + }; + } + + /** + * Returns a supplier that gathers a {@link #newStoredContext(boolean)} and restores it once the + * returned supplier is invoked. The context returned from the supplier is a stored version of the + * suppliers callers context that should be restored once the originally gathered context is not needed anymore. + * For instance this method should be used like this: + * + *
+     *     Supplier<ThreadContext.StoredContext> restorable = context.newRestorableContext(true);
+     *     new Thread() {
+     *         public void run() {
+     *             try (ThreadContext.StoredContext ctx = restorable.get()) {
+     *                 // execute with the parents context and restore the threads context afterwards
+     *             }
+     *         }
+     *
+     *     }.start();
+     * 
+ * + * @param preserveResponseHeaders if set to true the response headers of the restore thread will be preserved. + * @return a restorable context supplier + */ + public Supplier newRestorableContext(boolean preserveResponseHeaders) { + return wrapRestorable(newStoredContext(preserveResponseHeaders)); + } + + /** + * Same as {@link #newRestorableContext(boolean)} but wraps an existing context to restore. + * @param storedContext the context to restore + */ + public Supplier wrapRestorable(StoredContext storedContext) { + return () -> { + StoredContext context = newStoredContext(false); + storedContext.restore(); + return context; + }; } @Override @@ -327,6 +376,26 @@ public final class ThreadContext implements Closeable, Writeable { } } + private ThreadContextStruct putResponseHeaders(Map> headers) { + assert headers != null; + if (headers.isEmpty()) { + return this; + } + final Map> newResponseHeaders = new HashMap<>(this.responseHeaders); + for (Map.Entry> entry : headers.entrySet()) { + String key = entry.getKey(); + final List existingValues = newResponseHeaders.get(key); + if (existingValues != null) { + List newValues = Stream.concat(entry.getValue().stream(), + existingValues.stream()).distinct().collect(Collectors.toList()); + newResponseHeaders.put(key, Collections.unmodifiableList(newValues)); + } else { + newResponseHeaders.put(key, entry.getValue()); + } + } + return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders); + } + private ThreadContextStruct putResponse(String key, String value) { assert value != null; @@ -445,7 +514,7 @@ public final class ThreadContext implements Closeable, Writeable { private final ThreadContext.StoredContext ctx; private ContextPreservingRunnable(Runnable in) { - ctx = newStoredContext(); + ctx = newStoredContext(false); this.in = in; } @@ -487,7 +556,7 @@ public final class ThreadContext implements Closeable, Writeable { private ThreadContext.StoredContext threadsOriginalContext = null; private ContextPreservingAbstractRunnable(AbstractRunnable in) { - creatorsContext = newStoredContext(); + creatorsContext = newStoredContext(false); this.in = in; } diff --git a/core/src/main/java/org/elasticsearch/transport/TransportService.java b/core/src/main/java/org/elasticsearch/transport/TransportService.java index 8c816d12be1..6ba0abb6503 100644 --- a/core/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/core/src/main/java/org/elasticsearch/transport/TransportService.java @@ -543,8 +543,8 @@ public class TransportService extends AbstractLifecycleComponent { } else { timeoutHandler = new TimeoutHandler(requestId); } - TransportResponseHandler responseHandler = - new ContextRestoreResponseHandler<>(threadPool.getThreadContext().newStoredContext(), handler); + Supplier storedContextSupplier = threadPool.getThreadContext().newRestorableContext(true); + TransportResponseHandler responseHandler = new ContextRestoreResponseHandler<>(storedContextSupplier, handler); clientHandlers.put(requestId, new RequestHolder<>(responseHandler, connection.getNode(), action, timeoutHandler)); if (lifecycle.stoppedOrClosed()) { // if we are not started the exception handling will remove the RequestHolder again and calls the handler to notify @@ -1000,14 +1000,14 @@ public class TransportService extends AbstractLifecycleComponent { * This handler wrapper ensures that the response thread executes with the correct thread context. Before any of the4 handle methods * are invoked we restore the context. */ - private static final class ContextRestoreResponseHandler implements TransportResponseHandler { + public static final class ContextRestoreResponseHandler implements TransportResponseHandler { private final TransportResponseHandler delegate; - private final ThreadContext.StoredContext threadContext; + private final Supplier contextSupplier; - private ContextRestoreResponseHandler(ThreadContext.StoredContext threadContext, TransportResponseHandler delegate) { + public ContextRestoreResponseHandler(Supplier contextSupplier, TransportResponseHandler delegate) { this.delegate = delegate; - this.threadContext = threadContext; + this.contextSupplier = contextSupplier; } @Override @@ -1017,14 +1017,16 @@ public class TransportService extends AbstractLifecycleComponent { @Override public void handleResponse(T response) { - threadContext.restore(); - delegate.handleResponse(response); + try (ThreadContext.StoredContext ignore = contextSupplier.get()) { + delegate.handleResponse(response); + } } @Override public void handleException(TransportException exp) { - threadContext.restore(); - delegate.handleException(exp); + try (ThreadContext.StoredContext ignore = contextSupplier.get()) { + delegate.handleException(exp); + } } @Override @@ -1081,13 +1083,7 @@ public class TransportService extends AbstractLifecycleComponent { if (ThreadPool.Names.SAME.equals(executor)) { processResponse(handler, response); } else { - threadPool.executor(executor).execute(new Runnable() { - @SuppressWarnings({"unchecked"}) - @Override - public void run() { - processResponse(handler, response); - } - }); + threadPool.executor(executor).execute(() -> processResponse(handler, response)); } } } diff --git a/core/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java b/core/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java index 3f914f61d48..582f5d58d80 100644 --- a/core/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java +++ b/core/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java @@ -27,6 +27,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; @@ -86,7 +87,7 @@ public class ThreadContextTests extends ESTestCase { assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); - ThreadContext.StoredContext storedContext = threadContext.newStoredContext(); + ThreadContext.StoredContext storedContext = threadContext.newStoredContext(false); threadContext.putHeader("foo.bar", "baz"); try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { assertNull(threadContext.getHeader("foo")); @@ -109,6 +110,63 @@ public class ThreadContextTests extends ESTestCase { assertNull(threadContext.getHeader("foo.bar")); } + public void testRestorableContext() { + Settings build = Settings.builder().put("request.headers.default", "1").build(); + ThreadContext threadContext = new ThreadContext(build); + threadContext.putHeader("foo", "bar"); + threadContext.putTransient("ctx.foo", 1); + threadContext.addResponseHeader("resp.header", "baaaam"); + Supplier contextSupplier = threadContext.newRestorableContext(true); + + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + assertNull(threadContext.getHeader("foo")); + assertEquals("1", threadContext.getHeader("default")); + threadContext.addResponseHeader("resp.header", "boom"); + try (ThreadContext.StoredContext tmp = contextSupplier.get()) { + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + assertEquals(2, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0)); + assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(1)); + } + assertNull(threadContext.getHeader("foo")); + assertNull(threadContext.getTransient("ctx.foo")); + assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0)); + } + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0)); + + contextSupplier = threadContext.newRestorableContext(false); + + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + assertNull(threadContext.getHeader("foo")); + assertEquals("1", threadContext.getHeader("default")); + threadContext.addResponseHeader("resp.header", "boom"); + try (ThreadContext.StoredContext tmp = contextSupplier.get()) { + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0)); + } + assertNull(threadContext.getHeader("foo")); + assertNull(threadContext.getTransient("ctx.foo")); + assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0)); + } + + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); + assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0)); + } + public void testResponseHeaders() { final boolean expectThird = randomBoolean(); diff --git a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java index d4dfe69a727..28425d9145c 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java +++ b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java @@ -142,7 +142,7 @@ public class RemoteScrollableHitSource extends ScrollableHitSource { private void execute(String method, String uri, Map params, HttpEntity entity, BiFunction parser, Consumer listener) { // Preserve the thread context so headers survive after the call - ThreadContext.StoredContext ctx = threadPool.getThreadContext().newStoredContext(); + java.util.function.Supplier contextSupplier = threadPool.getThreadContext().newRestorableContext(true); class RetryHelper extends AbstractRunnable { private final Iterator retries = backoffPolicy.iterator(); @@ -152,63 +152,68 @@ public class RemoteScrollableHitSource extends ScrollableHitSource { @Override public void onSuccess(org.elasticsearch.client.Response response) { // Restore the thread context to get the precious headers - ctx.restore(); - T parsedResponse; - try { - HttpEntity responseEntity = response.getEntity(); - InputStream content = responseEntity.getContent(); - XContentType xContentType = null; - if (responseEntity.getContentType() != null) { - xContentType = XContentType.fromMediaTypeOrFormat(responseEntity.getContentType().getValue()); - } - if (xContentType == null) { - try { - throw new ElasticsearchException( - "Response didn't include Content-Type: " + bodyMessage(response.getEntity())); - } catch (IOException e) { - ElasticsearchException ee = new ElasticsearchException("Error extracting body from response"); - ee.addSuppressed(e); - throw ee; + try (ThreadContext.StoredContext ctx = contextSupplier.get()) { + assert ctx != null; // eliminates compiler warning + T parsedResponse; + try { + HttpEntity responseEntity = response.getEntity(); + InputStream content = responseEntity.getContent(); + XContentType xContentType = null; + if (responseEntity.getContentType() != null) { + xContentType = XContentType.fromMediaTypeOrFormat(responseEntity.getContentType().getValue()); } - } - // EMPTY is safe here because we don't call namedObject - try (XContentParser xContentParser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY, + if (xContentType == null) { + try { + throw new ElasticsearchException( + "Response didn't include Content-Type: " + bodyMessage(response.getEntity())); + } catch (IOException e) { + ElasticsearchException ee = new ElasticsearchException("Error extracting body from response"); + ee.addSuppressed(e); + throw ee; + } + } + // EMPTY is safe here because we don't call namedObject + try (XContentParser xContentParser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY, content)) { - parsedResponse = parser.apply(xContentParser, null); - } catch (ParsingException e) { + parsedResponse = parser.apply(xContentParser, null); + } catch (ParsingException e) { /* Because we're streaming the response we can't get a copy of it here. The best we can do is hint that it * is totally wrong and we're probably not talking to Elasticsearch. */ - throw new ElasticsearchException( + throw new ElasticsearchException( "Error parsing the response, remote is likely not an Elasticsearch instance", e); + } + } catch (IOException e) { + throw new ElasticsearchException( + "Error deserializing response, remote is likely not an Elasticsearch instance", e); } - } catch (IOException e) { - throw new ElasticsearchException("Error deserializing response, remote is likely not an Elasticsearch instance", - e); + listener.accept(parsedResponse); } - listener.accept(parsedResponse); } @Override public void onFailure(Exception e) { - if (e instanceof ResponseException) { - ResponseException re = (ResponseException) e; - if (RestStatus.TOO_MANY_REQUESTS.getStatus() == re.getResponse().getStatusLine().getStatusCode()) { - if (retries.hasNext()) { - TimeValue delay = retries.next(); - logger.trace( - (Supplier) () -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e); - countSearchRetry.run(); - threadPool.schedule(delay, ThreadPool.Names.SAME, RetryHelper.this); - return; + try (ThreadContext.StoredContext ctx = contextSupplier.get()) { + assert ctx != null; // eliminates compiler warning + if (e instanceof ResponseException) { + ResponseException re = (ResponseException) e; + if (RestStatus.TOO_MANY_REQUESTS.getStatus() == re.getResponse().getStatusLine().getStatusCode()) { + if (retries.hasNext()) { + TimeValue delay = retries.next(); + logger.trace( + (Supplier) () -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e); + countSearchRetry.run(); + threadPool.schedule(delay, ThreadPool.Names.SAME, RetryHelper.this); + return; + } } - } - e = wrapExceptionToPreserveStatus(re.getResponse().getStatusLine().getStatusCode(), + e = wrapExceptionToPreserveStatus(re.getResponse().getStatusLine().getStatusCode(), re.getResponse().getEntity(), re); - } else if (e instanceof ContentTooLongException) { - e = new IllegalArgumentException( + } else if (e instanceof ContentTooLongException) { + e = new IllegalArgumentException( "Remote responded with a chunk that was too large. Use a smaller batch size.", e); + } + fail.accept(e); } - fail.accept(e); } }); } diff --git a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java index e54d7c72efe..de33688a4c8 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java @@ -150,7 +150,6 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { client.close(); } client = new MyMockClient(new NoOpClient(threadPool)); - client.threadPool().getThreadContext().newStoredContext(); client.threadPool().getThreadContext().putHeader(expectedHeaders); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 9cb4ccba39a..e7c0020c0bf 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -21,6 +21,7 @@ package org.elasticsearch.transport; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; +import org.apache.lucene.util.CollectionUtil; import org.apache.lucene.util.Constants; import org.apache.lucene.util.IOUtils; import org.elasticsearch.ExceptionsHelper; @@ -1932,4 +1933,68 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { t.join(); } } + + public void testResponseHeadersArePreserved() throws InterruptedException { + List executors = new ArrayList<>(ThreadPool.THREAD_POOL_TYPES.keySet()); + CollectionUtil.timSort(executors); // makes sure it's reproducible + serviceA.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, + (request, channel) -> { + + threadPool.getThreadContext().putTransient("boom", new Object()); + threadPool.getThreadContext().addResponseHeader("foo.bar", "baz"); + if ("fail".equals(request.info)) { + throw new RuntimeException("boom"); + } else { + channel.sendResponse(TransportResponse.Empty.INSTANCE); + } + }); + + CountDownLatch latch = new CountDownLatch(2); + + TransportResponseHandler transportResponseHandler = new TransportResponseHandler() { + @Override + public TransportResponse newInstance() { + return TransportResponse.Empty.INSTANCE; + } + + @Override + public void handleResponse(TransportResponse response) { + try { + assertSame(response, TransportResponse.Empty.INSTANCE); + assertTrue(threadPool.getThreadContext().getResponseHeaders().containsKey("foo.bar")); + assertEquals(1, threadPool.getThreadContext().getResponseHeaders().get("foo.bar").size()); + assertEquals("baz", threadPool.getThreadContext().getResponseHeaders().get("foo.bar").get(0)); + assertNull(threadPool.getThreadContext().getTransient("boom")); + } finally { + latch.countDown(); + } + + } + + @Override + public void handleException(TransportException exp) { + try { + assertTrue(threadPool.getThreadContext().getResponseHeaders().containsKey("foo.bar")); + assertEquals(1, threadPool.getThreadContext().getResponseHeaders().get("foo.bar").size()); + assertEquals("baz", threadPool.getThreadContext().getResponseHeaders().get("foo.bar").get(0)); + assertNull(threadPool.getThreadContext().getTransient("boom")); + } finally { + latch.countDown(); + } + } + + @Override + public String executor() { + if (1 == 1) + return "same"; + + return randomFrom(executors); + } + }; + + serviceB.sendRequest(nodeA, "action", new TestRequest(randomFrom("fail", "pass")), transportResponseHandler); + serviceA.sendRequest(nodeA, "action", new TestRequest(randomFrom("fail", "pass")), transportResponseHandler); + latch.await(); + } + }