Streamline foreign stored context restore and allow to perserve response headers (#22677)

Today we do not preserve response headers if they are present on a transport protocol
response. While preserving these headers is not always desired, in the most cases we
should pass on these headers to have consistent results for depreciation headers etc.
yet, this hasn't been much of a problem since most of the deprecations are detected early
ie. on the coordinating node such that this bug wasn't uncovered until #22647

This commit allow to optionally preserve headers when a context is restored and also streamlines
the context restore since it leaked frequently into the callers thread context when the callers
context wasn't restored again.
This commit is contained in:
Simon Willnauer 2017-01-18 16:17:54 +01:00 committed by GitHub
parent 8a0a1140a9
commit 24e2847af2
9 changed files with 275 additions and 87 deletions

View File

@ -377,11 +377,9 @@ public class TransportBulkAction extends HandledTransportAction<BulkRequest, Bul
onFailure(failure);
return;
}
final ThreadContext.StoredContext context = threadPool.getThreadContext().newStoredContext();
observer.waitForNextChange(new ClusterStateObserver.Listener() {
@Override
public void onNewClusterState(ClusterState state) {
context.close();
run();
}
@ -392,7 +390,6 @@ public class TransportBulkAction extends HandledTransportAction<BulkRequest, Bul
@Override
public void onTimeout(TimeValue timeout) {
context.close();
// Try one more time...
run();
}

View File

@ -514,16 +514,15 @@ public abstract class TransportReplicationAction<
request),
e);
request.onRetry();
final ThreadContext.StoredContext context = threadPool.getThreadContext().newStoredContext();
observer.waitForNextChange(new ClusterStateObserver.Listener() {
@Override
public void onNewClusterState(ClusterState state) {
context.close();
// Forking a thread on local node via transport service so that custom transport service have an
// opportunity to execute custom logic before the replica operation begins
String extraMessage = "action [" + transportReplicaAction + "], request[" + request + "]";
TransportChannelResponseHandler<TransportResponse.Empty> handler =
new TransportChannelResponseHandler<>(logger, channel, extraMessage, () -> TransportResponse.Empty.INSTANCE);
new TransportChannelResponseHandler<>(logger, channel, extraMessage,
() -> TransportResponse.Empty.INSTANCE);
transportService.sendRequest(clusterService.localNode(), transportReplicaAction,
new ConcreteShardRequest<>(request, targetAllocationID),
handler);
@ -809,11 +808,9 @@ public abstract class TransportReplicationAction<
}
setPhase(task, "waiting_for_retry");
request.onRetry();
final ThreadContext.StoredContext context = threadPool.getThreadContext().newStoredContext();
observer.waitForNextChange(new ClusterStateObserver.Listener() {
@Override
public void onNewClusterState(ClusterState state) {
context.close();
run();
}
@ -824,7 +821,6 @@ public abstract class TransportReplicationAction<
@Override
public void onTimeout(TimeValue timeout) {
context.close();
// Try one more time...
run();
}

View File

@ -29,6 +29,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
import java.util.function.Supplier;
/**
* A utility class which simplifies interacting with the cluster state in cases where
@ -118,7 +119,7 @@ public class ClusterStateObserver {
* @param timeOutValue a timeout for waiting. If null the global observer timeout will be used.
*/
public void waitForNextChange(Listener listener, Predicate<ClusterState> statePredicate, @Nullable TimeValue timeOutValue) {
listener = new ContextPreservingListener(listener, contextHolder.newRestorableContext(false));
if (observingContext.get() != null) {
throw new ElasticsearchException("already waiting for a cluster state change");
}
@ -157,8 +158,7 @@ public class ClusterStateObserver {
listener.onNewClusterState(newState);
} else {
logger.trace("observer: sampled state rejected by predicate ({}). adding listener to ClusterService", newState);
ObservingContext context =
new ObservingContext(new ContextPreservingListener(listener, contextHolder.newStoredContext()), statePredicate);
final ObservingContext context = new ObservingContext(listener, statePredicate);
if (!observingContext.compareAndSet(null, context)) {
throw new ElasticsearchException("already waiting for a cluster state change");
}
@ -279,30 +279,33 @@ public class ClusterStateObserver {
private static final class ContextPreservingListener implements Listener {
private final Listener delegate;
private final ThreadContext.StoredContext tempContext;
private final Supplier<ThreadContext.StoredContext> contextSupplier;
private ContextPreservingListener(Listener delegate, ThreadContext.StoredContext storedContext) {
this.tempContext = storedContext;
private ContextPreservingListener(Listener delegate, Supplier<ThreadContext.StoredContext> contextSupplier) {
this.contextSupplier = contextSupplier;
this.delegate = delegate;
}
@Override
public void onNewClusterState(ClusterState state) {
tempContext.restore();
delegate.onNewClusterState(state);
try (ThreadContext.StoredContext context = contextSupplier.get()) {
delegate.onNewClusterState(state);
}
}
@Override
public void onClusterServiceClose() {
tempContext.restore();
delegate.onClusterServiceClose();
try (ThreadContext.StoredContext context = contextSupplier.get()) {
delegate.onClusterServiceClose();
}
}
@Override
public void onTimeout(TimeValue timeout) {
tempContext.restore();
delegate.onTimeout(timeout);
try (ThreadContext.StoredContext context = contextSupplier.get()) {
delegate.onTimeout(timeout);
}
}
}
}

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.store.Store;
import java.io.Closeable;
import java.io.IOException;
@ -34,6 +35,9 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with
@ -115,12 +119,57 @@ public final class ThreadContext implements Closeable, Writeable {
return () -> threadLocal.set(context);
}
/**
* Just like {@link #stashContext()} but no default context is set.
* @param preserveResponseHeaders if set to <code>true</code> the response headers of the restore thread will be preserved.
*/
public StoredContext newStoredContext() {
public StoredContext newStoredContext(boolean preserveResponseHeaders) {
final ThreadContextStruct context = threadLocal.get();
return () -> threadLocal.set(context);
return () -> {
if (preserveResponseHeaders && threadLocal.get() != context) {
threadLocal.set(context.putResponseHeaders(threadLocal.get().responseHeaders));
} else {
threadLocal.set(context);
}
};
}
/**
* Returns a supplier that gathers a {@link #newStoredContext(boolean)} and restores it once the
* returned supplier is invoked. The context returned from the supplier is a stored version of the
* suppliers callers context that should be restored once the originally gathered context is not needed anymore.
* For instance this method should be used like this:
*
* <pre>
* Supplier&lt;ThreadContext.StoredContext&gt; restorable = context.newRestorableContext(true);
* new Thread() {
* public void run() {
* try (ThreadContext.StoredContext ctx = restorable.get()) {
* // execute with the parents context and restore the threads context afterwards
* }
* }
*
* }.start();
* </pre>
*
* @param preserveResponseHeaders if set to <code>true</code> the response headers of the restore thread will be preserved.
* @return a restorable context supplier
*/
public Supplier<StoredContext> newRestorableContext(boolean preserveResponseHeaders) {
return wrapRestorable(newStoredContext(preserveResponseHeaders));
}
/**
* Same as {@link #newRestorableContext(boolean)} but wraps an existing context to restore.
* @param storedContext the context to restore
*/
public Supplier<StoredContext> wrapRestorable(StoredContext storedContext) {
return () -> {
StoredContext context = newStoredContext(false);
storedContext.restore();
return context;
};
}
@Override
@ -327,6 +376,26 @@ public final class ThreadContext implements Closeable, Writeable {
}
}
private ThreadContextStruct putResponseHeaders(Map<String, List<String>> headers) {
assert headers != null;
if (headers.isEmpty()) {
return this;
}
final Map<String, List<String>> newResponseHeaders = new HashMap<>(this.responseHeaders);
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
String key = entry.getKey();
final List<String> existingValues = newResponseHeaders.get(key);
if (existingValues != null) {
List<String> newValues = Stream.concat(entry.getValue().stream(),
existingValues.stream()).distinct().collect(Collectors.toList());
newResponseHeaders.put(key, Collections.unmodifiableList(newValues));
} else {
newResponseHeaders.put(key, entry.getValue());
}
}
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders);
}
private ThreadContextStruct putResponse(String key, String value) {
assert value != null;
@ -445,7 +514,7 @@ public final class ThreadContext implements Closeable, Writeable {
private final ThreadContext.StoredContext ctx;
private ContextPreservingRunnable(Runnable in) {
ctx = newStoredContext();
ctx = newStoredContext(false);
this.in = in;
}
@ -487,7 +556,7 @@ public final class ThreadContext implements Closeable, Writeable {
private ThreadContext.StoredContext threadsOriginalContext = null;
private ContextPreservingAbstractRunnable(AbstractRunnable in) {
creatorsContext = newStoredContext();
creatorsContext = newStoredContext(false);
this.in = in;
}

View File

@ -543,8 +543,8 @@ public class TransportService extends AbstractLifecycleComponent {
} else {
timeoutHandler = new TimeoutHandler(requestId);
}
TransportResponseHandler<T> responseHandler =
new ContextRestoreResponseHandler<>(threadPool.getThreadContext().newStoredContext(), handler);
Supplier<ThreadContext.StoredContext> storedContextSupplier = threadPool.getThreadContext().newRestorableContext(true);
TransportResponseHandler<T> responseHandler = new ContextRestoreResponseHandler<>(storedContextSupplier, handler);
clientHandlers.put(requestId, new RequestHolder<>(responseHandler, connection.getNode(), action, timeoutHandler));
if (lifecycle.stoppedOrClosed()) {
// if we are not started the exception handling will remove the RequestHolder again and calls the handler to notify
@ -1000,14 +1000,14 @@ public class TransportService extends AbstractLifecycleComponent {
* This handler wrapper ensures that the response thread executes with the correct thread context. Before any of the4 handle methods
* are invoked we restore the context.
*/
private static final class ContextRestoreResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> {
public static final class ContextRestoreResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> {
private final TransportResponseHandler<T> delegate;
private final ThreadContext.StoredContext threadContext;
private final Supplier<ThreadContext.StoredContext> contextSupplier;
private ContextRestoreResponseHandler(ThreadContext.StoredContext threadContext, TransportResponseHandler<T> delegate) {
public ContextRestoreResponseHandler(Supplier<ThreadContext.StoredContext> contextSupplier, TransportResponseHandler<T> delegate) {
this.delegate = delegate;
this.threadContext = threadContext;
this.contextSupplier = contextSupplier;
}
@Override
@ -1017,14 +1017,16 @@ public class TransportService extends AbstractLifecycleComponent {
@Override
public void handleResponse(T response) {
threadContext.restore();
delegate.handleResponse(response);
try (ThreadContext.StoredContext ignore = contextSupplier.get()) {
delegate.handleResponse(response);
}
}
@Override
public void handleException(TransportException exp) {
threadContext.restore();
delegate.handleException(exp);
try (ThreadContext.StoredContext ignore = contextSupplier.get()) {
delegate.handleException(exp);
}
}
@Override
@ -1081,13 +1083,7 @@ public class TransportService extends AbstractLifecycleComponent {
if (ThreadPool.Names.SAME.equals(executor)) {
processResponse(handler, response);
} else {
threadPool.executor(executor).execute(new Runnable() {
@SuppressWarnings({"unchecked"})
@Override
public void run() {
processResponse(handler, response);
}
});
threadPool.executor(executor).execute(() -> processResponse(handler, response));
}
}
}

View File

@ -27,6 +27,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
@ -86,7 +87,7 @@ public class ThreadContextTests extends ESTestCase {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
assertEquals("1", threadContext.getHeader("default"));
ThreadContext.StoredContext storedContext = threadContext.newStoredContext();
ThreadContext.StoredContext storedContext = threadContext.newStoredContext(false);
threadContext.putHeader("foo.bar", "baz");
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
assertNull(threadContext.getHeader("foo"));
@ -109,6 +110,63 @@ public class ThreadContextTests extends ESTestCase {
assertNull(threadContext.getHeader("foo.bar"));
}
public void testRestorableContext() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
threadContext.putHeader("foo", "bar");
threadContext.putTransient("ctx.foo", 1);
threadContext.addResponseHeader("resp.header", "baaaam");
Supplier<ThreadContext.StoredContext> contextSupplier = threadContext.newRestorableContext(true);
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
assertNull(threadContext.getHeader("foo"));
assertEquals("1", threadContext.getHeader("default"));
threadContext.addResponseHeader("resp.header", "boom");
try (ThreadContext.StoredContext tmp = contextSupplier.get()) {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
assertEquals("1", threadContext.getHeader("default"));
assertEquals(2, threadContext.getResponseHeaders().get("resp.header").size());
assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0));
assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(1));
}
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("ctx.foo"));
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0));
}
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
assertEquals("1", threadContext.getHeader("default"));
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));
contextSupplier = threadContext.newRestorableContext(false);
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
assertNull(threadContext.getHeader("foo"));
assertEquals("1", threadContext.getHeader("default"));
threadContext.addResponseHeader("resp.header", "boom");
try (ThreadContext.StoredContext tmp = contextSupplier.get()) {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
assertEquals("1", threadContext.getHeader("default"));
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));
}
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("ctx.foo"));
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0));
}
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
assertEquals("1", threadContext.getHeader("default"));
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));
}
public void testResponseHeaders() {
final boolean expectThird = randomBoolean();

View File

@ -142,7 +142,7 @@ public class RemoteScrollableHitSource extends ScrollableHitSource {
private <T> void execute(String method, String uri, Map<String, String> params, HttpEntity entity,
BiFunction<XContentParser, Void, T> parser, Consumer<? super T> listener) {
// Preserve the thread context so headers survive after the call
ThreadContext.StoredContext ctx = threadPool.getThreadContext().newStoredContext();
java.util.function.Supplier<ThreadContext.StoredContext> contextSupplier = threadPool.getThreadContext().newRestorableContext(true);
class RetryHelper extends AbstractRunnable {
private final Iterator<TimeValue> retries = backoffPolicy.iterator();
@ -152,63 +152,68 @@ public class RemoteScrollableHitSource extends ScrollableHitSource {
@Override
public void onSuccess(org.elasticsearch.client.Response response) {
// Restore the thread context to get the precious headers
ctx.restore();
T parsedResponse;
try {
HttpEntity responseEntity = response.getEntity();
InputStream content = responseEntity.getContent();
XContentType xContentType = null;
if (responseEntity.getContentType() != null) {
xContentType = XContentType.fromMediaTypeOrFormat(responseEntity.getContentType().getValue());
}
if (xContentType == null) {
try {
throw new ElasticsearchException(
"Response didn't include Content-Type: " + bodyMessage(response.getEntity()));
} catch (IOException e) {
ElasticsearchException ee = new ElasticsearchException("Error extracting body from response");
ee.addSuppressed(e);
throw ee;
try (ThreadContext.StoredContext ctx = contextSupplier.get()) {
assert ctx != null; // eliminates compiler warning
T parsedResponse;
try {
HttpEntity responseEntity = response.getEntity();
InputStream content = responseEntity.getContent();
XContentType xContentType = null;
if (responseEntity.getContentType() != null) {
xContentType = XContentType.fromMediaTypeOrFormat(responseEntity.getContentType().getValue());
}
}
// EMPTY is safe here because we don't call namedObject
try (XContentParser xContentParser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY,
if (xContentType == null) {
try {
throw new ElasticsearchException(
"Response didn't include Content-Type: " + bodyMessage(response.getEntity()));
} catch (IOException e) {
ElasticsearchException ee = new ElasticsearchException("Error extracting body from response");
ee.addSuppressed(e);
throw ee;
}
}
// EMPTY is safe here because we don't call namedObject
try (XContentParser xContentParser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY,
content)) {
parsedResponse = parser.apply(xContentParser, null);
} catch (ParsingException e) {
parsedResponse = parser.apply(xContentParser, null);
} catch (ParsingException e) {
/* Because we're streaming the response we can't get a copy of it here. The best we can do is hint that it
* is totally wrong and we're probably not talking to Elasticsearch. */
throw new ElasticsearchException(
throw new ElasticsearchException(
"Error parsing the response, remote is likely not an Elasticsearch instance", e);
}
} catch (IOException e) {
throw new ElasticsearchException(
"Error deserializing response, remote is likely not an Elasticsearch instance", e);
}
} catch (IOException e) {
throw new ElasticsearchException("Error deserializing response, remote is likely not an Elasticsearch instance",
e);
listener.accept(parsedResponse);
}
listener.accept(parsedResponse);
}
@Override
public void onFailure(Exception e) {
if (e instanceof ResponseException) {
ResponseException re = (ResponseException) e;
if (RestStatus.TOO_MANY_REQUESTS.getStatus() == re.getResponse().getStatusLine().getStatusCode()) {
if (retries.hasNext()) {
TimeValue delay = retries.next();
logger.trace(
(Supplier<?>) () -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e);
countSearchRetry.run();
threadPool.schedule(delay, ThreadPool.Names.SAME, RetryHelper.this);
return;
try (ThreadContext.StoredContext ctx = contextSupplier.get()) {
assert ctx != null; // eliminates compiler warning
if (e instanceof ResponseException) {
ResponseException re = (ResponseException) e;
if (RestStatus.TOO_MANY_REQUESTS.getStatus() == re.getResponse().getStatusLine().getStatusCode()) {
if (retries.hasNext()) {
TimeValue delay = retries.next();
logger.trace(
(Supplier<?>) () -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e);
countSearchRetry.run();
threadPool.schedule(delay, ThreadPool.Names.SAME, RetryHelper.this);
return;
}
}
}
e = wrapExceptionToPreserveStatus(re.getResponse().getStatusLine().getStatusCode(),
e = wrapExceptionToPreserveStatus(re.getResponse().getStatusLine().getStatusCode(),
re.getResponse().getEntity(), re);
} else if (e instanceof ContentTooLongException) {
e = new IllegalArgumentException(
} else if (e instanceof ContentTooLongException) {
e = new IllegalArgumentException(
"Remote responded with a chunk that was too large. Use a smaller batch size.", e);
}
fail.accept(e);
}
fail.accept(e);
}
});
}

View File

@ -150,7 +150,6 @@ public class AsyncBulkByScrollActionTests extends ESTestCase {
client.close();
}
client = new MyMockClient(new NoOpClient(threadPool));
client.threadPool().getThreadContext().newStoredContext();
client.threadPool().getThreadContext().putHeader(expectedHeaders);
}

View File

@ -21,6 +21,7 @@ package org.elasticsearch.transport;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Supplier;
import org.apache.lucene.util.CollectionUtil;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.IOUtils;
import org.elasticsearch.ExceptionsHelper;
@ -1932,4 +1933,68 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
t.join();
}
}
public void testResponseHeadersArePreserved() throws InterruptedException {
List<String> executors = new ArrayList<>(ThreadPool.THREAD_POOL_TYPES.keySet());
CollectionUtil.timSort(executors); // makes sure it's reproducible
serviceA.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME,
(request, channel) -> {
threadPool.getThreadContext().putTransient("boom", new Object());
threadPool.getThreadContext().addResponseHeader("foo.bar", "baz");
if ("fail".equals(request.info)) {
throw new RuntimeException("boom");
} else {
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
});
CountDownLatch latch = new CountDownLatch(2);
TransportResponseHandler<TransportResponse> transportResponseHandler = new TransportResponseHandler<TransportResponse>() {
@Override
public TransportResponse newInstance() {
return TransportResponse.Empty.INSTANCE;
}
@Override
public void handleResponse(TransportResponse response) {
try {
assertSame(response, TransportResponse.Empty.INSTANCE);
assertTrue(threadPool.getThreadContext().getResponseHeaders().containsKey("foo.bar"));
assertEquals(1, threadPool.getThreadContext().getResponseHeaders().get("foo.bar").size());
assertEquals("baz", threadPool.getThreadContext().getResponseHeaders().get("foo.bar").get(0));
assertNull(threadPool.getThreadContext().getTransient("boom"));
} finally {
latch.countDown();
}
}
@Override
public void handleException(TransportException exp) {
try {
assertTrue(threadPool.getThreadContext().getResponseHeaders().containsKey("foo.bar"));
assertEquals(1, threadPool.getThreadContext().getResponseHeaders().get("foo.bar").size());
assertEquals("baz", threadPool.getThreadContext().getResponseHeaders().get("foo.bar").get(0));
assertNull(threadPool.getThreadContext().getTransient("boom"));
} finally {
latch.countDown();
}
}
@Override
public String executor() {
if (1 == 1)
return "same";
return randomFrom(executors);
}
};
serviceB.sendRequest(nodeA, "action", new TestRequest(randomFrom("fail", "pass")), transportResponseHandler);
serviceA.sendRequest(nodeA, "action", new TestRequest(randomFrom("fail", "pass")), transportResponseHandler);
latch.await();
}
}