Use clean thread context for transport and applier service (#57792) (#57914)

Adds assertions to Netty to make sure that its threads are not polluted by thread contexts (and
also that thread contexts are not leaked). Moves the ClusterApplierService to use the system
context (same as we do for MasterService), which allows to remove a hack from
TemplateUgradeService and makes it clearer that applying CS updates is fully executing under
system context.
This commit is contained in:
Yannick Welsch 2020-06-10 10:30:28 +02:00 committed by GitHub
parent fe85bdbe6f
commit 80f221e920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 129 additions and 51 deletions

View File

@ -65,6 +65,7 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext());
assert Transports.assertTransportThread();
assert msg instanceof ByteBuf : "Expected message type ByteBuf, found: " + msg.getClass();
@ -78,6 +79,7 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext());
ExceptionsHelper.maybeDieOnAnotherThread(cause);
final Throwable unwrapped = ExceptionsHelper.unwrap(cause, ElasticsearchException.class);
final Throwable newCause = unwrapped != null ? unwrapped : cause;
@ -92,12 +94,15 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
assert msg instanceof OutboundHandler.SendContext;
assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext());
final boolean queued = queuedWrites.offer(new WriteOperation((OutboundHandler.SendContext) msg, promise));
assert queued;
assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext());
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws IOException {
assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext());
if (ctx.channel().isWritable()) {
doFlush(ctx);
}
@ -106,6 +111,7 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler {
@Override
public void flush(ChannelHandlerContext ctx) throws IOException {
assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext());
Channel channel = ctx.channel();
if (channel.isWritable() || channel.isActive() == false) {
doFlush(ctx);
@ -114,6 +120,7 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler {
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext());
doFlush(ctx);
Releasables.closeWhileHandlingException(pipeline);
super.channelInactive(ctx);

View File

@ -37,7 +37,6 @@ import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.ImmutableOpenMap;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
@ -129,11 +128,8 @@ public class TemplateUpgradeService implements ClusterStateListener {
changes.get().v1().size(),
changes.get().v2().size());
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
threadPool.generic().execute(() -> upgradeTemplates(changes.get().v1(), changes.get().v2()));
}
assert threadPool.getThreadContext().isSystemContext();
threadPool.generic().execute(() -> upgradeTemplates(changes.get().v1(), changes.get().v2()));
}
}
}

View File

