Improve async search's tasks cancellation (#53799)

This commit adds an explicit cancellation of the search task if
the initial async search submit task is cancelled (connection closed by the user).
This was previously done through the cancellation of the parent task but we don't
handle grand-children cancellation yet so we have to manually cancel the search task
in order to ensure that shard actions are cancelled too.
This change can be considered as a workaround until #50990 is fixed.
This commit is contained in:
Jim Ferenczi 2020-03-24 12:31:07 +01:00 committed by jimczi
parent 52062565a9
commit 0330bef409
10 changed files with 377 additions and 185 deletions

View File

@ -35,6 +35,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -42,6 +43,7 @@ import java.util.function.Supplier;
* Task that tracks the progress of a currently running {@link SearchRequest}. * Task that tracks the progress of a currently running {@link SearchRequest}.
*/ */
final class AsyncSearchTask extends SearchTask { final class AsyncSearchTask extends SearchTask {
private final BooleanSupplier checkSubmitCancellation;
private final AsyncSearchId searchId; private final AsyncSearchId searchId;
private final Client client; private final Client client;
private final ThreadPool threadPool; private final ThreadPool threadPool;
@ -68,6 +70,7 @@ final class AsyncSearchTask extends SearchTask {
* @param type The type of the task. * @param type The type of the task.
* @param action The action name. * @param action The action name.
* @param parentTaskId The parent task id. * @param parentTaskId The parent task id.
* @param checkSubmitCancellation A boolean supplier that checks if the submit task has been cancelled.
* @param originHeaders All the request context headers. * @param originHeaders All the request context headers.
* @param taskHeaders The filtered request headers for the task. * @param taskHeaders The filtered request headers for the task.
* @param searchId The {@link AsyncSearchId} of the task. * @param searchId The {@link AsyncSearchId} of the task.
@ -78,6 +81,7 @@ final class AsyncSearchTask extends SearchTask {
String type, String type,
String action, String action,
TaskId parentTaskId, TaskId parentTaskId,
BooleanSupplier checkSubmitCancellation,
TimeValue keepAlive, TimeValue keepAlive,
Map<String, String> originHeaders, Map<String, String> originHeaders,
Map<String, String> taskHeaders, Map<String, String> taskHeaders,
@ -86,6 +90,7 @@ final class AsyncSearchTask extends SearchTask {
ThreadPool threadPool, ThreadPool threadPool,
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) { Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) {
super(id, type, action, "async_search", parentTaskId, taskHeaders); super(id, type, action, "async_search", parentTaskId, taskHeaders);
this.checkSubmitCancellation = checkSubmitCancellation;
this.expirationTimeMillis = getStartTime() + keepAlive.getMillis(); this.expirationTimeMillis = getStartTime() + keepAlive.getMillis();
this.originHeaders = originHeaders; this.originHeaders = originHeaders;
this.searchId = searchId; this.searchId = searchId;
@ -212,13 +217,13 @@ final class AsyncSearchTask extends SearchTask {
final Cancellable cancellable; final Cancellable cancellable;
try { try {
cancellable = threadPool.schedule(() -> { cancellable = threadPool.schedule(threadPool.preserveContext(() -> {
if (hasRun.compareAndSet(false, true)) { if (hasRun.compareAndSet(false, true)) {
// timeout occurred before completion // timeout occurred before completion
removeCompletionListener(id); removeCompletionListener(id);
listener.onResponse(getResponse()); listener.onResponse(getResponse());
} }
}, waitForCompletion, "generic"); }), waitForCompletion, "generic");
} catch (EsRejectedExecutionException exc) { } catch (EsRejectedExecutionException exc) {
listener.onFailure(exc); listener.onFailure(exc);
return; return;
@ -291,10 +296,13 @@ final class AsyncSearchTask extends SearchTask {
return searchResponse.get().toAsyncSearchResponse(this, expirationTimeMillis); return searchResponse.get().toAsyncSearchResponse(this, expirationTimeMillis);
} }
// cancels the task if it expired // checks if the search task should be cancelled
private void checkExpiration() { private void checkCancellation() {
long now = System.currentTimeMillis(); long now = System.currentTimeMillis();
if (expirationTimeMillis < now) { if (expirationTimeMillis < now || checkSubmitCancellation.getAsBoolean()) {
// we cancel the search task if the initial submit task was cancelled,
// this is needed because the task cancellation mechanism doesn't
// handle the cancellation of grand-children.
cancelTask(() -> {}); cancelTask(() -> {});
} }
} }
@ -302,30 +310,31 @@ final class AsyncSearchTask extends SearchTask {
class Listener extends SearchProgressActionListener { class Listener extends SearchProgressActionListener {
@Override @Override
protected void onQueryResult(int shardIndex) { protected void onQueryResult(int shardIndex) {
checkExpiration(); checkCancellation();
} }
@Override @Override
protected void onFetchResult(int shardIndex) { protected void onFetchResult(int shardIndex) {
checkExpiration(); checkCancellation();
} }
@Override @Override
protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
// best effort to cancel expired tasks // best effort to cancel expired tasks
checkExpiration(); checkCancellation();
searchResponse.get().addShardFailure(shardIndex, new ShardSearchFailure(exc, shardTarget)); searchResponse.get().addShardFailure(shardIndex,
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null));
} }
@Override @Override
protected void onFetchFailure(int shardIndex, Exception exc) { protected void onFetchFailure(int shardIndex, Exception exc) {
checkExpiration(); checkCancellation();
} }
@Override @Override
protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped, Clusters clusters, boolean fetchPhase) { protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped, Clusters clusters, boolean fetchPhase) {
// best effort to cancel expired tasks // best effort to cancel expired tasks
checkExpiration(); checkCancellation();
searchResponse.compareAndSet(null, searchResponse.compareAndSet(null,
new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, aggReduceContextSupplier)); new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, aggReduceContextSupplier));
executeInitListeners(); executeInitListeners();
@ -334,7 +343,7 @@ final class AsyncSearchTask extends SearchTask {
@Override @Override
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
// best effort to cancel expired tasks // best effort to cancel expired tasks
checkExpiration(); checkCancellation();
searchResponse.get().updatePartialResponse(shards.size(), searchResponse.get().updatePartialResponse(shards.size(),
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs, new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
null, null, false, null, reducePhase), aggs == null); null, null, false, null, reducePhase), aggs == null);
@ -343,7 +352,7 @@ final class AsyncSearchTask extends SearchTask {
@Override @Override
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
// best effort to cancel expired tasks // best effort to cancel expired tasks
checkExpiration(); checkCancellation();
searchResponse.get().updatePartialResponse(shards.size(), searchResponse.get().updatePartialResponse(shards.size(),
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs, new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
null, null, false, null, reducePhase), true); null, null, false, null, reducePhase), true);

View File

@ -66,7 +66,7 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
@Override @Override
protected void doExecute(Task task, SubmitAsyncSearchRequest request, ActionListener<AsyncSearchResponse> submitListener) { protected void doExecute(Task task, SubmitAsyncSearchRequest request, ActionListener<AsyncSearchResponse> submitListener) {
CancellableTask submitTask = (CancellableTask) task; CancellableTask submitTask = (CancellableTask) task;
final SearchRequest searchRequest = createSearchRequest(request, submitTask.getId(), request.getKeepAlive()); final SearchRequest searchRequest = createSearchRequest(request, submitTask, request.getKeepAlive());
AsyncSearchTask searchTask = (AsyncSearchTask) taskManager.register("transport", SearchAction.INSTANCE.name(), searchRequest); AsyncSearchTask searchTask = (AsyncSearchTask) taskManager.register("transport", SearchAction.INSTANCE.name(), searchRequest);
searchAction.execute(searchTask, searchRequest, searchTask.getSearchProgressActionListener()); searchAction.execute(searchTask, searchRequest, searchTask.getSearchProgressActionListener());
searchTask.addCompletionListener( searchTask.addCompletionListener(
@ -81,7 +81,7 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
// the user cancelled the submit so we don't store anything // the user cancelled the submit so we don't store anything
// and propagate the failure // and propagate the failure
Exception cause = new TaskCancelledException(submitTask.getReasonCancelled()); Exception cause = new TaskCancelledException(submitTask.getReasonCancelled());
onFatalFailure(searchTask, cause, false, submitListener); onFatalFailure(searchTask, cause, searchResponse.isRunning(), submitListener);
} else { } else {
final String docId = searchTask.getSearchId().getDocId(); final String docId = searchTask.getSearchId().getDocId();
// creates the fallback response if the node crashes/restarts in the middle of the request // creates the fallback response if the node crashes/restarts in the middle of the request
@ -129,7 +129,7 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
}, request.getWaitForCompletion()); }, request.getWaitForCompletion());
} }
private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, long parentTaskId, TimeValue keepAlive) { private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, CancellableTask submitTask, TimeValue keepAlive) {
String docID = UUIDs.randomBase64UUID(); String docID = UUIDs.randomBase64UUID();
Map<String, String> originHeaders = nodeClient.threadPool().getThreadContext().getHeaders(); Map<String, String> originHeaders = nodeClient.threadPool().getThreadContext().getHeaders();
SearchRequest searchRequest = new SearchRequest(request.getSearchRequest()) { SearchRequest searchRequest = new SearchRequest(request.getSearchRequest()) {
@ -138,16 +138,17 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
AsyncSearchId searchId = new AsyncSearchId(docID, new TaskId(nodeClient.getLocalNodeId(), id)); AsyncSearchId searchId = new AsyncSearchId(docID, new TaskId(nodeClient.getLocalNodeId(), id));
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier = Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier =
() -> requestToAggReduceContextBuilder.apply(request.getSearchRequest()); () -> requestToAggReduceContextBuilder.apply(request.getSearchRequest());
return new AsyncSearchTask(id, type, action, parentTaskId, keepAlive, originHeaders, taskHeaders, searchId, return new AsyncSearchTask(id, type, action, parentTaskId,
store.getClient(), nodeClient.threadPool(), aggReduceContextSupplier); () -> submitTask.isCancelled(), keepAlive, originHeaders, taskHeaders, searchId, store.getClient(),
nodeClient.threadPool(), aggReduceContextSupplier);
} }
}; };
searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), parentTaskId)); searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), submitTask.getId()));
return searchRequest; return searchRequest;
} }
private void onFatalFailure(AsyncSearchTask task, Exception error, boolean shouldCancel, ActionListener<AsyncSearchResponse> listener) { private void onFatalFailure(AsyncSearchTask task, Exception error, boolean shouldCancel, ActionListener<AsyncSearchResponse> listener) {
if (shouldCancel) { if (shouldCancel && task.isCancelled() == false) {
task.cancelTask(() -> { task.cancelTask(() -> {
try { try {
task.addCompletionListener(finalResponse -> taskManager.unregister(task)); task.addCompletionListener(finalResponse -> taskManager.unregister(task));

View File

@ -251,4 +251,28 @@ public class AsyncSearchActionTests extends AsyncSearchIntegTestCase {
ElasticsearchException exc = response.getFailure(); ElasticsearchException exc = response.getFailure();
assertThat(exc.getMessage(), containsString("no such index")); assertThat(exc.getMessage(), containsString("no such index"));
} }
public void testCancellation() throws Exception {
SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(indexName);
request.getSearchRequest().source(
new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test"))
);
request.setWaitForCompletion(TimeValue.timeValueMillis(1));
AsyncSearchResponse response = submitAsyncSearch(request);
assertNotNull(response.getSearchResponse());
assertTrue(response.isRunning());
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));
response = getAsyncSearch(response.getId());
assertNotNull(response.getSearchResponse());
assertTrue(response.isRunning());
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));
deleteAsyncSearch(response.getId());
ensureTaskRemoval(response.getId());
}
} }

View File

@ -5,11 +5,7 @@
*/ */
package org.elasticsearch.xpack.search; package org.elasticsearch.xpack.search;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskResponse; import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskResponse;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
@ -18,22 +14,12 @@ import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.index.reindex.ReindexPlugin; import org.elasticsearch.index.reindex.ReindexPlugin;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskId;
@ -50,15 +36,12 @@ import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.ilm.IndexLifecycle; import org.elasticsearch.xpack.ilm.IndexLifecycle;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.Iterator; import java.util.Iterator;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -75,7 +58,7 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
@Override @Override
protected Collection<Class<? extends Plugin>> nodePlugins() { protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateCompositeXPackPlugin.class, AsyncSearch.class, IndexLifecycle.class, return Arrays.asList(LocalStateCompositeXPackPlugin.class, AsyncSearch.class, IndexLifecycle.class,
QueryBlockPlugin.class, ReindexPlugin.class); SearchTestPlugin.class, ReindexPlugin.class);
} }
@Override @Override
@ -165,14 +148,14 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
.collect( .collect(
Collectors.toMap( Collectors.toMap(
Function.identity(), Function.identity(),
id -> new ShardIdLatch(id, new CountDownLatch(1), failures.decrementAndGet() >= 0) id -> new ShardIdLatch(id, failures.decrementAndGet() >= 0)
) )
); );
ShardIdLatch[] shardLatchArray = shardLatchMap.values().stream() ShardIdLatch[] shardLatchArray = shardLatchMap.values().stream()
.sorted(Comparator.comparing(ShardIdLatch::shard)) .sorted(Comparator.comparing(ShardIdLatch::shardId))
.toArray(ShardIdLatch[]::new); .toArray(ShardIdLatch[]::new);
resetPluginsLatch(shardLatchMap); resetPluginsLatch(shardLatchMap);
request.getSearchRequest().source().query(new BlockQueryBuilder(shardLatchMap)); request.getSearchRequest().source().query(new BlockingQueryBuilder(shardLatchMap));
final AsyncSearchResponse initial = client().execute(SubmitAsyncSearchAction.INSTANCE, request).get(); final AsyncSearchResponse initial = client().execute(SubmitAsyncSearchAction.INSTANCE, request).get();
@ -210,7 +193,7 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
int step = shardIndex == 0 ? progressStep+1 : progressStep-1; int step = shardIndex == 0 ? progressStep+1 : progressStep-1;
int index = 0; int index = 0;
while (index < step && shardIndex < shardLatchArray.length) { while (index < step && shardIndex < shardLatchArray.length) {
if (shardLatchArray[shardIndex].shouldFail == false) { if (shardLatchArray[shardIndex].shouldFail() == false) {
++index; ++index;
} }
shardLatchArray[shardIndex++].countDown(); shardLatchArray[shardIndex++].countDown();
@ -255,8 +238,8 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
@Override @Override
public void close() { public void close() {
Arrays.stream(shardLatchArray).forEach(shard -> { Arrays.stream(shardLatchArray).forEach(shard -> {
if (shard.latch.getCount() == 1) { if (shard.getCount() == 1) {
shard.latch.countDown(); shard.countDown();
} }
}); });
} }
@ -265,143 +248,7 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
private void resetPluginsLatch(Map<ShardId, ShardIdLatch> newLatch) { private void resetPluginsLatch(Map<ShardId, ShardIdLatch> newLatch) {
for (PluginsService pluginsService : internalCluster().getDataNodeInstances(PluginsService.class)) { for (PluginsService pluginsService : internalCluster().getDataNodeInstances(PluginsService.class)) {
pluginsService.filterPlugins(QueryBlockPlugin.class).forEach(p -> p.reset(newLatch)); pluginsService.filterPlugins(SearchTestPlugin.class).forEach(p -> p.resetQueryLatch(newLatch));
}
}
public static class QueryBlockPlugin extends Plugin implements SearchPlugin {
private Map<ShardId, ShardIdLatch> shardsLatch;
public QueryBlockPlugin() {
this.shardsLatch = null;
}
public void reset(Map<ShardId, ShardIdLatch> newLatch) {
shardsLatch = newLatch;
}
@Override
public List<QuerySpec<?>> getQueries() {
return Collections.singletonList(
new QuerySpec<>("block_match_all",
in -> new BlockQueryBuilder(in, shardsLatch),
p -> BlockQueryBuilder.fromXContent(p, shardsLatch))
);
}
}
private static class BlockQueryBuilder extends AbstractQueryBuilder<BlockQueryBuilder> {
public static final String NAME = "block_match_all";
private final Map<ShardId, ShardIdLatch> shardsLatch;
private BlockQueryBuilder(Map<ShardId, ShardIdLatch> shardsLatch) {
super();
this.shardsLatch = shardsLatch;
}
BlockQueryBuilder(StreamInput in, Map<ShardId, ShardIdLatch> shardsLatch) throws IOException {
super(in);
this.shardsLatch = shardsLatch;
}
private BlockQueryBuilder() {
this.shardsLatch = null;
}
@Override
protected void doWriteTo(StreamOutput out) {}
@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.endObject();
}
private static final ObjectParser<BlockQueryBuilder, Void> PARSER = new ObjectParser<>(NAME, BlockQueryBuilder::new);
public static BlockQueryBuilder fromXContent(XContentParser parser, Map<ShardId, ShardIdLatch> shardsLatch) {
try {
PARSER.apply(parser, null);
return new BlockQueryBuilder(shardsLatch);
} catch (IllegalArgumentException e) {
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
}
}
@Override
protected Query doToQuery(QueryShardContext context) {
final Query delegate = Queries.newMatchAllQuery();
return new Query() {
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
if (shardsLatch != null) {
try {
final ShardIdLatch latch = shardsLatch.get(new ShardId(context.index(), context.getShardId()));
latch.await();
if (latch.shouldFail) {
throw new IOException("boum");
}
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
return delegate.createWeight(searcher, scoreMode, boost);
}
@Override
public String toString(String field) {
return delegate.toString(field);
}
@Override
public boolean equals(Object obj) {
return false;
}
@Override
public int hashCode() {
return 0;
}
};
}
@Override
protected boolean doEquals(BlockQueryBuilder other) {
return false;
}
@Override
protected int doHashCode() {
return 0;
}
@Override
public String getWriteableName() {
return NAME;
}
}
private static class ShardIdLatch {
private final ShardId shard;
private final CountDownLatch latch;
private final boolean shouldFail;
private ShardIdLatch(ShardId shard, CountDownLatch latch, boolean shouldFail) {
this.shard = shard;
this.latch = latch;
this.shouldFail = shouldFail;
}
ShardId shard() {
return shard;
}
void countDown() {
latch.countDown();
}
void await() throws InterruptedException {
latch.await();
} }
} }
} }

View File

@ -15,6 +15,7 @@ import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.internal.InternalSearchResponse; import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.test.client.NoOpClient;
@ -27,6 +28,7 @@ import org.junit.Before;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -35,6 +37,23 @@ import static org.hamcrest.Matchers.equalTo;
public class AsyncSearchTaskTests extends ESTestCase { public class AsyncSearchTaskTests extends ESTestCase {
private ThreadPool threadPool; private ThreadPool threadPool;
private static class TestTask extends CancellableTask {
private TestTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
super(id, type, action, description, parentTaskId, headers);
}
@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;
}
}
private static TestTask createSubmitTask() {
return new TestTask(0L, "", "", "test", new TaskId("node1", 0), Collections.emptyMap());
}
@Before @Before
public void beforeTest() { public void beforeTest() {
threadPool = new TestThreadPool(getTestName()); threadPool = new TestThreadPool(getTestName());
@ -46,7 +65,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
} }
public void testWaitForInit() throws InterruptedException { public void testWaitForInit() throws InterruptedException {
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), TimeValue.timeValueHours(1), AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)), Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)),
new NoOpClient(threadPool), threadPool, null); new NoOpClient(threadPool), threadPool, null);
int numShards = randomIntBetween(0, 10); int numShards = randomIntBetween(0, 10);
@ -86,7 +105,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
} }
public void testWithFailure() throws InterruptedException { public void testWithFailure() throws InterruptedException {
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), TimeValue.timeValueHours(1), AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)), Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)),
new NoOpClient(threadPool), threadPool, null); new NoOpClient(threadPool), threadPool, null);
int numThreads = randomIntBetween(1, 10); int numThreads = randomIntBetween(1, 10);
@ -114,7 +133,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
} }
public void testWaitForCompletion() throws InterruptedException { public void testWaitForCompletion() throws InterruptedException {
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), TimeValue.timeValueHours(1), AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)), Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)),
new NoOpClient(threadPool), threadPool, null); new NoOpClient(threadPool), threadPool, null);
int numShards = randomIntBetween(0, 10); int numShards = randomIntBetween(0, 10);

