diff --git a/core/src/main/java/org/elasticsearch/action/bulk/Retry.java b/core/src/main/java/org/elasticsearch/action/bulk/Retry.java index 375796ae801..a41bd454979 100644 --- a/core/src/main/java/org/elasticsearch/action/bulk/Retry.java +++ b/core/src/main/java/org/elasticsearch/action/bulk/Retry.java @@ -142,7 +142,9 @@ public class Retry { assert backoff.hasNext(); TimeValue next = backoff.next(); logger.trace("Retry of bulk request scheduled in {} ms.", next.millis()); - scheduledRequestFuture = client.threadPool().schedule(next, ThreadPool.Names.SAME, (() -> this.execute(bulkRequestForRetry))); + Runnable retry = () -> this.execute(bulkRequestForRetry); + retry = client.threadPool().getThreadContext().preserveContext(retry); + scheduledRequestFuture = client.threadPool().schedule(next, ThreadPool.Names.SAME, retry); } private BulkRequest createBulkRequestForRetry(BulkResponse bulkItemResponses) { diff --git a/core/src/test/java/org/elasticsearch/action/bulk/RetryTests.java b/core/src/test/java/org/elasticsearch/action/bulk/RetryTests.java index c0e735ec33c..ea365974995 100644 --- a/core/src/test/java/org/elasticsearch/action/bulk/RetryTests.java +++ b/core/src/test/java/org/elasticsearch/action/bulk/RetryTests.java @@ -31,6 +31,8 @@ import org.elasticsearch.test.client.NoOpClient; import org.junit.After; import org.junit.Before; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; @@ -44,12 +46,21 @@ public class RetryTests extends ESTestCase { private static final int CALLS_TO_FAIL = 5; private MockBulkClient bulkClient; + /** + * Headers that are expected to be sent with all bulk requests. + */ + private Map expectedHeaders = new HashMap<>(); @Override @Before public void setUp() throws Exception { super.setUp(); this.bulkClient = new MockBulkClient(getTestName(), CALLS_TO_FAIL); + // Stash some random headers so we can assert that we preserve them + bulkClient.threadPool().getThreadContext().stashContext(); + expectedHeaders.clear(); + expectedHeaders.put(randomAsciiOfLength(5), randomAsciiOfLength(5)); + bulkClient.threadPool().getThreadContext().putHeader(expectedHeaders); } @Override @@ -178,7 +189,7 @@ public class RetryTests extends ESTestCase { } } - private static class MockBulkClient extends NoOpClient { + private class MockBulkClient extends NoOpClient { private int numberOfCallsToFail; private MockBulkClient(String testName, int numberOfCallsToFail) { @@ -195,6 +206,12 @@ public class RetryTests extends ESTestCase { @Override public void bulk(BulkRequest request, ActionListener listener) { + if (false == expectedHeaders.equals(threadPool().getThreadContext().getHeaders())) { + listener.onFailure( + new RuntimeException("Expected " + expectedHeaders + " but got " + threadPool().getThreadContext().getHeaders())); + return; + } + // do everything synchronously, that's fine for a test boolean shouldFail = numberOfCallsToFail > 0; numberOfCallsToFail--; diff --git a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/AbstractAsyncBulkByScrollAction.java b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/AbstractAsyncBulkByScrollAction.java index d8e935b5022..0328e606d9e 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/AbstractAsyncBulkByScrollAction.java +++ b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/AbstractAsyncBulkByScrollAction.java @@ -30,7 +30,6 @@ import org.elasticsearch.action.bulk.BulkItemResponse.Failure; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.bulk.Retry; -import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.client.ParentTaskAssigningClient; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; diff --git a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java index 09423b2cca8..7b6be85140f 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java +++ b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java @@ -79,10 +79,12 @@ public class ClientScrollableHitSource extends ScrollableHitSource { @Override protected void doStartNextScroll(String scrollId, TimeValue extraKeepAlive, Consumer onResponse) { - SearchScrollRequest request = new SearchScrollRequest(); - // Add the wait time into the scroll timeout so it won't timeout while we wait for throttling - request.scrollId(scrollId).scroll(timeValueNanos(firstSearchRequest.scroll().keepAlive().nanos() + extraKeepAlive.nanos())); - searchWithRetry(listener -> client.searchScroll(request, listener), r -> consume(r, onResponse)); + searchWithRetry(listener -> { + SearchScrollRequest request = new SearchScrollRequest(); + // Add the wait time into the scroll timeout so it won't timeout while we wait for throttling + request.scrollId(scrollId).scroll(timeValueNanos(firstSearchRequest.scroll().keepAlive().nanos() + extraKeepAlive.nanos())); + client.searchScroll(request, listener); + }, r -> consume(r, onResponse)); } @Override @@ -126,6 +128,10 @@ public class ClientScrollableHitSource extends ScrollableHitSource { */ class RetryHelper extends AbstractRunnable implements ActionListener { private final Iterator retries = backoffPolicy.iterator(); + /** + * The runnable to run that retries in the same context as the original call. + */ + private Runnable retryWithContext; private volatile int retryCount = 0; @Override @@ -146,7 +152,7 @@ public class ClientScrollableHitSource extends ScrollableHitSource { TimeValue delay = retries.next(); logger.trace((Supplier) () -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e); countSearchRetry.run(); - threadPool.schedule(delay, ThreadPool.Names.SAME, this); + threadPool.schedule(delay, ThreadPool.Names.SAME, retryWithContext); } else { logger.warn( (Supplier) () -> new ParameterizedMessage( @@ -159,7 +165,10 @@ public class ClientScrollableHitSource extends ScrollableHitSource { } } } - new RetryHelper().run(); + RetryHelper helper = new RetryHelper(); + // Wrap the helper in a runnable that preserves the current context so we keep it on retry. + helper.retryWithContext = threadPool.getThreadContext().preserveContext(helper); + helper.run(); } private void consume(SearchResponse response, Consumer onResponse) { diff --git a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java index b14b790340c..932d4de9174 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java @@ -27,9 +27,9 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.DocWriteResponse; import org.elasticsearch.action.DocWriteResponse.Result; -import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.bulk.BackoffPolicy; import org.elasticsearch.action.bulk.BulkItemResponse; @@ -80,6 +80,7 @@ import org.junit.Before; import java.util.ArrayList; import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -95,8 +96,10 @@ import java.util.function.Consumer; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; +import static java.util.Collections.newSetFromMap; import static java.util.Collections.singleton; import static java.util.Collections.singletonList; +import static java.util.Collections.synchronizedSet; import static org.apache.lucene.util.TestUtil.randomSimpleString; import static org.elasticsearch.action.bulk.BackoffPolicy.constantBackoff; import static org.elasticsearch.common.unit.TimeValue.timeValueMillis; @@ -114,7 +117,6 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo; public class AsyncBulkByScrollActionTests extends ESTestCase { private MyMockClient client; - private ThreadPool threadPool; private DummyAbstractBulkByScrollRequest testRequest; private SearchRequest firstSearchRequest; private PlainActionFuture listener; @@ -127,8 +129,11 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { @Before public void setupForTest() { - client = new MyMockClient(new NoOpClient(getTestName())); - threadPool = new TestThreadPool(getTestName()); + // Fill the context with something random so we can make sure we inherited it appropriately. + expectedHeaders.clear(); + expectedHeaders.put(randomSimpleString(random()), randomSimpleString(random())); + + setupClient(new TestThreadPool(getTestName())); firstSearchRequest = new SearchRequest(); testRequest = new DummyAbstractBulkByScrollRequest(firstSearchRequest); listener = new PlainActionFuture<>(); @@ -136,19 +141,22 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { taskManager = new TaskManager(Settings.EMPTY); testTask = (WorkingBulkByScrollTask) taskManager.register("don'tcare", "hereeither", testRequest); - // Fill the context with something random so we can make sure we inherited it appropriately. - expectedHeaders.clear(); - expectedHeaders.put(randomSimpleString(random()), randomSimpleString(random())); - threadPool.getThreadContext().newStoredContext(); - threadPool.getThreadContext().putHeader(expectedHeaders); localNode = new DiscoveryNode("thenode", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); taskId = new TaskId(localNode.getId(), testTask.getId()); } + private void setupClient(ThreadPool threadPool) { + if (client != null) { + client.close(); + } + client = new MyMockClient(new NoOpClient(threadPool)); + client.threadPool().getThreadContext().newStoredContext(); + client.threadPool().getThreadContext().putHeader(expectedHeaders); + } + @After public void tearDownAndVerifyCommonStuff() { client.close(); - threadPool.shutdown(); } /** @@ -237,12 +245,6 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { // Use assert busy because the update happens on another thread final int expectedBatches = batches; assertBusy(() -> assertEquals(expectedBatches, testTask.getStatus().getBatches())); - - /* - * Also while we're here check that we preserved the headers from the last request. assertBusy because no requests might have - * come in yet. - */ - assertBusy(() -> assertEquals(expectedHeaders, client.lastHeaders.get())); } } @@ -304,8 +306,7 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { */ public void testThreadPoolRejectionsAbortRequest() throws Exception { testTask.rethrottle(1); - threadPool.shutdown(); - threadPool = new TestThreadPool(getTestName()) { + setupClient(new TestThreadPool(getTestName()) { @Override public ScheduledFuture schedule(TimeValue delay, String name, Runnable command) { // While we're here we can check that the sleep made it through @@ -314,7 +315,7 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { ((AbstractRunnable) command).onRejection(new EsRejectedExecutionException("test")); return null; } - }; + }); ScrollableHitSource.Response response = new ScrollableHitSource.Response(false, emptyList(), 0, emptyList(), null); simulateScrollResponse(new DummyAbstractAsyncBulkByScrollAction(), timeValueNanos(System.nanoTime()), 10, response); ExecutionException e = expectThrows(ExecutionException.class, () -> listener.get()); @@ -417,15 +418,14 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { */ AtomicReference capturedDelay = new AtomicReference<>(); AtomicReference capturedCommand = new AtomicReference<>(); - threadPool.shutdown(); - threadPool = new TestThreadPool(getTestName()) { + setupClient(new TestThreadPool(getTestName()) { @Override public ScheduledFuture schedule(TimeValue delay, String name, Runnable command) { capturedDelay.set(delay); capturedCommand.set(command); return null; } - }; + }); DummyAbstractAsyncBulkByScrollAction action = new DummyAbstractAsyncBulkByScrollAction(); action.setScroll(scrollId()); @@ -497,7 +497,7 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { assertThat(response.getSearchFailures(), empty()); assertNull(response.getReasonCancelled()); } else { - successLatch.await(10, TimeUnit.SECONDS); + assertTrue(successLatch.await(10, TimeUnit.SECONDS)); } } @@ -593,8 +593,7 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { * Replace the thread pool with one that will cancel the task as soon as anything is scheduled, which reindex tries to do when there * is a delay. */ - threadPool.shutdown(); - threadPool = new TestThreadPool(getTestName()) { + setupClient(new TestThreadPool(getTestName()) { @Override public ScheduledFuture schedule(TimeValue delay, String name, Runnable command) { /* @@ -609,7 +608,7 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { } return super.schedule(delay, name, command); } - }; + }); // Send the scroll response which will trigger the custom thread pool above, canceling the request before running the response DummyAbstractAsyncBulkByScrollAction action = new DummyAbstractAsyncBulkByScrollAction(); @@ -660,7 +659,7 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { extends AbstractAsyncBulkByScrollAction { public DummyAbstractAsyncBulkByScrollAction() { super(testTask, AsyncBulkByScrollActionTests.this.logger, new ParentTaskAssigningClient(client, localNode, testTask), - AsyncBulkByScrollActionTests.this.threadPool, testRequest, listener); + client.threadPool(), testRequest, listener); } @Override @@ -706,7 +705,6 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { private final AtomicInteger bulksAttempts = new AtomicInteger(); private final AtomicInteger searchAttempts = new AtomicInteger(); private final AtomicInteger scrollAttempts = new AtomicInteger(); - private final AtomicReference> lastHeaders = new AtomicReference<>(); private final AtomicReference lastRefreshRequest = new AtomicReference<>(); /** * Last search attempt that wasn't rejected outright. @@ -716,7 +714,10 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { * Last scroll attempt that wasn't rejected outright. */ private final AtomicReference> lastScroll = new AtomicReference<>(); - + /** + * Set of all scrolls we've already used. Used to check that we don't reuse the same request twice. + */ + private final Set usedScolls = synchronizedSet(newSetFromMap(new IdentityHashMap<>())); private int bulksToReject = 0; private int searchesToReject = 0; @@ -731,7 +732,12 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { protected > void doExecute( Action action, Request request, ActionListener listener) { - lastHeaders.set(threadPool.getThreadContext().getHeaders()); + if (false == expectedHeaders.equals(threadPool().getThreadContext().getHeaders())) { + listener.onFailure( + new RuntimeException("Expected " + expectedHeaders + " but got " + threadPool().getThreadContext().getHeaders())); + return; + } + if (request instanceof ClearScrollRequest) { assertEquals(TaskId.EMPTY_TASK_ID, request.getParentTask()); } else { @@ -751,11 +757,14 @@ public class AsyncBulkByScrollActionTests extends ESTestCase { return; } if (request instanceof SearchScrollRequest) { + SearchScrollRequest scroll = (SearchScrollRequest) request; + boolean newRequest = usedScolls.add(scroll); + assertTrue("We can't reuse scroll requests", newRequest); if (scrollAttempts.incrementAndGet() <= scrollsToReject) { listener.onFailure(wrappedRejectedException()); return; } - lastScroll.set(new RequestAndListener<>((SearchScrollRequest) request, (ActionListener) listener)); + lastScroll.set(new RequestAndListener<>(scroll, (ActionListener) listener)); return; } if (request instanceof ClearScrollRequest) { diff --git a/test/framework/src/main/java/org/elasticsearch/test/client/NoOpClient.java b/test/framework/src/main/java/org/elasticsearch/test/client/NoOpClient.java index 5e1e1acd9ab..c393e19653f 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/client/NoOpClient.java +++ b/test/framework/src/main/java/org/elasticsearch/test/client/NoOpClient.java @@ -32,8 +32,20 @@ import org.elasticsearch.threadpool.ThreadPool; import java.util.concurrent.TimeUnit; +/** + * Client that always responds with {@code null} to every request. Override this for testing. + */ public class NoOpClient extends AbstractClient { + /** + * Build with {@link ThreadPool}. This {@linkplain ThreadPool} is terminated on {@link #close()}. + */ + public NoOpClient(ThreadPool threadPool) { + super(Settings.EMPTY, threadPool); + } + /** + * Create a new {@link TestThreadPool} for this client. + */ public NoOpClient(String testName) { super(Settings.EMPTY, new TestThreadPool(testName)); }