@ -46,6 +46,7 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.iterable.Iterables;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
@ -347,8 +348,12 @@ public class ClusterApplierService extends AbstractLifecycleComponent implements
if (!lifecycle.started()) {
return;
}
try {
UpdateTask updateTask = new UpdateTask(config.priority(), source, new SafeClusterApplyListener(listener, logger), executor);
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
final UpdateTask updateTask = new UpdateTask(config.priority(), source,
new SafeClusterApplyListener(listener, supplier, logger), executor);
if (config.timeout() != null) {
threadPoolExecutor.execute(updateTask, config.timeout(),
() -> threadPool.generic().execute(
@ -534,16 +539,18 @@ public class ClusterApplierService extends AbstractLifecycleComponent implements
private static class SafeClusterApplyListener implements ClusterApplyListener {
private final ClusterApplyListener listener;
protected final Supplier<ThreadContext.StoredContext> context;
private final Logger logger;
SafeClusterApplyListener(ClusterApplyListener listener, Logger logger) {
SafeClusterApplyListener(ClusterApplyListener listener, Supplier<ThreadContext.StoredContext> context, Logger logger) {
this.listener = listener;
this.context = context;
this.logger = logger;
}
@Override
public void onFailure(String source, Exception e) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onFailure(source, e);
} catch (Exception inner) {
inner.addSuppressed(e);
@ -554,7 +561,7 @@ public class ClusterApplierService extends AbstractLifecycleComponent implements
@Override
public void onSuccess(String source) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onSuccess(source);
} catch (Exception e) {
logger.error(new ParameterizedMessage(

View File

@ -138,6 +138,14 @@ public final class ThreadContext implements Writeable {
};
}
/**
* Captures the current thread context as writeable, allowing it to be serialized out later
*/
public Writeable captureAsWriteable() {
final ThreadContextStruct context = threadLocal.get();
return out -> context.writeTo(out, defaultHeader);
}
/**
* Removes the current context and resets a default context marked with as
* originating from the supplied string. The removed context can be
@ -294,6 +302,13 @@ public final class ThreadContext implements Writeable {
return Collections.unmodifiableMap(map);
}
/**
* Returns the request headers, without the default headers
*/
public Map<String, String> getRequestHeadersOnly() {
return Collections.unmodifiableMap(new HashMap<>(threadLocal.get().requestHeaders));
}
/**
* Get a copy of all <em>response</em> headers.
*
@ -493,7 +508,7 @@ public final class ThreadContext implements Writeable {
return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders, isSystemContext);
}
private void putSingleHeader(String key, String value, Map<String, String> newHeaders) {
private static void putSingleHeader(String key, String value, Map<String, String> newHeaders) {
if (newHeaders.putIfAbsent(key, value) != null) {
throw new IllegalArgumentException("value for key [" + key + "] already present");
}

View File

@ -19,6 +19,7 @@
package org.elasticsearch.transport;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.ThreadContext;
/**
@ -28,15 +29,12 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
public abstract class NetworkMessage {
protected final Version version;
protected final ThreadContext threadContext;
protected final ThreadContext.StoredContext storedContext;
protected final Writeable threadContext;
protected final long requestId;
protected final byte status;
NetworkMessage(ThreadContext threadContext, Version version, byte status, long requestId) {
this.threadContext = threadContext;
storedContext = threadContext.stashContext();
storedContext.restore();
this.threadContext = threadContext.captureAsWriteable();
this.version = version;
this.requestId = requestId;
this.status = status;
@ -54,10 +52,6 @@ public abstract class NetworkMessage {
return TransportStatus.isCompress(status);
}
ThreadContext.StoredContext getStoredContext() {
return storedContext;
}
boolean isResponse() {
return TransportStatus.isRequest(status) == false;
}

View File

@ -35,6 +35,7 @@ import org.elasticsearch.common.network.CloseableChannel;
import org.elasticsearch.common.transport.NetworkExceptionHelper;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.threadpool.ThreadPool;
@ -121,7 +122,8 @@ public final class OutboundHandler {
private void internalSend(TcpChannel channel, SendContext sendContext) {
channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
try {
// stash thread context so that channel event loop is not polluted by thread context
try (ThreadContext.StoredContext existing = threadPool.getThreadContext().stashContext()) {
channel.sendMessage(sendContext);
} catch (RuntimeException ex) {
sendContext.onFailure(ex);

View File

@ -40,37 +40,34 @@ abstract class OutboundMessage extends NetworkMessage {
}
BytesReference serialize(BytesStreamOutput bytesStream) throws IOException {
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
storedContext.restore();
bytesStream.setVersion(version);
bytesStream.skip(TcpHeader.headerSize(version));
bytesStream.setVersion(version);
bytesStream.skip(TcpHeader.headerSize(version));
// The compressible bytes stream will not close the underlying bytes stream
BytesReference reference;
int variableHeaderLength = -1;
final long preHeaderPosition = bytesStream.position();
// The compressible bytes stream will not close the underlying bytes stream
BytesReference reference;
int variableHeaderLength = -1;
final long preHeaderPosition = bytesStream.position();
if (version.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) {
writeVariableHeader(bytesStream);
variableHeaderLength = Math.toIntExact(bytesStream.position() - preHeaderPosition);
}
try (CompressibleBytesOutputStream stream =
new CompressibleBytesOutputStream(bytesStream, TransportStatus.isCompress(status))) {
stream.setVersion(version);
stream.setFeatures(bytesStream.getFeatures());
if (variableHeaderLength == -1) {
writeVariableHeader(stream);
}
reference = writeMessage(stream);
}
bytesStream.seek(0);
final int contentSize = reference.length() - TcpHeader.headerSize(version);
TcpHeader.writeHeader(bytesStream, requestId, status, version, contentSize, variableHeaderLength);
return reference;
if (version.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) {
writeVariableHeader(bytesStream);
variableHeaderLength = Math.toIntExact(bytesStream.position() - preHeaderPosition);
}
try (CompressibleBytesOutputStream stream =
new CompressibleBytesOutputStream(bytesStream, TransportStatus.isCompress(status))) {
stream.setVersion(version);
stream.setFeatures(bytesStream.getFeatures());
if (variableHeaderLength == -1) {
writeVariableHeader(stream);
}
reference = writeMessage(stream);
}
bytesStream.seek(0);
final int contentSize = reference.length() - TcpHeader.headerSize(version);
TcpHeader.writeHeader(bytesStream, requestId, status, version, contentSize, variableHeaderLength);
return reference;
}
protected void writeVariableHeader(StreamOutput stream) throws IOException {

View File

@ -19,7 +19,9 @@
package org.elasticsearch.transport;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.tasks.Task;
import java.util.Arrays;
@ -58,4 +60,11 @@ public enum Transports {
assert isTransportThread(t) == false : "Expected current thread [" + t + "] to not be a transport thread. Reason: [" + reason + "]";
return true;
}
public static boolean assertDefaultThreadContext(ThreadContext threadContext) {
assert threadContext.getRequestHeadersOnly().isEmpty() ||
threadContext.getRequestHeadersOnly().size() == 1 && threadContext.getRequestHeadersOnly().containsKey(Task.X_OPAQUE_ID) :
"expected empty context but was " + threadContext.getRequestHeadersOnly() + " on " + Thread.currentThread().getName();
return true;
}
}

View File

@ -38,6 +38,7 @@ import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLogAppender;
import org.elasticsearch.test.junit.annotations.TestLogging;
@ -48,6 +49,9 @@ import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
@ -494,6 +498,48 @@ public class ClusterApplierServiceTests extends ESTestCase {
assertTrue(applierCalled.get());
}
public void testThreadContext() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) {
final Map<String, String> expectedHeaders = Collections.singletonMap("test", "test");
final Map<String, List<String>> expectedResponseHeaders = Collections.singletonMap("testResponse",
Collections.singletonList("testResponse"));
threadPool.getThreadContext().putHeader(expectedHeaders);
clusterApplierService.onNewClusterState("test", () -> {
assertTrue(threadPool.getThreadContext().isSystemContext());
assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getHeaders());
threadPool.getThreadContext().addResponseHeader("testResponse", "testResponse");
assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
if (randomBoolean()) {
return ClusterState.builder(clusterApplierService.state()).build();
} else {
throw new IllegalArgumentException("mock failure");
}
}, new ClusterApplyListener() {
@Override
public void onSuccess(String source) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
latch.countDown();
}
@Override
public void onFailure(String source, Exception e) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
latch.countDown();
}
});
}
latch.await();
}
static class TimedClusterApplierService extends ClusterApplierService {
final ClusterSettings clusterSettings;

View File

@ -12,6 +12,7 @@ import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.metadata.IndexTemplateMetadata;
import org.elasticsearch.cluster.metadata.TemplateUpgradeService;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
import org.elasticsearch.test.ESIntegTestCase.Scope;
import org.elasticsearch.test.SecurityIntegTestCase;
@ -59,7 +60,11 @@ public class TemplateUpgraderTests extends SecurityIntegTestCase {
// ensure the cluster listener gets triggered
ClusterChangedEvent event = new ClusterChangedEvent("testing", clusterService.state(), clusterService.state());
templateUpgradeService.clusterChanged(event);
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
templateUpgradeService.clusterChanged(event);
}
assertBusy(() -> assertTemplates("added-template", "removed-template"));
}