View File

@ -0,0 +1,118 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.search;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.index.shard.ShardId;
import java.io.IOException;
import java.util.Map;
/**
* A query builder that blocks shard execution based on the provided {@link ShardIdLatch}.
*/
class BlockingQueryBuilder extends AbstractQueryBuilder<BlockingQueryBuilder> {
public static final String NAME = "block";
private final Map<ShardId, ShardIdLatch> shardsLatch;
BlockingQueryBuilder(Map<ShardId, ShardIdLatch> shardsLatch) {
super();
this.shardsLatch = shardsLatch;
}
BlockingQueryBuilder(StreamInput in, Map<ShardId, ShardIdLatch> shardsLatch) throws IOException {
super(in);
this.shardsLatch = shardsLatch;
}
BlockingQueryBuilder() {
this.shardsLatch = null;
}
@Override
protected void doWriteTo(StreamOutput out) {}
@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.endObject();
}
private static final ObjectParser<BlockingQueryBuilder, Void> PARSER = new ObjectParser<>(NAME, BlockingQueryBuilder::new);
public static BlockingQueryBuilder fromXContent(XContentParser parser, Map<ShardId, ShardIdLatch> shardsLatch) {
try {
PARSER.apply(parser, null);
return new BlockingQueryBuilder(shardsLatch);
} catch (IllegalArgumentException e) {
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
}
}
@Override
protected Query doToQuery(QueryShardContext context) {
final Query delegate = Queries.newMatchAllQuery();
return new Query() {
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
if (shardsLatch != null) {
try {
final ShardIdLatch latch = shardsLatch.get(new ShardId(context.index(), context.getShardId()));
latch.await();
if (latch.shouldFail()) {
throw new IOException("boum");
}
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
return delegate.createWeight(searcher, scoreMode, boost);
}
@Override
public String toString(String field) {
return delegate.toString(field);
}
@Override
public boolean equals(Object obj) {
return false;
}
@Override
public int hashCode() {
return 0;
}
};
}
@Override
protected boolean doEquals(BlockingQueryBuilder other) {
return false;
}
@Override
protected int doHashCode() {
return 0;
}
@Override
public String getWriteableName() {
return NAME;
}
}

