Improve async search's tasks cancellation (#53799)
This commit adds an explicit cancellation of the search task if the initial async search submit task is cancelled (connection closed by the user). This was previously done through the cancellation of the parent task but we don't handle grand-children cancellation yet so we have to manually cancel the search task in order to ensure that shard actions are cancelled too. This change can be considered as a workaround until #50990 is fixed.
This commit is contained in:
parent
52062565a9
commit
0330bef409
|
@ -35,6 +35,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.BooleanSupplier;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
|
@ -42,6 +43,7 @@ import java.util.function.Supplier;
|
|||
* Task that tracks the progress of a currently running {@link SearchRequest}.
|
||||
*/
|
||||
final class AsyncSearchTask extends SearchTask {
|
||||
private final BooleanSupplier checkSubmitCancellation;
|
||||
private final AsyncSearchId searchId;
|
||||
private final Client client;
|
||||
private final ThreadPool threadPool;
|
||||
|
@ -68,6 +70,7 @@ final class AsyncSearchTask extends SearchTask {
|
|||
* @param type The type of the task.
|
||||
* @param action The action name.
|
||||
* @param parentTaskId The parent task id.
|
||||
* @param checkSubmitCancellation A boolean supplier that checks if the submit task has been cancelled.
|
||||
* @param originHeaders All the request context headers.
|
||||
* @param taskHeaders The filtered request headers for the task.
|
||||
* @param searchId The {@link AsyncSearchId} of the task.
|
||||
|
@ -78,6 +81,7 @@ final class AsyncSearchTask extends SearchTask {
|
|||
String type,
|
||||
String action,
|
||||
TaskId parentTaskId,
|
||||
BooleanSupplier checkSubmitCancellation,
|
||||
TimeValue keepAlive,
|
||||
Map<String, String> originHeaders,
|
||||
Map<String, String> taskHeaders,
|
||||
|
@ -86,6 +90,7 @@ final class AsyncSearchTask extends SearchTask {
|
|||
ThreadPool threadPool,
|
||||
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) {
|
||||
super(id, type, action, "async_search", parentTaskId, taskHeaders);
|
||||
this.checkSubmitCancellation = checkSubmitCancellation;
|
||||
this.expirationTimeMillis = getStartTime() + keepAlive.getMillis();
|
||||
this.originHeaders = originHeaders;
|
||||
this.searchId = searchId;
|
||||
|
@ -212,13 +217,13 @@ final class AsyncSearchTask extends SearchTask {
|
|||
|
||||
final Cancellable cancellable;
|
||||
try {
|
||||
cancellable = threadPool.schedule(() -> {
|
||||
cancellable = threadPool.schedule(threadPool.preserveContext(() -> {
|
||||
if (hasRun.compareAndSet(false, true)) {
|
||||
// timeout occurred before completion
|
||||
removeCompletionListener(id);
|
||||
listener.onResponse(getResponse());
|
||||
}
|
||||
}, waitForCompletion, "generic");
|
||||
}), waitForCompletion, "generic");
|
||||
} catch (EsRejectedExecutionException exc) {
|
||||
listener.onFailure(exc);
|
||||
return;
|
||||
|
@ -291,10 +296,13 @@ final class AsyncSearchTask extends SearchTask {
|
|||
return searchResponse.get().toAsyncSearchResponse(this, expirationTimeMillis);
|
||||
}
|
||||
|
||||
// cancels the task if it expired
|
||||
private void checkExpiration() {
|
||||
// checks if the search task should be cancelled
|
||||
private void checkCancellation() {
|
||||
long now = System.currentTimeMillis();
|
||||
if (expirationTimeMillis < now) {
|
||||
if (expirationTimeMillis < now || checkSubmitCancellation.getAsBoolean()) {
|
||||
// we cancel the search task if the initial submit task was cancelled,
|
||||
// this is needed because the task cancellation mechanism doesn't
|
||||
// handle the cancellation of grand-children.
|
||||
cancelTask(() -> {});
|
||||
}
|
||||
}
|
||||
|
@ -302,30 +310,31 @@ final class AsyncSearchTask extends SearchTask {
|
|||
class Listener extends SearchProgressActionListener {
|
||||
@Override
|
||||
protected void onQueryResult(int shardIndex) {
|
||||
checkExpiration();
|
||||
checkCancellation();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onFetchResult(int shardIndex) {
|
||||
checkExpiration();
|
||||
checkCancellation();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
|
||||
// best effort to cancel expired tasks
|
||||
checkExpiration();
|
||||
searchResponse.get().addShardFailure(shardIndex, new ShardSearchFailure(exc, shardTarget));
|
||||
checkCancellation();
|
||||
searchResponse.get().addShardFailure(shardIndex,
|
||||
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onFetchFailure(int shardIndex, Exception exc) {
|
||||
checkExpiration();
|
||||
checkCancellation();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped, Clusters clusters, boolean fetchPhase) {
|
||||
// best effort to cancel expired tasks
|
||||
checkExpiration();
|
||||
checkCancellation();
|
||||
searchResponse.compareAndSet(null,
|
||||
new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, aggReduceContextSupplier));
|
||||
executeInitListeners();
|
||||
|
@ -334,7 +343,7 @@ final class AsyncSearchTask extends SearchTask {
|
|||
@Override
|
||||
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
|
||||
// best effort to cancel expired tasks
|
||||
checkExpiration();
|
||||
checkCancellation();
|
||||
searchResponse.get().updatePartialResponse(shards.size(),
|
||||
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
|
||||
null, null, false, null, reducePhase), aggs == null);
|
||||
|
@ -343,7 +352,7 @@ final class AsyncSearchTask extends SearchTask {
|
|||
@Override
|
||||
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
|
||||
// best effort to cancel expired tasks
|
||||
checkExpiration();
|
||||
checkCancellation();
|
||||
searchResponse.get().updatePartialResponse(shards.size(),
|
||||
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
|
||||
null, null, false, null, reducePhase), true);
|
||||
|
|
|
@ -66,7 +66,7 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
|
|||
@Override
|
||||
protected void doExecute(Task task, SubmitAsyncSearchRequest request, ActionListener<AsyncSearchResponse> submitListener) {
|
||||
CancellableTask submitTask = (CancellableTask) task;
|
||||
final SearchRequest searchRequest = createSearchRequest(request, submitTask.getId(), request.getKeepAlive());
|
||||
final SearchRequest searchRequest = createSearchRequest(request, submitTask, request.getKeepAlive());
|
||||
AsyncSearchTask searchTask = (AsyncSearchTask) taskManager.register("transport", SearchAction.INSTANCE.name(), searchRequest);
|
||||
searchAction.execute(searchTask, searchRequest, searchTask.getSearchProgressActionListener());
|
||||
searchTask.addCompletionListener(
|
||||
|
@ -81,7 +81,7 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
|
|||
// the user cancelled the submit so we don't store anything
|
||||
// and propagate the failure
|
||||
Exception cause = new TaskCancelledException(submitTask.getReasonCancelled());
|
||||
onFatalFailure(searchTask, cause, false, submitListener);
|
||||
onFatalFailure(searchTask, cause, searchResponse.isRunning(), submitListener);
|
||||
} else {
|
||||
final String docId = searchTask.getSearchId().getDocId();
|
||||
// creates the fallback response if the node crashes/restarts in the middle of the request
|
||||
|
@ -129,7 +129,7 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
|
|||
}, request.getWaitForCompletion());
|
||||
}
|
||||
|
||||
private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, long parentTaskId, TimeValue keepAlive) {
|
||||
private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, CancellableTask submitTask, TimeValue keepAlive) {
|
||||
String docID = UUIDs.randomBase64UUID();
|
||||
Map<String, String> originHeaders = nodeClient.threadPool().getThreadContext().getHeaders();
|
||||
SearchRequest searchRequest = new SearchRequest(request.getSearchRequest()) {
|
||||
|
@ -138,16 +138,17 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
|
|||
AsyncSearchId searchId = new AsyncSearchId(docID, new TaskId(nodeClient.getLocalNodeId(), id));
|
||||
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier =
|
||||
() -> requestToAggReduceContextBuilder.apply(request.getSearchRequest());
|
||||
return new AsyncSearchTask(id, type, action, parentTaskId, keepAlive, originHeaders, taskHeaders, searchId,
|
||||
store.getClient(), nodeClient.threadPool(), aggReduceContextSupplier);
|
||||
return new AsyncSearchTask(id, type, action, parentTaskId,
|
||||
() -> submitTask.isCancelled(), keepAlive, originHeaders, taskHeaders, searchId, store.getClient(),
|
||||
nodeClient.threadPool(), aggReduceContextSupplier);
|
||||
}
|
||||
};
|
||||
searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), parentTaskId));
|
||||
searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), submitTask.getId()));
|
||||
return searchRequest;
|
||||
}
|
||||
|
||||
private void onFatalFailure(AsyncSearchTask task, Exception error, boolean shouldCancel, ActionListener<AsyncSearchResponse> listener) {
|
||||
if (shouldCancel) {
|
||||
if (shouldCancel && task.isCancelled() == false) {
|
||||
task.cancelTask(() -> {
|
||||
try {
|
||||
task.addCompletionListener(finalResponse -> taskManager.unregister(task));
|
||||
|
|
|
@ -251,4 +251,28 @@ public class AsyncSearchActionTests extends AsyncSearchIntegTestCase {
|
|||
ElasticsearchException exc = response.getFailure();
|
||||
assertThat(exc.getMessage(), containsString("no such index"));
|
||||
}
|
||||
|
||||
public void testCancellation() throws Exception {
|
||||
SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(indexName);
|
||||
request.getSearchRequest().source(
|
||||
new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test"))
|
||||
);
|
||||
request.setWaitForCompletion(TimeValue.timeValueMillis(1));
|
||||
AsyncSearchResponse response = submitAsyncSearch(request);
|
||||
assertNotNull(response.getSearchResponse());
|
||||
assertTrue(response.isRunning());
|
||||
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
|
||||
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
|
||||
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));
|
||||
|
||||
response = getAsyncSearch(response.getId());
|
||||
assertNotNull(response.getSearchResponse());
|
||||
assertTrue(response.isRunning());
|
||||
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
|
||||
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
|
||||
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));
|
||||
|
||||
deleteAsyncSearch(response.getId());
|
||||
ensureTaskRemoval(response.getId());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,11 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.search;
|
||||
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.elasticsearch.ResourceNotFoundException;
|
||||
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskResponse;
|
||||
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
|
||||
|
@ -18,22 +14,12 @@ import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
|
|||
import org.elasticsearch.action.get.GetResponse;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
import org.elasticsearch.common.ParsingException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.lucene.search.Queries;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.AbstractQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryShardContext;
|
||||
import org.elasticsearch.index.reindex.ReindexPlugin;
|
||||
import org.elasticsearch.index.shard.ShardId;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.plugins.PluginsService;
|
||||
import org.elasticsearch.plugins.SearchPlugin;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.tasks.TaskId;
|
||||
|
@ -50,15 +36,12 @@ import org.elasticsearch.xpack.core.XPackSettings;
|
|||
import org.elasticsearch.xpack.ilm.IndexLifecycle;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
@ -75,7 +58,7 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
|
|||
@Override
|
||||
protected Collection<Class<? extends Plugin>> nodePlugins() {
|
||||
return Arrays.asList(LocalStateCompositeXPackPlugin.class, AsyncSearch.class, IndexLifecycle.class,
|
||||
QueryBlockPlugin.class, ReindexPlugin.class);
|
||||
SearchTestPlugin.class, ReindexPlugin.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -165,14 +148,14 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
|
|||
.collect(
|
||||
Collectors.toMap(
|
||||
Function.identity(),
|
||||
id -> new ShardIdLatch(id, new CountDownLatch(1), failures.decrementAndGet() >= 0)
|
||||
id -> new ShardIdLatch(id, failures.decrementAndGet() >= 0)
|
||||
)
|
||||
);
|
||||
ShardIdLatch[] shardLatchArray = shardLatchMap.values().stream()
|
||||
.sorted(Comparator.comparing(ShardIdLatch::shard))
|
||||
.sorted(Comparator.comparing(ShardIdLatch::shardId))
|
||||
.toArray(ShardIdLatch[]::new);
|
||||
resetPluginsLatch(shardLatchMap);
|
||||
request.getSearchRequest().source().query(new BlockQueryBuilder(shardLatchMap));
|
||||
request.getSearchRequest().source().query(new BlockingQueryBuilder(shardLatchMap));
|
||||
|
||||
final AsyncSearchResponse initial = client().execute(SubmitAsyncSearchAction.INSTANCE, request).get();
|
||||
|
||||
|
@ -210,7 +193,7 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
|
|||
int step = shardIndex == 0 ? progressStep+1 : progressStep-1;
|
||||
int index = 0;
|
||||
while (index < step && shardIndex < shardLatchArray.length) {
|
||||
if (shardLatchArray[shardIndex].shouldFail == false) {
|
||||
if (shardLatchArray[shardIndex].shouldFail() == false) {
|
||||
++index;
|
||||
}
|
||||
shardLatchArray[shardIndex++].countDown();
|
||||
|
@ -255,8 +238,8 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
|
|||
@Override
|
||||
public void close() {
|
||||
Arrays.stream(shardLatchArray).forEach(shard -> {
|
||||
if (shard.latch.getCount() == 1) {
|
||||
shard.latch.countDown();
|
||||
if (shard.getCount() == 1) {
|
||||
shard.countDown();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -265,143 +248,7 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
|
|||
|
||||
private void resetPluginsLatch(Map<ShardId, ShardIdLatch> newLatch) {
|
||||
for (PluginsService pluginsService : internalCluster().getDataNodeInstances(PluginsService.class)) {
|
||||
pluginsService.filterPlugins(QueryBlockPlugin.class).forEach(p -> p.reset(newLatch));
|
||||
}
|
||||
}
|
||||
|
||||
public static class QueryBlockPlugin extends Plugin implements SearchPlugin {
|
||||
private Map<ShardId, ShardIdLatch> shardsLatch;
|
||||
|
||||
public QueryBlockPlugin() {
|
||||
this.shardsLatch = null;
|
||||
}
|
||||
|
||||
public void reset(Map<ShardId, ShardIdLatch> newLatch) {
|
||||
shardsLatch = newLatch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QuerySpec<?>> getQueries() {
|
||||
return Collections.singletonList(
|
||||
new QuerySpec<>("block_match_all",
|
||||
in -> new BlockQueryBuilder(in, shardsLatch),
|
||||
p -> BlockQueryBuilder.fromXContent(p, shardsLatch))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private static class BlockQueryBuilder extends AbstractQueryBuilder<BlockQueryBuilder> {
|
||||
public static final String NAME = "block_match_all";
|
||||
private final Map<ShardId, ShardIdLatch> shardsLatch;
|
||||
|
||||
private BlockQueryBuilder(Map<ShardId, ShardIdLatch> shardsLatch) {
|
||||
super();
|
||||
this.shardsLatch = shardsLatch;
|
||||
}
|
||||
|
||||
BlockQueryBuilder(StreamInput in, Map<ShardId, ShardIdLatch> shardsLatch) throws IOException {
|
||||
super(in);
|
||||
this.shardsLatch = shardsLatch;
|
||||
}
|
||||
|
||||
private BlockQueryBuilder() {
|
||||
this.shardsLatch = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doWriteTo(StreamOutput out) {}
|
||||
|
||||
@Override
|
||||
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject(NAME);
|
||||
builder.endObject();
|
||||
}
|
||||
|
||||
private static final ObjectParser<BlockQueryBuilder, Void> PARSER = new ObjectParser<>(NAME, BlockQueryBuilder::new);
|
||||
|
||||
public static BlockQueryBuilder fromXContent(XContentParser parser, Map<ShardId, ShardIdLatch> shardsLatch) {
|
||||
try {
|
||||
PARSER.apply(parser, null);
|
||||
return new BlockQueryBuilder(shardsLatch);
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Query doToQuery(QueryShardContext context) {
|
||||
final Query delegate = Queries.newMatchAllQuery();
|
||||
return new Query() {
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
|
||||
if (shardsLatch != null) {
|
||||
try {
|
||||
final ShardIdLatch latch = shardsLatch.get(new ShardId(context.index(), context.getShardId()));
|
||||
latch.await();
|
||||
if (latch.shouldFail) {
|
||||
throw new IOException("boum");
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
return delegate.createWeight(searcher, scoreMode, boost);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return delegate.toString(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean doEquals(BlockQueryBuilder other) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int doHashCode() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
}
|
||||
|
||||
private static class ShardIdLatch {
|
||||
private final ShardId shard;
|
||||
private final CountDownLatch latch;
|
||||
private final boolean shouldFail;
|
||||
|
||||
private ShardIdLatch(ShardId shard, CountDownLatch latch, boolean shouldFail) {
|
||||
this.shard = shard;
|
||||
this.latch = latch;
|
||||
this.shouldFail = shouldFail;
|
||||
}
|
||||
|
||||
ShardId shard() {
|
||||
return shard;
|
||||
}
|
||||
|
||||
void countDown() {
|
||||
latch.countDown();
|
||||
}
|
||||
|
||||
void await() throws InterruptedException {
|
||||
latch.await();
|
||||
pluginsService.filterPlugins(SearchTestPlugin.class).forEach(p -> p.resetQueryLatch(newLatch));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.index.shard.ShardId;
|
|||
import org.elasticsearch.search.SearchHits;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregations;
|
||||
import org.elasticsearch.search.internal.InternalSearchResponse;
|
||||
import org.elasticsearch.tasks.CancellableTask;
|
||||
import org.elasticsearch.tasks.TaskId;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.client.NoOpClient;
|
||||
|
@ -27,6 +28,7 @@ import org.junit.Before;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
|
@ -35,6 +37,23 @@ import static org.hamcrest.Matchers.equalTo;
|
|||
public class AsyncSearchTaskTests extends ESTestCase {
|
||||
private ThreadPool threadPool;
|
||||
|
||||
private static class TestTask extends CancellableTask {
|
||||
private TestTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
|
||||
super(id, type, action, description, parentTaskId, headers);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean shouldCancelChildrenOnCancellation() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static TestTask createSubmitTask() {
|
||||
return new TestTask(0L, "", "", "test", new TaskId("node1", 0), Collections.emptyMap());
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Before
|
||||
public void beforeTest() {
|
||||
threadPool = new TestThreadPool(getTestName());
|
||||
|
@ -46,7 +65,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testWaitForInit() throws InterruptedException {
|
||||
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), TimeValue.timeValueHours(1),
|
||||
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
|
||||
Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)),
|
||||
new NoOpClient(threadPool), threadPool, null);
|
||||
int numShards = randomIntBetween(0, 10);
|
||||
|
@ -86,7 +105,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testWithFailure() throws InterruptedException {
|
||||
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), TimeValue.timeValueHours(1),
|
||||
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
|
||||
Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)),
|
||||
new NoOpClient(threadPool), threadPool, null);
|
||||
int numThreads = randomIntBetween(1, 10);
|
||||
|
@ -114,7 +133,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testWaitForCompletion() throws InterruptedException {
|
||||
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), TimeValue.timeValueHours(1),
|
||||
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
|
||||
Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)),
|
||||
new NoOpClient(threadPool), threadPool, null);
|
||||
int numShards = randomIntBetween(0, 10);
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.search;
|
||||
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.elasticsearch.common.ParsingException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.lucene.search.Queries;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.AbstractQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryShardContext;
|
||||
import org.elasticsearch.index.shard.ShardId;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A query builder that blocks shard execution based on the provided {@link ShardIdLatch}.
|
||||
*/
|
||||
class BlockingQueryBuilder extends AbstractQueryBuilder<BlockingQueryBuilder> {
|
||||
public static final String NAME = "block";
|
||||
private final Map<ShardId, ShardIdLatch> shardsLatch;
|
||||
|
||||
BlockingQueryBuilder(Map<ShardId, ShardIdLatch> shardsLatch) {
|
||||
super();
|
||||
this.shardsLatch = shardsLatch;
|
||||
}
|
||||
|
||||
BlockingQueryBuilder(StreamInput in, Map<ShardId, ShardIdLatch> shardsLatch) throws IOException {
|
||||
super(in);
|
||||
this.shardsLatch = shardsLatch;
|
||||
}
|
||||
|
||||
BlockingQueryBuilder() {
|
||||
this.shardsLatch = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doWriteTo(StreamOutput out) {}
|
||||
|
||||
@Override
|
||||
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject(NAME);
|
||||
builder.endObject();
|
||||
}
|
||||
|
||||
private static final ObjectParser<BlockingQueryBuilder, Void> PARSER = new ObjectParser<>(NAME, BlockingQueryBuilder::new);
|
||||
|
||||
public static BlockingQueryBuilder fromXContent(XContentParser parser, Map<ShardId, ShardIdLatch> shardsLatch) {
|
||||
try {
|
||||
PARSER.apply(parser, null);
|
||||
return new BlockingQueryBuilder(shardsLatch);
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Query doToQuery(QueryShardContext context) {
|
||||
final Query delegate = Queries.newMatchAllQuery();
|
||||
return new Query() {
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
|
||||
if (shardsLatch != null) {
|
||||
try {
|
||||
final ShardIdLatch latch = shardsLatch.get(new ShardId(context.index(), context.getShardId()));
|
||||
latch.await();
|
||||
if (latch.shouldFail()) {
|
||||
throw new IOException("boum");
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
return delegate.createWeight(searcher, scoreMode, boost);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return delegate.toString(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean doEquals(BlockingQueryBuilder other) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int doHashCode() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.search;
|
||||
|
||||
import org.elasticsearch.common.ParsingException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.index.query.QueryShardContext;
|
||||
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregator;
|
||||
import org.elasticsearch.search.aggregations.AggregatorFactories;
|
||||
import org.elasticsearch.search.aggregations.AggregatorFactory;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.search.internal.SearchContext;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* An aggregation builder that blocks shard search action until the task is cancelled.
|
||||
*/
|
||||
public class CancellingAggregationBuilder extends AbstractAggregationBuilder<CancellingAggregationBuilder> {
|
||||
static final String NAME = "cancel";
|
||||
static final int SLEEP_TIME = 10;
|
||||
|
||||
public CancellingAggregationBuilder(String name) {
|
||||
super(name);
|
||||
}
|
||||
|
||||
public CancellingAggregationBuilder(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map<String, Object> metaData) {
|
||||
return new CancellingAggregationBuilder(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getType() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doWriteTo(StreamOutput out) throws IOException {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
static final ConstructingObjectParser<CancellingAggregationBuilder, String> PARSER =
|
||||
new ConstructingObjectParser<>(NAME, false, (args, name) -> new CancellingAggregationBuilder(name));
|
||||
|
||||
|
||||
static CancellingAggregationBuilder fromXContent(String aggName, XContentParser parser) {
|
||||
try {
|
||||
return PARSER.apply(parser, aggName);
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
protected AggregatorFactory doBuild(QueryShardContext queryShardContext, AggregatorFactory parent,
|
||||
AggregatorFactories.Builder subfactoriesBuilder) throws IOException {
|
||||
final FilterAggregationBuilder filterAgg = new FilterAggregationBuilder(name, QueryBuilders.matchAllQuery());
|
||||
filterAgg.subAggregations(subfactoriesBuilder);
|
||||
final AggregatorFactory factory = filterAgg.build(queryShardContext, parent);
|
||||
return new AggregatorFactory(name, queryShardContext, parent, subfactoriesBuilder, metaData) {
|
||||
@Override
|
||||
protected Aggregator createInternal(SearchContext searchContext,
|
||||
Aggregator parent,
|
||||
boolean collectsFromSingleBucket,
|
||||
List<PipelineAggregator> pipelineAggregators,
|
||||
Map<String, Object> metaData) throws IOException {
|
||||
while (searchContext.isCancelled() == false) {
|
||||
try {
|
||||
Thread.sleep(SLEEP_TIME);
|
||||
} catch (InterruptedException e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
}
|
||||
return factory.create(searchContext, parent, collectsFromSingleBucket);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.search;
|
||||
|
||||
import org.elasticsearch.index.shard.ShardId;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.plugins.SearchPlugin;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.InternalFilter;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SearchTestPlugin extends Plugin implements SearchPlugin {
|
||||
private Map<ShardId, ShardIdLatch> shardsLatch;
|
||||
|
||||
public SearchTestPlugin() {
|
||||
this.shardsLatch = null;
|
||||
}
|
||||
|
||||
public void resetQueryLatch(Map<ShardId, ShardIdLatch> newLatch) {
|
||||
shardsLatch = newLatch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QuerySpec<?>> getQueries() {
|
||||
return Collections.singletonList(
|
||||
new QuerySpec<>(BlockingQueryBuilder.NAME,
|
||||
in -> new BlockingQueryBuilder(in, shardsLatch),
|
||||
p -> BlockingQueryBuilder.fromXContent(p, shardsLatch))
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationSpec> getAggregations() {
|
||||
return Collections.singletonList(new AggregationSpec(CancellingAggregationBuilder.NAME, CancellingAggregationBuilder::new,
|
||||
CancellingAggregationBuilder.PARSER).addResultReader(InternalFilter::new));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.search;
|
||||
|
||||
import org.elasticsearch.index.shard.ShardId;
|
||||
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
class ShardIdLatch extends CountDownLatch {
|
||||
private final ShardId shard;
|
||||
private final boolean shouldFail;
|
||||
|
||||
ShardIdLatch(ShardId shard, boolean shouldFail) {
|
||||
super(1);
|
||||
this.shard = shard;
|
||||
this.shouldFail = shouldFail;
|
||||
}
|
||||
|
||||
ShardId shardId() {
|
||||
return shard;
|
||||
}
|
||||
|
||||
boolean shouldFail() {
|
||||
return shouldFail;
|
||||
}
|
||||
}
|
|
@ -150,7 +150,8 @@ public class SubmitAsyncSearchRequest extends ActionRequest {
|
|||
return new CancellableTask(id, type, action, toString(), parentTaskId, headers) {
|
||||
@Override
|
||||
public boolean shouldCancelChildrenOnCancellation() {
|
||||
return true;
|
||||
// we cancel the underlying search action explicitly in the submit action
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue