Streamline foreign stored context restore and allow to perserve response headers (#22677)
Today we do not preserve response headers if they are present on a transport protocol response. While preserving these headers is not always desired, in the most cases we should pass on these headers to have consistent results for depreciation headers etc. yet, this hasn't been much of a problem since most of the deprecations are detected early ie. on the coordinating node such that this bug wasn't uncovered until #22647 This commit allow to optionally preserve headers when a context is restored and also streamlines the context restore since it leaked frequently into the callers thread context when the callers context wasn't restored again.
This commit is contained in:
parent
8a0a1140a9
commit
24e2847af2
|
@ -377,11 +377,9 @@ public class TransportBulkAction extends HandledTransportAction<BulkRequest, Bul
|
|||
onFailure(failure);
|
||||
return;
|
||||
}
|
||||
final ThreadContext.StoredContext context = threadPool.getThreadContext().newStoredContext();
|
||||
observer.waitForNextChange(new ClusterStateObserver.Listener() {
|
||||
@Override
|
||||
public void onNewClusterState(ClusterState state) {
|
||||
context.close();
|
||||
run();
|
||||
}
|
||||
|
||||
|
@ -392,7 +390,6 @@ public class TransportBulkAction extends HandledTransportAction<BulkRequest, Bul
|
|||
|
||||
@Override
|
||||
public void onTimeout(TimeValue timeout) {
|
||||
context.close();
|
||||
// Try one more time...
|
||||
run();
|
||||
}
|
||||
|
|
|
@ -514,16 +514,15 @@ public abstract class TransportReplicationAction<
|
|||
request),
|
||||
e);
|
||||
request.onRetry();
|
||||
final ThreadContext.StoredContext context = threadPool.getThreadContext().newStoredContext();
|
||||
observer.waitForNextChange(new ClusterStateObserver.Listener() {
|
||||
@Override
|
||||
public void onNewClusterState(ClusterState state) {
|
||||
context.close();
|
||||
// Forking a thread on local node via transport service so that custom transport service have an
|
||||
// opportunity to execute custom logic before the replica operation begins
|
||||
String extraMessage = "action [" + transportReplicaAction + "], request[" + request + "]";
|
||||
TransportChannelResponseHandler<TransportResponse.Empty> handler =
|
||||
new TransportChannelResponseHandler<>(logger, channel, extraMessage, () -> TransportResponse.Empty.INSTANCE);
|
||||
new TransportChannelResponseHandler<>(logger, channel, extraMessage,
|
||||
() -> TransportResponse.Empty.INSTANCE);
|
||||
transportService.sendRequest(clusterService.localNode(), transportReplicaAction,
|
||||
new ConcreteShardRequest<>(request, targetAllocationID),
|
||||
handler);
|
||||
|
@ -809,11 +808,9 @@ public abstract class TransportReplicationAction<
|
|||
}
|
||||
setPhase(task, "waiting_for_retry");
|
||||
request.onRetry();
|
||||
final ThreadContext.StoredContext context = threadPool.getThreadContext().newStoredContext();
|
||||
observer.waitForNextChange(new ClusterStateObserver.Listener() {
|
||||
@Override
|
||||
public void onNewClusterState(ClusterState state) {
|
||||
context.close();
|
||||
run();
|
||||
}
|
||||
|
||||
|
@ -824,7 +821,6 @@ public abstract class TransportReplicationAction<
|
|||
|
||||
@Override
|
||||
public void onTimeout(TimeValue timeout) {
|
||||
context.close();
|
||||
// Try one more time...
|
||||
run();
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
|
|||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
/**
|
||||
* A utility class which simplifies interacting with the cluster state in cases where
|
||||
|
@ -118,7 +119,7 @@ public class ClusterStateObserver {
|
|||
* @param timeOutValue a timeout for waiting. If null the global observer timeout will be used.
|
||||
*/
|
||||
public void waitForNextChange(Listener listener, Predicate<ClusterState> statePredicate, @Nullable TimeValue timeOutValue) {
|
||||
|
||||
listener = new ContextPreservingListener(listener, contextHolder.newRestorableContext(false));
|
||||
if (observingContext.get() != null) {
|
||||
throw new ElasticsearchException("already waiting for a cluster state change");
|
||||
}
|
||||
|
@ -157,8 +158,7 @@ public class ClusterStateObserver {
|
|||
listener.onNewClusterState(newState);
|
||||
} else {
|
||||
logger.trace("observer: sampled state rejected by predicate ({}). adding listener to ClusterService", newState);
|
||||
ObservingContext context =
|
||||
new ObservingContext(new ContextPreservingListener(listener, contextHolder.newStoredContext()), statePredicate);
|
||||
final ObservingContext context = new ObservingContext(listener, statePredicate);
|
||||
if (!observingContext.compareAndSet(null, context)) {
|
||||
throw new ElasticsearchException("already waiting for a cluster state change");
|
||||
}
|
||||
|
@ -279,30 +279,33 @@ public class ClusterStateObserver {
|
|||
|
||||
private static final class ContextPreservingListener implements Listener {
|
||||
private final Listener delegate;
|
||||
private final ThreadContext.StoredContext tempContext;
|
||||
private final Supplier<ThreadContext.StoredContext> contextSupplier;
|
||||
|
||||
|
||||
private ContextPreservingListener(Listener delegate, ThreadContext.StoredContext storedContext) {
|
||||
this.tempContext = storedContext;
|
||||
private ContextPreservingListener(Listener delegate, Supplier<ThreadContext.StoredContext> contextSupplier) {
|
||||
this.contextSupplier = contextSupplier;
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNewClusterState(ClusterState state) {
|
||||
tempContext.restore();
|
||||
try (ThreadContext.StoredContext context = contextSupplier.get()) {
|
||||
delegate.onNewClusterState(state);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClusterServiceClose() {
|
||||
tempContext.restore();
|
||||
try (ThreadContext.StoredContext context = contextSupplier.get()) {
|
||||
delegate.onClusterServiceClose();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTimeout(TimeValue timeout) {
|
||||
tempContext.restore();
|
||||
try (ThreadContext.StoredContext context = contextSupplier.get()) {
|
||||
delegate.onTimeout(timeout);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.common.settings.Setting;
|
||||
import org.elasticsearch.common.settings.Setting.Property;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.store.Store;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
|
@ -34,6 +35,9 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with
|
||||
|
@ -115,12 +119,57 @@ public final class ThreadContext implements Closeable, Writeable {
|
|||
return () -> threadLocal.set(context);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Just like {@link #stashContext()} but no default context is set.
|
||||
* @param preserveResponseHeaders if set to <code>true</code> the response headers of the restore thread will be preserved.
|
||||
*/
|
||||
public StoredContext newStoredContext() {
|
||||
public StoredContext newStoredContext(boolean preserveResponseHeaders) {
|
||||
final ThreadContextStruct context = threadLocal.get();
|
||||
return () -> threadLocal.set(context);
|
||||
return () -> {
|
||||
if (preserveResponseHeaders && threadLocal.get() != context) {
|
||||
threadLocal.set(context.putResponseHeaders(threadLocal.get().responseHeaders));
|
||||
} else {
|
||||
threadLocal.set(context);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a supplier that gathers a {@link #newStoredContext(boolean)} and restores it once the
|
||||
* returned supplier is invoked. The context returned from the supplier is a stored version of the
|
||||
* suppliers callers context that should be restored once the originally gathered context is not needed anymore.
|
||||
* For instance this method should be used like this:
|
||||
*
|
||||
* <pre>
|
||||
* Supplier<ThreadContext.StoredContext> restorable = context.newRestorableContext(true);
|
||||
* new Thread() {
|
||||
* public void run() {
|
||||
* try (ThreadContext.StoredContext ctx = restorable.get()) {
|
||||
* // execute with the parents context and restore the threads context afterwards
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* }.start();
|
||||
* </pre>
|
||||
*
|
||||
* @param preserveResponseHeaders if set to <code>true</code> the response headers of the restore thread will be preserved.
|
||||
* @return a restorable context supplier
|
||||
*/
|
||||
public Supplier<StoredContext> newRestorableContext(boolean preserveResponseHeaders) {
|
||||
return wrapRestorable(newStoredContext(preserveResponseHeaders));
|
||||
}
|
||||
|
||||
/**
|
||||
* Same as {@link #newRestorableContext(boolean)} but wraps an existing context to restore.
|
||||
* @param storedContext the context to restore
|
||||
*/
|
||||
public Supplier<StoredContext> wrapRestorable(StoredContext storedContext) {
|
||||
return () -> {
|
||||
StoredContext context = newStoredContext(false);
|
||||
storedContext.restore();
|
||||
return context;
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -327,6 +376,26 @@ public final class ThreadContext implements Closeable, Writeable {
|
|||
}
|
||||
}
|
||||
|
||||
private ThreadContextStruct putResponseHeaders(Map<String, List<String>> headers) {
|
||||
assert headers != null;
|
||||
if (headers.isEmpty()) {
|
||||
return this;
|
||||
}
|
||||
final Map<String, List<String>> newResponseHeaders = new HashMap<>(this.responseHeaders);
|
||||
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
|
||||
String key = entry.getKey();
|
||||
final List<String> existingValues = newResponseHeaders.get(key);
|
||||
if (existingValues != null) {
|
||||
List<String> newValues = Stream.concat(entry.getValue().stream(),
|
||||
existingValues.stream()).distinct().collect(Collectors.toList());
|
||||
newResponseHeaders.put(key, Collections.unmodifiableList(newValues));
|
||||
} else {
|
||||
newResponseHeaders.put(key, entry.getValue());
|
||||
}
|
||||
}
|
||||
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders);
|
||||
}
|
||||
|
||||
private ThreadContextStruct putResponse(String key, String value) {
|
||||
assert value != null;
|
||||
|
||||
|
@ -445,7 +514,7 @@ public final class ThreadContext implements Closeable, Writeable {
|
|||
private final ThreadContext.StoredContext ctx;
|
||||
|
||||
private ContextPreservingRunnable(Runnable in) {
|
||||
ctx = newStoredContext();
|
||||
ctx = newStoredContext(false);
|
||||
this.in = in;
|
||||
}
|
||||
|
||||
|
@ -487,7 +556,7 @@ public final class ThreadContext implements Closeable, Writeable {
|
|||
private ThreadContext.StoredContext threadsOriginalContext = null;
|
||||
|
||||
private ContextPreservingAbstractRunnable(AbstractRunnable in) {
|
||||
creatorsContext = newStoredContext();
|
||||
creatorsContext = newStoredContext(false);
|
||||
this.in = in;
|
||||
}
|
||||
|
||||
|
|
|
@ -543,8 +543,8 @@ public class TransportService extends AbstractLifecycleComponent {
|
|||
} else {
|
||||
timeoutHandler = new TimeoutHandler(requestId);
|
||||
}
|
||||
TransportResponseHandler<T> responseHandler =
|
||||
new ContextRestoreResponseHandler<>(threadPool.getThreadContext().newStoredContext(), handler);
|
||||
Supplier<ThreadContext.StoredContext> storedContextSupplier = threadPool.getThreadContext().newRestorableContext(true);
|
||||
TransportResponseHandler<T> responseHandler = new ContextRestoreResponseHandler<>(storedContextSupplier, handler);
|
||||
clientHandlers.put(requestId, new RequestHolder<>(responseHandler, connection.getNode(), action, timeoutHandler));
|
||||
if (lifecycle.stoppedOrClosed()) {
|
||||
// if we are not started the exception handling will remove the RequestHolder again and calls the handler to notify
|
||||
|
@ -1000,14 +1000,14 @@ public class TransportService extends AbstractLifecycleComponent {
|
|||
* This handler wrapper ensures that the response thread executes with the correct thread context. Before any of the4 handle methods
|
||||
* are invoked we restore the context.
|
||||
*/
|
||||
private static final class ContextRestoreResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> {
|
||||
public static final class ContextRestoreResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> {
|
||||
|
||||
private final TransportResponseHandler<T> delegate;
|
||||
private final ThreadContext.StoredContext threadContext;
|
||||
private final Supplier<ThreadContext.StoredContext> contextSupplier;
|
||||
|
||||
private ContextRestoreResponseHandler(ThreadContext.StoredContext threadContext, TransportResponseHandler<T> delegate) {
|
||||
public ContextRestoreResponseHandler(Supplier<ThreadContext.StoredContext> contextSupplier, TransportResponseHandler<T> delegate) {
|
||||
this.delegate = delegate;
|
||||
this.threadContext = threadContext;
|
||||
this.contextSupplier = contextSupplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1017,15 +1017,17 @@ public class TransportService extends AbstractLifecycleComponent {
|
|||
|
||||
@Override
|
||||
public void handleResponse(T response) {
|
||||
threadContext.restore();
|
||||
try (ThreadContext.StoredContext ignore = contextSupplier.get()) {
|
||||
delegate.handleResponse(response);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleException(TransportException exp) {
|
||||
threadContext.restore();
|
||||
try (ThreadContext.StoredContext ignore = contextSupplier.get()) {
|
||||
delegate.handleException(exp);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executor() {
|
||||
|
@ -1081,13 +1083,7 @@ public class TransportService extends AbstractLifecycleComponent {
|
|||
if (ThreadPool.Names.SAME.equals(executor)) {
|
||||
processResponse(handler, response);
|
||||
} else {
|
||||
threadPool.executor(executor).execute(new Runnable() {
|
||||
@SuppressWarnings({"unchecked"})
|
||||
@Override
|
||||
public void run() {
|
||||
processResponse(handler, response);
|
||||
}
|
||||
});
|
||||
threadPool.executor(executor).execute(() -> processResponse(handler, response));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasItem;
|
||||
|
@ -86,7 +87,7 @@ public class ThreadContextTests extends ESTestCase {
|
|||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
|
||||
assertEquals("1", threadContext.getHeader("default"));
|
||||
ThreadContext.StoredContext storedContext = threadContext.newStoredContext();
|
||||
ThreadContext.StoredContext storedContext = threadContext.newStoredContext(false);
|
||||
threadContext.putHeader("foo.bar", "baz");
|
||||
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
|
@ -109,6 +110,63 @@ public class ThreadContextTests extends ESTestCase {
|
|||
assertNull(threadContext.getHeader("foo.bar"));
|
||||
}
|
||||
|
||||
public void testRestorableContext() {
|
||||
Settings build = Settings.builder().put("request.headers.default", "1").build();
|
||||
ThreadContext threadContext = new ThreadContext(build);
|
||||
threadContext.putHeader("foo", "bar");
|
||||
threadContext.putTransient("ctx.foo", 1);
|
||||
threadContext.addResponseHeader("resp.header", "baaaam");
|
||||
Supplier<ThreadContext.StoredContext> contextSupplier = threadContext.newRestorableContext(true);
|
||||
|
||||
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertEquals("1", threadContext.getHeader("default"));
|
||||
threadContext.addResponseHeader("resp.header", "boom");
|
||||
try (ThreadContext.StoredContext tmp = contextSupplier.get()) {
|
||||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
|
||||
assertEquals("1", threadContext.getHeader("default"));
|
||||
assertEquals(2, threadContext.getResponseHeaders().get("resp.header").size());
|
||||
assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0));
|
||||
assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(1));
|
||||
}
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertNull(threadContext.getTransient("ctx.foo"));
|
||||
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
|
||||
assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0));
|
||||
}
|
||||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
|
||||
assertEquals("1", threadContext.getHeader("default"));
|
||||
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
|
||||
assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));
|
||||
|
||||
contextSupplier = threadContext.newRestorableContext(false);
|
||||
|
||||
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertEquals("1", threadContext.getHeader("default"));
|
||||
threadContext.addResponseHeader("resp.header", "boom");
|
||||
try (ThreadContext.StoredContext tmp = contextSupplier.get()) {
|
||||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
|
||||
assertEquals("1", threadContext.getHeader("default"));
|
||||
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
|
||||
assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));
|
||||
}
|
||||
assertNull(threadContext.getHeader("foo"));
|
||||
assertNull(threadContext.getTransient("ctx.foo"));
|
||||
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
|
||||
assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0));
|
||||
}
|
||||
|
||||
assertEquals("bar", threadContext.getHeader("foo"));
|
||||
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
|
||||
assertEquals("1", threadContext.getHeader("default"));
|
||||
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
|
||||
assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));
|
||||
}
|
||||
|
||||
public void testResponseHeaders() {
|
||||
final boolean expectThird = randomBoolean();
|
||||
|
||||
|
|
|
@ -142,7 +142,7 @@ public class RemoteScrollableHitSource extends ScrollableHitSource {
|
|||
private <T> void execute(String method, String uri, Map<String, String> params, HttpEntity entity,
|
||||
BiFunction<XContentParser, Void, T> parser, Consumer<? super T> listener) {
|
||||
// Preserve the thread context so headers survive after the call
|
||||
ThreadContext.StoredContext ctx = threadPool.getThreadContext().newStoredContext();
|
||||
java.util.function.Supplier<ThreadContext.StoredContext> contextSupplier = threadPool.getThreadContext().newRestorableContext(true);
|
||||
class RetryHelper extends AbstractRunnable {
|
||||
private final Iterator<TimeValue> retries = backoffPolicy.iterator();
|
||||
|
||||
|
@ -152,7 +152,8 @@ public class RemoteScrollableHitSource extends ScrollableHitSource {
|
|||
@Override
|
||||
public void onSuccess(org.elasticsearch.client.Response response) {
|
||||
// Restore the thread context to get the precious headers
|
||||
ctx.restore();
|
||||
try (ThreadContext.StoredContext ctx = contextSupplier.get()) {
|
||||
assert ctx != null; // eliminates compiler warning
|
||||
T parsedResponse;
|
||||
try {
|
||||
HttpEntity responseEntity = response.getEntity();
|
||||
|
@ -182,14 +183,17 @@ public class RemoteScrollableHitSource extends ScrollableHitSource {
|
|||
"Error parsing the response, remote is likely not an Elasticsearch instance", e);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new ElasticsearchException("Error deserializing response, remote is likely not an Elasticsearch instance",
|
||||
e);
|
||||
throw new ElasticsearchException(
|
||||
"Error deserializing response, remote is likely not an Elasticsearch instance", e);
|
||||
}
|
||||
listener.accept(parsedResponse);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
try (ThreadContext.StoredContext ctx = contextSupplier.get()) {
|
||||
assert ctx != null; // eliminates compiler warning
|
||||
if (e instanceof ResponseException) {
|
||||
ResponseException re = (ResponseException) e;
|
||||
if (RestStatus.TOO_MANY_REQUESTS.getStatus() == re.getResponse().getStatusLine().getStatusCode()) {
|
||||
|
@ -210,6 +214,7 @@ public class RemoteScrollableHitSource extends ScrollableHitSource {
|
|||
}
|
||||
fail.accept(e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -150,7 +150,6 @@ public class AsyncBulkByScrollActionTests extends ESTestCase {
|
|||
client.close();
|
||||
}
|
||||
client = new MyMockClient(new NoOpClient(threadPool));
|
||||
client.threadPool().getThreadContext().newStoredContext();
|
||||
client.threadPool().getThreadContext().putHeader(expectedHeaders);
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ package org.elasticsearch.transport;
|
|||
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.apache.logging.log4j.util.Supplier;
|
||||
import org.apache.lucene.util.CollectionUtil;
|
||||
import org.apache.lucene.util.Constants;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
|
@ -1932,4 +1933,68 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
|
|||
t.join();
|
||||
}
|
||||
}
|
||||
|
||||
public void testResponseHeadersArePreserved() throws InterruptedException {
|
||||
List<String> executors = new ArrayList<>(ThreadPool.THREAD_POOL_TYPES.keySet());
|
||||
CollectionUtil.timSort(executors); // makes sure it's reproducible
|
||||
serviceA.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME,
|
||||
(request, channel) -> {
|
||||
|
||||
threadPool.getThreadContext().putTransient("boom", new Object());
|
||||
threadPool.getThreadContext().addResponseHeader("foo.bar", "baz");
|
||||
if ("fail".equals(request.info)) {
|
||||
throw new RuntimeException("boom");
|
||||
} else {
|
||||
channel.sendResponse(TransportResponse.Empty.INSTANCE);
|
||||
}
|
||||
});
|
||||
|
||||
CountDownLatch latch = new CountDownLatch(2);
|
||||
|
||||
TransportResponseHandler<TransportResponse> transportResponseHandler = new TransportResponseHandler<TransportResponse>() {
|
||||
@Override
|
||||
public TransportResponse newInstance() {
|
||||
return TransportResponse.Empty.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleResponse(TransportResponse response) {
|
||||
try {
|
||||
assertSame(response, TransportResponse.Empty.INSTANCE);
|
||||
assertTrue(threadPool.getThreadContext().getResponseHeaders().containsKey("foo.bar"));
|
||||
assertEquals(1, threadPool.getThreadContext().getResponseHeaders().get("foo.bar").size());
|
||||
assertEquals("baz", threadPool.getThreadContext().getResponseHeaders().get("foo.bar").get(0));
|
||||
assertNull(threadPool.getThreadContext().getTransient("boom"));
|
||||
} finally {
|
||||
latch.countDown();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleException(TransportException exp) {
|
||||
try {
|
||||
assertTrue(threadPool.getThreadContext().getResponseHeaders().containsKey("foo.bar"));
|
||||
assertEquals(1, threadPool.getThreadContext().getResponseHeaders().get("foo.bar").size());
|
||||
assertEquals("baz", threadPool.getThreadContext().getResponseHeaders().get("foo.bar").get(0));
|
||||
assertNull(threadPool.getThreadContext().getTransient("boom"));
|
||||
} finally {
|
||||
latch.countDown();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executor() {
|
||||
if (1 == 1)
|
||||
return "same";
|
||||
|
||||
return randomFrom(executors);
|
||||
}
|
||||
};
|
||||
|
||||
serviceB.sendRequest(nodeA, "action", new TestRequest(randomFrom("fail", "pass")), transportResponseHandler);
|
||||
serviceA.sendRequest(nodeA, "action", new TestRequest(randomFrom("fail", "pass")), transportResponseHandler);
|
||||
latch.await();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue