diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java index 99f199690e4..fb72bd82eed 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java @@ -65,6 +65,7 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext()); assert Transports.assertTransportThread(); assert msg instanceof ByteBuf : "Expected message type ByteBuf, found: " + msg.getClass(); @@ -78,6 +79,7 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext()); ExceptionsHelper.maybeDieOnAnotherThread(cause); final Throwable unwrapped = ExceptionsHelper.unwrap(cause, ElasticsearchException.class); final Throwable newCause = unwrapped != null ? unwrapped : cause; @@ -92,12 +94,15 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler { @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { assert msg instanceof OutboundHandler.SendContext; + assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext()); final boolean queued = queuedWrites.offer(new WriteOperation((OutboundHandler.SendContext) msg, promise)); assert queued; + assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext()); } @Override public void channelWritabilityChanged(ChannelHandlerContext ctx) throws IOException { + assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext()); if (ctx.channel().isWritable()) { doFlush(ctx); } @@ -106,6 +111,7 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler { @Override public void flush(ChannelHandlerContext ctx) throws IOException { + assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext()); Channel channel = ctx.channel(); if (channel.isWritable() || channel.isActive() == false) { doFlush(ctx); @@ -114,6 +120,7 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler { @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { + assert Transports.assertDefaultThreadContext(transport.getThreadPool().getThreadContext()); doFlush(ctx); Releasables.closeWhileHandlingException(pipeline); super.channelInactive(ctx); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/TemplateUpgradeService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/TemplateUpgradeService.java index ec5e20f0154..9cc938d9ec5 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/TemplateUpgradeService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/TemplateUpgradeService.java @@ -37,7 +37,6 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentType; @@ -129,11 +128,8 @@ public class TemplateUpgradeService implements ClusterStateListener { changes.get().v1().size(), changes.get().v2().size()); - final ThreadContext threadContext = threadPool.getThreadContext(); - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - threadContext.markAsSystemContext(); - threadPool.generic().execute(() -> upgradeTemplates(changes.get().v1(), changes.get().v2())); - } + assert threadPool.getThreadContext().isSystemContext(); + threadPool.generic().execute(() -> upgradeTemplates(changes.get().v1(), changes.get().v2())); } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/service/ClusterApplierService.java b/server/src/main/java/org/elasticsearch/cluster/service/ClusterApplierService.java index 8f3d7701c4a..4d94cc31d3d 100644 --- a/server/src/main/java/org/elasticsearch/cluster/service/ClusterApplierService.java +++ b/server/src/main/java/org/elasticsearch/cluster/service/ClusterApplierService.java @@ -46,6 +46,7 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.iterable.Iterables; import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; @@ -347,8 +348,12 @@ public class ClusterApplierService extends AbstractLifecycleComponent implements if (!lifecycle.started()) { return; } - try { - UpdateTask updateTask = new UpdateTask(config.priority(), source, new SafeClusterApplyListener(listener, logger), executor); + final ThreadContext threadContext = threadPool.getThreadContext(); + final Supplier supplier = threadContext.newRestorableContext(true); + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + threadContext.markAsSystemContext(); + final UpdateTask updateTask = new UpdateTask(config.priority(), source, + new SafeClusterApplyListener(listener, supplier, logger), executor); if (config.timeout() != null) { threadPoolExecutor.execute(updateTask, config.timeout(), () -> threadPool.generic().execute( @@ -534,16 +539,18 @@ public class ClusterApplierService extends AbstractLifecycleComponent implements private static class SafeClusterApplyListener implements ClusterApplyListener { private final ClusterApplyListener listener; + protected final Supplier context; private final Logger logger; - SafeClusterApplyListener(ClusterApplyListener listener, Logger logger) { + SafeClusterApplyListener(ClusterApplyListener listener, Supplier context, Logger logger) { this.listener = listener; + this.context = context; this.logger = logger; } @Override public void onFailure(String source, Exception e) { - try { + try (ThreadContext.StoredContext ignore = context.get()) { listener.onFailure(source, e); } catch (Exception inner) { inner.addSuppressed(e); @@ -554,7 +561,7 @@ public class ClusterApplierService extends AbstractLifecycleComponent implements @Override public void onSuccess(String source) { - try { + try (ThreadContext.StoredContext ignore = context.get()) { listener.onSuccess(source); } catch (Exception e) { logger.error(new ParameterizedMessage( diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java index b2d196418af..52e469bf26a 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java @@ -138,6 +138,14 @@ public final class ThreadContext implements Writeable { }; } + /** + * Captures the current thread context as writeable, allowing it to be serialized out later + */ + public Writeable captureAsWriteable() { + final ThreadContextStruct context = threadLocal.get(); + return out -> context.writeTo(out, defaultHeader); + } + /** * Removes the current context and resets a default context marked with as * originating from the supplied string. The removed context can be @@ -294,6 +302,13 @@ public final class ThreadContext implements Writeable { return Collections.unmodifiableMap(map); } + /** + * Returns the request headers, without the default headers + */ + public Map getRequestHeadersOnly() { + return Collections.unmodifiableMap(new HashMap<>(threadLocal.get().requestHeaders)); + } + /** * Get a copy of all response headers. * @@ -493,7 +508,7 @@ public final class ThreadContext implements Writeable { return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders, isSystemContext); } - private void putSingleHeader(String key, String value, Map newHeaders) { + private static void putSingleHeader(String key, String value, Map newHeaders) { if (newHeaders.putIfAbsent(key, value) != null) { throw new IllegalArgumentException("value for key [" + key + "] already present"); } diff --git a/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java b/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java index 7d8dbb8a0f1..c46466897e4 100644 --- a/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java +++ b/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.concurrent.ThreadContext; /** @@ -28,15 +29,12 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; public abstract class NetworkMessage { protected final Version version; - protected final ThreadContext threadContext; - protected final ThreadContext.StoredContext storedContext; + protected final Writeable threadContext; protected final long requestId; protected final byte status; NetworkMessage(ThreadContext threadContext, Version version, byte status, long requestId) { - this.threadContext = threadContext; - storedContext = threadContext.stashContext(); - storedContext.restore(); + this.threadContext = threadContext.captureAsWriteable(); this.version = version; this.requestId = requestId; this.status = status; @@ -54,10 +52,6 @@ public abstract class NetworkMessage { return TransportStatus.isCompress(status); } - ThreadContext.StoredContext getStoredContext() { - return storedContext; - } - boolean isResponse() { return TransportStatus.isRequest(status) == false; } diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java index 5fd684c1b04..b1d82d9827a 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -35,6 +35,7 @@ import org.elasticsearch.common.network.CloseableChannel; import org.elasticsearch.common.transport.NetworkExceptionHelper; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.threadpool.ThreadPool; @@ -121,7 +122,8 @@ public final class OutboundHandler { private void internalSend(TcpChannel channel, SendContext sendContext) { channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); - try { + // stash thread context so that channel event loop is not polluted by thread context + try (ThreadContext.StoredContext existing = threadPool.getThreadContext().stashContext()) { channel.sendMessage(sendContext); } catch (RuntimeException ex) { sendContext.onFailure(ex); diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java b/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java index 9ad96a846f2..9cd004f2b51 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java @@ -40,37 +40,34 @@ abstract class OutboundMessage extends NetworkMessage { } BytesReference serialize(BytesStreamOutput bytesStream) throws IOException { - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - storedContext.restore(); - bytesStream.setVersion(version); - bytesStream.skip(TcpHeader.headerSize(version)); + bytesStream.setVersion(version); + bytesStream.skip(TcpHeader.headerSize(version)); - // The compressible bytes stream will not close the underlying bytes stream - BytesReference reference; - int variableHeaderLength = -1; - final long preHeaderPosition = bytesStream.position(); + // The compressible bytes stream will not close the underlying bytes stream + BytesReference reference; + int variableHeaderLength = -1; + final long preHeaderPosition = bytesStream.position(); - if (version.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) { - writeVariableHeader(bytesStream); - variableHeaderLength = Math.toIntExact(bytesStream.position() - preHeaderPosition); - } - - try (CompressibleBytesOutputStream stream = - new CompressibleBytesOutputStream(bytesStream, TransportStatus.isCompress(status))) { - stream.setVersion(version); - stream.setFeatures(bytesStream.getFeatures()); - - if (variableHeaderLength == -1) { - writeVariableHeader(stream); - } - reference = writeMessage(stream); - } - - bytesStream.seek(0); - final int contentSize = reference.length() - TcpHeader.headerSize(version); - TcpHeader.writeHeader(bytesStream, requestId, status, version, contentSize, variableHeaderLength); - return reference; + if (version.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) { + writeVariableHeader(bytesStream); + variableHeaderLength = Math.toIntExact(bytesStream.position() - preHeaderPosition); } + + try (CompressibleBytesOutputStream stream = + new CompressibleBytesOutputStream(bytesStream, TransportStatus.isCompress(status))) { + stream.setVersion(version); + stream.setFeatures(bytesStream.getFeatures()); + + if (variableHeaderLength == -1) { + writeVariableHeader(stream); + } + reference = writeMessage(stream); + } + + bytesStream.seek(0); + final int contentSize = reference.length() - TcpHeader.headerSize(version); + TcpHeader.writeHeader(bytesStream, requestId, status, version, contentSize, variableHeaderLength); + return reference; } protected void writeVariableHeader(StreamOutput stream) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/transport/Transports.java b/server/src/main/java/org/elasticsearch/transport/Transports.java index d6968a85536..2106c194aab 100644 --- a/server/src/main/java/org/elasticsearch/transport/Transports.java +++ b/server/src/main/java/org/elasticsearch/transport/Transports.java @@ -19,7 +19,9 @@ package org.elasticsearch.transport; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.http.HttpServerTransport; +import org.elasticsearch.tasks.Task; import java.util.Arrays; @@ -58,4 +60,11 @@ public enum Transports { assert isTransportThread(t) == false : "Expected current thread [" + t + "] to not be a transport thread. Reason: [" + reason + "]"; return true; } + + public static boolean assertDefaultThreadContext(ThreadContext threadContext) { + assert threadContext.getRequestHeadersOnly().isEmpty() || + threadContext.getRequestHeadersOnly().size() == 1 && threadContext.getRequestHeadersOnly().containsKey(Task.X_OPAQUE_ID) : + "expected empty context but was " + threadContext.getRequestHeadersOnly() + " on " + Thread.currentThread().getName(); + return true; + } } diff --git a/server/src/test/java/org/elasticsearch/cluster/service/ClusterApplierServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/service/ClusterApplierServiceTests.java index 9838d9d05c1..5cb74615a6f 100644 --- a/server/src/test/java/org/elasticsearch/cluster/service/ClusterApplierServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/service/ClusterApplierServiceTests.java @@ -38,6 +38,7 @@ import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.MockLogAppender; import org.elasticsearch.test.junit.annotations.TestLogging; @@ -48,6 +49,9 @@ import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; +import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -494,6 +498,48 @@ public class ClusterApplierServiceTests extends ESTestCase { assertTrue(applierCalled.get()); } + public void testThreadContext() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + + try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) { + final Map expectedHeaders = Collections.singletonMap("test", "test"); + final Map> expectedResponseHeaders = Collections.singletonMap("testResponse", + Collections.singletonList("testResponse")); + threadPool.getThreadContext().putHeader(expectedHeaders); + + clusterApplierService.onNewClusterState("test", () -> { + assertTrue(threadPool.getThreadContext().isSystemContext()); + assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getHeaders()); + threadPool.getThreadContext().addResponseHeader("testResponse", "testResponse"); + assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders()); + if (randomBoolean()) { + return ClusterState.builder(clusterApplierService.state()).build(); + } else { + throw new IllegalArgumentException("mock failure"); + } + }, new ClusterApplyListener() { + + @Override + public void onSuccess(String source) { + assertFalse(threadPool.getThreadContext().isSystemContext()); + assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders()); + assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders()); + latch.countDown(); + } + + @Override + public void onFailure(String source, Exception e) { + assertFalse(threadPool.getThreadContext().isSystemContext()); + assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders()); + assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders()); + latch.countDown(); + } + }); + } + + latch.await(); + } + static class TimedClusterApplierService extends ClusterApplierService { final ClusterSettings clusterSettings; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/TemplateUpgraderTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/TemplateUpgraderTests.java index 42b70751f4d..be24a891131 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/TemplateUpgraderTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/TemplateUpgraderTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; import org.elasticsearch.cluster.metadata.TemplateUpgradeService; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.test.ESIntegTestCase.ClusterScope; import org.elasticsearch.test.ESIntegTestCase.Scope; import org.elasticsearch.test.SecurityIntegTestCase; @@ -59,7 +60,11 @@ public class TemplateUpgraderTests extends SecurityIntegTestCase { // ensure the cluster listener gets triggered ClusterChangedEvent event = new ClusterChangedEvent("testing", clusterService.state(), clusterService.state()); - templateUpgradeService.clusterChanged(event); + final ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + threadContext.markAsSystemContext(); + templateUpgradeService.clusterChanged(event); + } assertBusy(() -> assertTemplates("added-template", "removed-template")); }