View File

@ -0,0 +1,102 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.search;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.List;
import java.util.Map;
/**
* An aggregation builder that blocks shard search action until the task is cancelled.
*/
public class CancellingAggregationBuilder extends AbstractAggregationBuilder<CancellingAggregationBuilder> {
static final String NAME = "cancel";
static final int SLEEP_TIME = 10;
public CancellingAggregationBuilder(String name) {
super(name);
}
public CancellingAggregationBuilder(StreamInput in) throws IOException {
super(in);
}
@Override
protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map<String, Object> metaData) {
return new CancellingAggregationBuilder(name);
}
@Override
public String getType() {
return NAME;
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
}
@Override
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}
static final ConstructingObjectParser<CancellingAggregationBuilder, String> PARSER =
new ConstructingObjectParser<>(NAME, false, (args, name) -> new CancellingAggregationBuilder(name));
static CancellingAggregationBuilder fromXContent(String aggName, XContentParser parser) {
try {
return PARSER.apply(parser, aggName);
} catch (IllegalArgumentException e) {
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
}
}
@Override
@SuppressWarnings("unchecked")
protected AggregatorFactory doBuild(QueryShardContext queryShardContext, AggregatorFactory parent,
AggregatorFactories.Builder subfactoriesBuilder) throws IOException {
final FilterAggregationBuilder filterAgg = new FilterAggregationBuilder(name, QueryBuilders.matchAllQuery());
filterAgg.subAggregations(subfactoriesBuilder);
final AggregatorFactory factory = filterAgg.build(queryShardContext, parent);
return new AggregatorFactory(name, queryShardContext, parent, subfactoriesBuilder, metaData) {
@Override
protected Aggregator createInternal(SearchContext searchContext,
Aggregator parent,
boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
while (searchContext.isCancelled() == false) {
try {
Thread.sleep(SLEEP_TIME);
} catch (InterruptedException e) {
throw new IOException(e);
}
}
return factory.create(searchContext, parent, collectsFromSingleBucket);
}
};
}
}

