diff --git a/core/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/core/src/main/java/org/elasticsearch/index/shard/IndexShard.java index dc144c13d50..7b29b2b632e 100644 --- a/core/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/core/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -66,6 +66,7 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AsyncIOProcessor; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexModule; @@ -2422,7 +2423,7 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl indexSettings::getMaxRefreshListeners, () -> refresh("too_many_listeners"), threadPool.executor(ThreadPool.Names.LISTENER)::execute, - logger); + logger, threadPool.getThreadContext()); } /** diff --git a/core/src/main/java/org/elasticsearch/index/shard/RefreshListeners.java b/core/src/main/java/org/elasticsearch/index/shard/RefreshListeners.java index f0df6e12b8c..17e824eb046 100644 --- a/core/src/main/java/org/elasticsearch/index/shard/RefreshListeners.java +++ b/core/src/main/java/org/elasticsearch/index/shard/RefreshListeners.java @@ -22,6 +22,7 @@ package org.elasticsearch.index.shard; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.ReferenceManager; import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.translog.Translog; import java.io.Closeable; @@ -45,6 +46,7 @@ public final class RefreshListeners implements ReferenceManager.RefreshListener, private final Runnable forceRefresh; private final Executor listenerExecutor; private final Logger logger; + private final ThreadContext threadContext; /** * Is this closed? If true then we won't add more listeners and have flushed all pending listeners. @@ -63,11 +65,13 @@ public final class RefreshListeners implements ReferenceManager.RefreshListener, */ private volatile Translog.Location lastRefreshedLocation; - public RefreshListeners(IntSupplier getMaxRefreshListeners, Runnable forceRefresh, Executor listenerExecutor, Logger logger) { + public RefreshListeners(IntSupplier getMaxRefreshListeners, Runnable forceRefresh, Executor listenerExecutor, Logger logger, + ThreadContext threadContext) { this.getMaxRefreshListeners = getMaxRefreshListeners; this.forceRefresh = forceRefresh; this.listenerExecutor = listenerExecutor; this.logger = logger; + this.threadContext = threadContext; } /** @@ -98,8 +102,15 @@ public final class RefreshListeners implements ReferenceManager.RefreshListener, refreshListeners = listeners; } if (listeners.size() < getMaxRefreshListeners.getAsInt()) { + ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true); + Consumer contextPreservingListener = forced -> { + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + storedContext.restore(); + listener.accept(forced); + } + }; // We have a free slot so register the listener - listeners.add(new Tuple<>(location, listener)); + listeners.add(new Tuple<>(location, contextPreservingListener)); return false; } } diff --git a/core/src/test/java/org/elasticsearch/index/shard/RefreshListenersTests.java b/core/src/test/java/org/elasticsearch/index/shard/RefreshListenersTests.java index fcc3c93fc37..53ced098c04 100644 --- a/core/src/test/java/org/elasticsearch/index/shard/RefreshListenersTests.java +++ b/core/src/test/java/org/elasticsearch/index/shard/RefreshListenersTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.common.lucene.uid.Versions; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexSettings; @@ -67,6 +68,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Locale; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -87,16 +89,16 @@ public class RefreshListenersTests extends ESTestCase { public void setupListeners() throws Exception { // Setup dependencies of the listeners maxListeners = randomIntBetween(1, 1000); + // Now setup the InternalEngine which is much more complicated because we aren't mocking anything + threadPool = new TestThreadPool(getTestName()); listeners = new RefreshListeners( () -> maxListeners, () -> engine.refresh("too-many-listeners"), // Immediately run listeners rather than adding them to the listener thread pool like IndexShard does to simplify the test. Runnable::run, - logger - ); + logger, + threadPool.getThreadContext()); - // Now setup the InternalEngine which is much more complicated because we aren't mocking anything - threadPool = new TestThreadPool(getTestName()); IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("index", Settings.EMPTY); ShardId shardId = new ShardId(new Index("index", "_na_"), 1); String allocationId = UUIDs.randomBase64UUID(random()); @@ -161,6 +163,23 @@ public class RefreshListenersTests extends ESTestCase { assertEquals(0, listeners.pendingCount()); } + public void testContextIsPreserved() throws IOException, InterruptedException { + assertEquals(0, listeners.pendingCount()); + Engine.IndexResult index = index("1"); + CountDownLatch latch = new CountDownLatch(1); + try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putHeader("test", "foobar"); + assertFalse(listeners.addOrNotify(index.getTranslogLocation(), forced -> { + assertEquals("foobar", threadPool.getThreadContext().getHeader("test")); + latch.countDown(); + })); + } + assertNull(threadPool.getThreadContext().getHeader("test")); + assertEquals(1, latch.getCount()); + engine.refresh("I said so"); + latch.await(); + } + public void testTooMany() throws Exception { assertEquals(0, listeners.pendingCount()); assertFalse(listeners.refreshNeeded());