diff --git a/core/src/main/java/org/elasticsearch/index/shard/IndexShardOperationsLock.java b/core/src/main/java/org/elasticsearch/index/shard/IndexShardOperationsLock.java index cde14dec173..70e2037664f 100644 --- a/core/src/main/java/org/elasticsearch/index/shard/IndexShardOperationsLock.java +++ b/core/src/main/java/org/elasticsearch/index/shard/IndexShardOperationsLock.java @@ -20,9 +20,11 @@ package org.elasticsearch.index.shard; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.util.concurrent.ThreadContext.StoredContext; import org.elasticsearch.threadpool.ThreadPool; import java.io.Closeable; @@ -32,6 +34,7 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; public class IndexShardOperationsLock implements Closeable { private final ShardId shardId; @@ -126,11 +129,13 @@ public class IndexShardOperationsLock implements Closeable { if (delayedOperations == null) { delayedOperations = new ArrayList<>(); } + final Supplier contextSupplier = threadPool.getThreadContext().newRestorableContext(false); if (executorOnDelay != null) { delayedOperations.add( - new ThreadedActionListener<>(logger, threadPool, executorOnDelay, onAcquired, forceExecution)); + new ThreadedActionListener<>(logger, threadPool, executorOnDelay, + new ContextPreservingActionListener<>(contextSupplier, onAcquired), forceExecution)); } else { - delayedOperations.add(onAcquired); + delayedOperations.add(new ContextPreservingActionListener<>(contextSupplier, onAcquired)); } return; } diff --git a/core/src/test/java/org/elasticsearch/index/shard/IndexShardOperationsLockTests.java b/core/src/test/java/org/elasticsearch/index/shard/IndexShardOperationsLockTests.java index c9bb9e19866..d3df93513d0 100644 --- a/core/src/test/java/org/elasticsearch/index/shard/IndexShardOperationsLockTests.java +++ b/core/src/test/java/org/elasticsearch/index/shard/IndexShardOperationsLockTests.java @@ -18,9 +18,10 @@ */ package org.elasticsearch.index.shard; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.inject.internal.Nullable; import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -35,7 +36,8 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import java.util.function.Supplier; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -146,7 +148,57 @@ public class IndexShardOperationsLockTests extends ESTestCase { block.acquire(future, ThreadPool.Names.GENERIC, true); assertFalse(future.isDone()); } - future.get(1, TimeUnit.MINUTES).close(); + future.get(1, TimeUnit.HOURS).close(); + } + + /** + * Tests that the ThreadContext is restored when a operation is executed after it has been delayed due to a block + */ + public void testThreadContextPreservedIfBlock() throws ExecutionException, InterruptedException, TimeoutException { + final ThreadContext context = threadPool.getThreadContext(); + final Function, Boolean> contextChecker = (listener) -> { + if ("bar".equals(context.getHeader("foo")) == false) { + listener.onFailure(new IllegalStateException("context did not have value [bar] for header [foo]. Actual value [" + + context.getHeader("foo") + "]")); + } else if ("baz".equals(context.getTransient("bar")) == false) { + listener.onFailure(new IllegalStateException("context did not have value [baz] for transient [bar]. Actual value [" + + context.getTransient("bar") + "]")); + } else { + return true; + } + return false; + }; + PlainActionFuture future = new PlainActionFuture() { + @Override + public void onResponse(Releasable releasable) { + if (contextChecker.apply(this)) { + super.onResponse(releasable); + } + } + }; + PlainActionFuture future2 = new PlainActionFuture() { + @Override + public void onResponse(Releasable releasable) { + if (contextChecker.apply(this)) { + super.onResponse(releasable); + } + } + }; + + try (Releasable releasable = blockAndWait()) { + // we preserve the thread context here so that we have a different context in the call to acquire than the context present + // when the releasable is closed + try (ThreadContext.StoredContext ignore = context.newStoredContext(false)) { + context.putHeader("foo", "bar"); + context.putTransient("bar", "baz"); + // test both with and without a executor name + block.acquire(future, ThreadPool.Names.GENERIC, true); + block.acquire(future2, null, true); + } + assertFalse(future.isDone()); + } + future.get(1, TimeUnit.HOURS).close(); + future2.get(1, TimeUnit.HOURS).close(); } protected Releasable blockAndWait() throws InterruptedException {