View File

@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.search;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.aggregations.bucket.filter.InternalFilter;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class SearchTestPlugin extends Plugin implements SearchPlugin {
private Map<ShardId, ShardIdLatch> shardsLatch;
public SearchTestPlugin() {
this.shardsLatch = null;
}
public void resetQueryLatch(Map<ShardId, ShardIdLatch> newLatch) {
shardsLatch = newLatch;
}
@Override
public List<QuerySpec<?>> getQueries() {
return Collections.singletonList(
new QuerySpec<>(BlockingQueryBuilder.NAME,
in -> new BlockingQueryBuilder(in, shardsLatch),
p -> BlockingQueryBuilder.fromXContent(p, shardsLatch))
);
}
@Override
public List<AggregationSpec> getAggregations() {
return Collections.singletonList(new AggregationSpec(CancellingAggregationBuilder.NAME, CancellingAggregationBuilder::new,
CancellingAggregationBuilder.PARSER).addResultReader(InternalFilter::new));
}
}

View File

@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.search;
import org.elasticsearch.index.shard.ShardId;
import java.util.concurrent.CountDownLatch;
class ShardIdLatch extends CountDownLatch {
private final ShardId shard;
private final boolean shouldFail;
ShardIdLatch(ShardId shard, boolean shouldFail) {
super(1);
this.shard = shard;
this.shouldFail = shouldFail;
}
ShardId shardId() {
return shard;
}
boolean shouldFail() {
return shouldFail;
}
}

View File

@ -150,7 +150,8 @@ public class SubmitAsyncSearchRequest extends ActionRequest {
return new CancellableTask(id, type, action, toString(), parentTaskId, headers) { return new CancellableTask(id, type, action, toString(), parentTaskId, headers) {
@Override @Override
public boolean shouldCancelChildrenOnCancellation() { public boolean shouldCancelChildrenOnCancellation() {
return true; // we cancel the underlying search action explicitly in the submit action
return false;
} }
}; };
} }