EQL: Add cascading search cancellation (#54843)

EQL search cancellation now propagates cancellation to underlying search
operations.

Relates to #49638
This commit is contained in:
Igor Motov 2020-04-14 08:06:02 -04:00 committed by GitHub
parent 16ebbff3b6
commit 8a669dc9b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 491 additions and 25 deletions

View File

@ -20,6 +20,21 @@ archivesBaseName = 'x-pack-eql'
// All integration tests live in qa modules // All integration tests live in qa modules
integTest.enabled = false integTest.enabled = false
task internalClusterTest(type: Test) {
mustRunAfter test
include '**/*IT.class'
/*
* We have to disable setting the number of available processors as tests in the same JVM randomize processors and will step on each
* other if we allow them to set the number of available processors as it's set-once in Netty.
*/
systemProperty 'es.set.netty.runtime.available.processors', 'false'
if (BuildParams.isSnapshotBuild() == false) {
systemProperty 'es.eql_feature_flag_registered', 'true'
}
}
check.dependsOn internalClusterTest
dependencies { dependencies {
compileOnly project(path: xpackModule('core'), configuration: 'default') compileOnly project(path: xpackModule('core'), configuration: 'default')
compileOnly(project(':modules:lang-painless')) { compileOnly(project(':modules:lang-painless')) {

View File

@ -23,7 +23,7 @@ public class EqlSearchTask extends CancellableTask {
@Override @Override
public boolean shouldCancelChildrenOnCancellation() { public boolean shouldCancelChildrenOnCancellation() {
return false; return true;
} }
@Override @Override

View File

@ -15,6 +15,7 @@ import org.elasticsearch.common.time.DateUtils;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.XPackSettings;
@ -50,11 +51,12 @@ public class TransportEqlSearchAction extends HandledTransportAction<EqlSearchRe
@Override @Override
protected void doExecute(Task task, EqlSearchRequest request, ActionListener<EqlSearchResponse> listener) { protected void doExecute(Task task, EqlSearchRequest request, ActionListener<EqlSearchResponse> listener) {
operation(planExecutor, (EqlSearchTask) task, request, username(securityContext), clusterName(clusterService), listener); operation(planExecutor, (EqlSearchTask) task, request, username(securityContext), clusterName(clusterService),
clusterService.localNode().getId(), listener);
} }
public static void operation(PlanExecutor planExecutor, EqlSearchTask task, EqlSearchRequest request, String username, public static void operation(PlanExecutor planExecutor, EqlSearchTask task, EqlSearchRequest request, String username,
String clusterName, ActionListener<EqlSearchResponse> listener) { String clusterName, String nodeId, ActionListener<EqlSearchResponse> listener) {
// TODO: these should be sent by the client // TODO: these should be sent by the client
ZoneId zoneId = DateUtils.of("Z"); ZoneId zoneId = DateUtils.of("Z");
QueryBuilder filter = request.filter(); QueryBuilder filter = request.filter();
@ -68,7 +70,7 @@ public class TransportEqlSearchAction extends HandledTransportAction<EqlSearchRe
.implicitJoinKey(request.implicitJoinKeyField()); .implicitJoinKey(request.implicitJoinKeyField());
Configuration cfg = new Configuration(request.indices(), zoneId, username, clusterName, filter, timeout, request.fetchSize(), Configuration cfg = new Configuration(request.indices(), zoneId, username, clusterName, filter, timeout, request.fetchSize(),
includeFrozen, clientId, task); includeFrozen, clientId, new TaskId(nodeId, task.getId()), task::isCancelled);
planExecutor.eql(cfg, request.query(), params, wrap(r -> listener.onResponse(createResponse(r)), listener::onFailure)); planExecutor.eql(cfg, request.query(), params, wrap(r -> listener.onResponse(createResponse(r)), listener::onFailure));
} }

View File

@ -9,9 +9,10 @@ package org.elasticsearch.xpack.eql.session;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.eql.action.EqlSearchTask; import org.elasticsearch.tasks.TaskId;
import java.time.ZoneId; import java.time.ZoneId;
import java.util.function.Supplier;
public class Configuration extends org.elasticsearch.xpack.ql.session.Configuration { public class Configuration extends org.elasticsearch.xpack.ql.session.Configuration {
@ -20,13 +21,14 @@ public class Configuration extends org.elasticsearch.xpack.ql.session.Configurat
private final int size; private final int size;
private final String clientId; private final String clientId;
private final boolean includeFrozenIndices; private final boolean includeFrozenIndices;
private final EqlSearchTask task; private final Supplier<Boolean> isCancelled;
private final TaskId taskId;
@Nullable @Nullable
private QueryBuilder filter; private final QueryBuilder filter;
public Configuration(String[] indices, ZoneId zi, String username, String clusterName, QueryBuilder filter, TimeValue requestTimeout, public Configuration(String[] indices, ZoneId zi, String username, String clusterName, QueryBuilder filter, TimeValue requestTimeout,
int size, boolean includeFrozen, String clientId, EqlSearchTask task) { int size, boolean includeFrozen, String clientId, TaskId taskId, Supplier<Boolean> isCancelled) {
super(zi, username, clusterName); super(zi, username, clusterName);
@ -36,7 +38,8 @@ public class Configuration extends org.elasticsearch.xpack.ql.session.Configurat
this.size = size; this.size = size;
this.clientId = clientId; this.clientId = clientId;
this.includeFrozenIndices = includeFrozen; this.includeFrozenIndices = includeFrozen;
this.task = task; this.taskId = taskId;
this.isCancelled = isCancelled;
} }
public String[] indices() { public String[] indices() {
@ -64,6 +67,10 @@ public class Configuration extends org.elasticsearch.xpack.ql.session.Configurat
} }
public boolean isCancelled() { public boolean isCancelled() {
return task.isCancelled(); return isCancelled.get();
}
public TaskId getTaskId() {
return taskId;
} }
} }

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.eql.session;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.xpack.eql.analysis.Analyzer; import org.elasticsearch.xpack.eql.analysis.Analyzer;
@ -37,7 +38,7 @@ public class EqlSession {
public EqlSession(Client client, Configuration cfg, IndexResolver indexResolver, PreAnalyzer preAnalyzer, Analyzer analyzer, public EqlSession(Client client, Configuration cfg, IndexResolver indexResolver, PreAnalyzer preAnalyzer, Analyzer analyzer,
Optimizer optimizer, Planner planner, PlanExecutor planExecutor) { Optimizer optimizer, Planner planner, PlanExecutor planExecutor) {
this.client = client; this.client = new ParentTaskAssigningClient(client, cfg.getTaskId());
this.configuration = cfg; this.configuration = cfg;
this.indexResolver = indexResolver; this.indexResolver = indexResolver;
this.preAnalyzer = preAnalyzer; this.preAnalyzer = preAnalyzer;

View File

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.eql; package org.elasticsearch.xpack.eql;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.eql.action.EqlSearchAction; import org.elasticsearch.xpack.eql.action.EqlSearchAction;
import org.elasticsearch.xpack.eql.action.EqlSearchTask; import org.elasticsearch.xpack.eql.action.EqlSearchTask;
import org.elasticsearch.xpack.eql.session.Configuration; import org.elasticsearch.xpack.eql.session.Configuration;
@ -27,7 +28,7 @@ public final class EqlTestUtils {
public static final Configuration TEST_CFG = new Configuration(new String[]{"none"}, org.elasticsearch.xpack.ql.util.DateUtils.UTC, public static final Configuration TEST_CFG = new Configuration(new String[]{"none"}, org.elasticsearch.xpack.ql.util.DateUtils.UTC,
"nobody", "cluster", null, TimeValue.timeValueSeconds(30), -1, false, "", "nobody", "cluster", null, TimeValue.timeValueSeconds(30), -1, false, "",
new EqlSearchTask(-1, "", EqlSearchAction.NAME, () -> "", null, Collections.emptyMap())); new TaskId(randomAlphaOfLength(10), randomNonNegativeLong()), () -> false);
public static Configuration randomConfiguration() { public static Configuration randomConfiguration() {
return new Configuration(new String[]{randomAlphaOfLength(16)}, return new Configuration(new String[]{randomAlphaOfLength(16)},
@ -39,7 +40,8 @@ public final class EqlTestUtils {
randomIntBetween(5, 100), randomIntBetween(5, 100),
randomBoolean(), randomBoolean(),
randomAlphaOfLength(16), randomAlphaOfLength(16),
randomTask()); new TaskId(randomAlphaOfLength(10), randomNonNegativeLong()),
() -> false);
} }
public static EqlSearchTask randomTask() { public static EqlSearchTask randomTask() {

View File

@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.eql.action;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.LicenseService;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.eql.plugin.EqlPlugin;
import java.util.Collection;
import java.util.Collections;
import static org.elasticsearch.test.ESIntegTestCase.Scope.SUITE;
@ESIntegTestCase.ClusterScope(scope = SUITE, numDataNodes = 0, numClientNodes = 0, maxNumDataNodes = 0, transportClientRatio = 0)
public abstract class AbstractEqlIntegTestCase extends ESIntegTestCase {
@Override
protected Settings nodeSettings(int nodeOrdinal) {
Settings.Builder settings = Settings.builder().put(super.nodeSettings(nodeOrdinal));
settings.put(XPackSettings.SECURITY_ENABLED.getKey(), false);
settings.put(XPackSettings.MONITORING_ENABLED.getKey(), false);
settings.put(XPackSettings.WATCHER_ENABLED.getKey(), false);
settings.put(XPackSettings.GRAPH_ENABLED.getKey(), false);
settings.put(XPackSettings.MACHINE_LEARNING_ENABLED.getKey(), false);
settings.put(XPackSettings.SQL_ENABLED.getKey(), false);
settings.put(EqlPlugin.EQL_ENABLED_SETTING.getKey(), true);
settings.put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial");
return settings.build();
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singletonList(LocalStateEQLXPackPlugin.class);
}
}

View File

@ -0,0 +1,280 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.eql.action;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.support.ActionFilter;
import org.elasticsearch.action.support.ActionFilterChain;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexModule;
import org.elasticsearch.index.shard.SearchOperationListener;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskInfo;
import org.junit.After;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
public class EqlCancellationIT extends AbstractEqlIntegTestCase {
private final ExecutorService executorService = Executors.newFixedThreadPool(1);
/**
* Shutdown the executor so we don't leak threads into other test runs.
*/
@After
public void shutdownExec() {
executorService.shutdown();
}
public void testCancellation() throws Exception {
assertAcked(client().admin().indices().prepareCreate("test")
.addMapping("_doc", "val", "type=integer", "event_type", "type=keyword", "@timestamp", "type=date")
.get());
createIndex("idx_unmapped");
int numDocs = randomIntBetween(6, 20);
List<IndexRequestBuilder> builders = new ArrayList<>();
for (int i = 0; i < numDocs; i++) {
int fieldValue = randomIntBetween(0, 10);
builders.add(client().prepareIndex("test", "_doc").setSource(
jsonBuilder().startObject()
.field("val", fieldValue).field("event_type", "my_event").field("@timestamp", "2020-04-09T12:35:48Z")
.endObject()));
}
indexRandom(true, builders);
boolean cancelDuringSearch = randomBoolean();
List<SearchBlockPlugin> plugins = initBlockFactory(cancelDuringSearch, cancelDuringSearch == false);
EqlSearchRequest request = new EqlSearchRequest().indices("test").query("my_event where val=1").eventCategoryField("event_type");
String id = randomAlphaOfLength(10);
logger.trace("Preparing search");
// We might perform field caps on the same thread if it is local client, so we cannot use the standard mechanism
Future<EqlSearchResponse> future = executorService.submit(() ->
client().filterWithHeader(Collections.singletonMap(Task.X_OPAQUE_ID, id)).execute(EqlSearchAction.INSTANCE, request).get()
);
logger.trace("Waiting for block to be established");
if (cancelDuringSearch) {
awaitForBlockedSearches(plugins, "test");
} else {
awaitForBlockedFieldCaps(plugins);
}
logger.trace("Block is established");
ListTasksResponse tasks = client().admin().cluster().prepareListTasks().setActions(EqlSearchAction.NAME).get();
TaskId taskId = null;
for (TaskInfo task : tasks.getTasks()) {
if (id.equals(task.getHeaders().get(Task.X_OPAQUE_ID))) {
taskId = task.getTaskId();
break;
}
}
assertNotNull(taskId);
logger.trace("Cancelling task " + taskId);
CancelTasksResponse response = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).get();
assertThat(response.getTasks(), hasSize(1));
assertThat(response.getTasks().get(0).getAction(), equalTo(EqlSearchAction.NAME));
logger.trace("Task is cancelled " + taskId);
disableBlocks(plugins);
Exception exception = expectThrows(Exception.class, future::get);
Throwable inner = ExceptionsHelper.unwrap(exception, SearchPhaseExecutionException.class);
if (cancelDuringSearch) {
// Make sure we cancelled inside search
assertNotNull(inner);
assertThat(inner, instanceOf(SearchPhaseExecutionException.class));
assertThat(inner.getCause(), instanceOf(TaskCancelledException.class));
} else {
// Make sure we were not cancelled inside search
assertNull(inner);
assertThat(getNumberOfContexts(plugins), equalTo(0));
Throwable cancellationException = ExceptionsHelper.unwrap(exception, TaskCancelledException.class);
assertNotNull(cancellationException);
}
}
private List<SearchBlockPlugin> initBlockFactory(boolean searchBlock, boolean fieldCapsBlock) {
List<SearchBlockPlugin> plugins = new ArrayList<>();
for (PluginsService pluginsService : internalCluster().getDataNodeInstances(PluginsService.class)) {
plugins.addAll(pluginsService.filterPlugins(SearchBlockPlugin.class));
}
for (SearchBlockPlugin plugin : plugins) {
plugin.reset();
if (searchBlock) {
plugin.enableSearchBlock();
}
if (fieldCapsBlock) {
plugin.enableFieldCapBlock();
}
}
return plugins;
}
private void disableBlocks(List<SearchBlockPlugin> plugins) {
for (SearchBlockPlugin plugin : plugins) {
plugin.disableSearchBlock();
plugin.disableFieldCapBlock();
}
}
private void awaitForBlockedSearches(List<SearchBlockPlugin> plugins, String index) throws Exception {
int numberOfShards = getNumShards(index).numPrimaries;
assertBusy(() -> {
int numberOfBlockedPlugins = getNumberOfContexts(plugins);
logger.trace("The plugin blocked on {} out of {} shards", numberOfBlockedPlugins, numberOfShards);
assertThat(numberOfBlockedPlugins, greaterThan(0));
});
}
private int getNumberOfContexts(List<SearchBlockPlugin> plugins) throws Exception {
int count = 0;
for (SearchBlockPlugin plugin : plugins) {
count += plugin.contexts.get();
}
return count;
}
private int getNumberOfFieldCaps(List<SearchBlockPlugin> plugins) throws Exception {
int count = 0;
for (SearchBlockPlugin plugin : plugins) {
count += plugin.fieldCaps.get();
}
return count;
}
private void awaitForBlockedFieldCaps(List<SearchBlockPlugin> plugins) throws Exception {
assertBusy(() -> {
int numberOfBlockedPlugins = getNumberOfFieldCaps(plugins);
logger.trace("The plugin blocked on {} nodes", numberOfBlockedPlugins);
assertThat(numberOfBlockedPlugins, greaterThan(0));
});
}
public static class SearchBlockPlugin extends LocalStateEQLXPackPlugin {
protected final Logger logger = LogManager.getLogger(getClass());
private final AtomicInteger contexts = new AtomicInteger();
private final AtomicInteger fieldCaps = new AtomicInteger();
private final AtomicBoolean shouldBlockOnSearch = new AtomicBoolean(false);
private final AtomicBoolean shouldBlockOnFieldCapabilities = new AtomicBoolean(false);
private final String nodeId;
public void reset() {
contexts.set(0);
fieldCaps.set(0);
}
public void disableSearchBlock() {
shouldBlockOnSearch.set(false);
}
public void enableSearchBlock() {
shouldBlockOnSearch.set(true);
}
public void disableFieldCapBlock() {
shouldBlockOnFieldCapabilities.set(false);
}
public void enableFieldCapBlock() {
shouldBlockOnFieldCapabilities.set(true);
}
public SearchBlockPlugin(Settings settings, Path configPath) throws Exception {
super(settings, configPath);
nodeId = settings.get("node.name");
}
@Override
public void onIndexModule(IndexModule indexModule) {
super.onIndexModule(indexModule);
indexModule.addSearchOperationListener(new SearchOperationListener() {
@Override
public void onNewContext(SearchContext context) {
contexts.incrementAndGet();
try {
logger.trace("blocking search on " + nodeId);
assertBusy(() -> assertFalse(shouldBlockOnSearch.get()));
logger.trace("unblocking search on " + nodeId);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
}
@Override
public List<ActionFilter> getActionFilters() {
List<ActionFilter> list = new ArrayList<>(super.getActionFilters());
list.add(new ActionFilter() {
@Override
public int order() {
return 0;
}
@Override
public <Request extends ActionRequest, Response extends ActionResponse> void apply(
Task task, String action, Request request, ActionListener<Response> listener,
ActionFilterChain<Request, Response> chain) {
if (action.equals(FieldCapabilitiesAction.NAME)) {
try {
fieldCaps.incrementAndGet();
logger.trace("blocking field caps on " + nodeId);
assertBusy(() -> assertFalse(shouldBlockOnFieldCapabilities.get()));
logger.trace("unblocking field caps on " + nodeId);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
chain.proceed(task, action, request, listener);
}
});
return list;
}
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singletonList(SearchBlockPlugin.class);
}
}

View File

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.eql.action;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.eql.plugin.EqlPlugin;
import org.elasticsearch.xpack.ql.plugin.QlPlugin;
import java.nio.file.Path;
public class LocalStateEQLXPackPlugin extends LocalStateCompositeXPackPlugin {
public LocalStateEQLXPackPlugin(final Settings settings, final Path configPath) {
super(settings, configPath);
LocalStateEQLXPackPlugin thisVar = this;
plugins.add(new EqlPlugin(settings) {
@Override
protected XPackLicenseState getLicenseState() {
return thisVar.getLicenseState();
}
});
plugins.add(new QlPlugin());
}
}

View File

@ -8,9 +8,15 @@ package org.elasticsearch.xpack.eql.analysis;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.fieldcaps.FieldCapabilities; import org.elasticsearch.action.fieldcaps.FieldCapabilities;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.eql.action.EqlSearchRequest; import org.elasticsearch.xpack.eql.action.EqlSearchRequest;
import org.elasticsearch.xpack.eql.action.EqlSearchResponse; import org.elasticsearch.xpack.eql.action.EqlSearchResponse;
@ -19,6 +25,7 @@ import org.elasticsearch.xpack.eql.execution.PlanExecutor;
import org.elasticsearch.xpack.eql.plugin.TransportEqlSearchAction; import org.elasticsearch.xpack.eql.plugin.TransportEqlSearchAction;
import org.elasticsearch.xpack.ql.index.IndexResolver; import org.elasticsearch.xpack.ql.index.IndexResolver;
import org.elasticsearch.xpack.ql.type.DefaultDataTypeRegistry; import org.elasticsearch.xpack.ql.type.DefaultDataTypeRegistry;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import java.util.Collections; import java.util.Collections;
@ -42,13 +49,14 @@ public class CancellationTests extends ESTestCase {
public void testCancellationBeforeFieldCaps() throws InterruptedException { public void testCancellationBeforeFieldCaps() throws InterruptedException {
Client client = mock(Client.class); Client client = mock(Client.class);
when(client.settings()).thenReturn(Settings.EMPTY);
EqlSearchTask task = mock(EqlSearchTask.class); EqlSearchTask task = mock(EqlSearchTask.class);
when(task.isCancelled()).thenReturn(true); when(task.isCancelled()).thenReturn(true);
IndexResolver indexResolver = new IndexResolver(client, randomAlphaOfLength(10), DefaultDataTypeRegistry.INSTANCE); IndexResolver indexResolver = new IndexResolver(client, randomAlphaOfLength(10), DefaultDataTypeRegistry.INSTANCE);
PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList())); PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList()));
CountDownLatch countDownLatch = new CountDownLatch(1); CountDownLatch countDownLatch = new CountDownLatch(1);
TransportEqlSearchAction.operation(planExecutor, task, new EqlSearchRequest().query("foo where blah"), "", "", TransportEqlSearchAction.operation(planExecutor, task, new EqlSearchRequest().query("foo where blah"), "", "", "node_id",
new ActionListener<EqlSearchResponse>() { new ActionListener<EqlSearchResponse>() {
@Override @Override
public void onResponse(EqlSearchResponse eqlSearchResponse) { public void onResponse(EqlSearchResponse eqlSearchResponse) {
@ -64,18 +72,13 @@ public class CancellationTests extends ESTestCase {
}); });
countDownLatch.await(); countDownLatch.await();
verify(task, times(1)).isCancelled(); verify(task, times(1)).isCancelled();
verify(task, times(1)).getId();
verify(client, times(1)).settings();
verify(client, times(1)).threadPool();
verifyNoMoreInteractions(client, task); verifyNoMoreInteractions(client, task);
} }
public void testCancellationBeforeSearch() throws InterruptedException { private Map<String, Map<String, FieldCapabilities>> fields(String[] indices) {
Client client = mock(Client.class);
AtomicBoolean cancelled = new AtomicBoolean(false);
EqlSearchTask task = mock(EqlSearchTask.class);
when(task.isCancelled()).then(invocationOnMock -> cancelled.get());
String[] indices = new String[]{"endgame"};
FieldCapabilities fooField = FieldCapabilities fooField =
new FieldCapabilities("foo", "integer", true, true, indices, null, null, emptyMap()); new FieldCapabilities("foo", "integer", true, true, indices, null, null, emptyMap());
FieldCapabilities categoryField = FieldCapabilities categoryField =
@ -86,10 +89,24 @@ public class CancellationTests extends ESTestCase {
fields.put(fooField.getName(), singletonMap(fooField.getName(), fooField)); fields.put(fooField.getName(), singletonMap(fooField.getName(), fooField));
fields.put(categoryField.getName(), singletonMap(categoryField.getName(), categoryField)); fields.put(categoryField.getName(), singletonMap(categoryField.getName(), categoryField));
fields.put(timestampField.getName(), singletonMap(timestampField.getName(), timestampField)); fields.put(timestampField.getName(), singletonMap(timestampField.getName(), timestampField));
return fields;
}
public void testCancellationBeforeSearch() throws InterruptedException {
Client client = mock(Client.class);
when(client.settings()).thenReturn(Settings.EMPTY);
AtomicBoolean cancelled = new AtomicBoolean(false);
EqlSearchTask task = mock(EqlSearchTask.class);
String nodeId = randomAlphaOfLength(10);
long taskId = randomNonNegativeLong();
when(task.isCancelled()).then(invocationOnMock -> cancelled.get());
when(task.getId()).thenReturn(taskId);
String[] indices = new String[]{"endgame"};
FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class); FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
when(fieldCapabilitiesResponse.getIndices()).thenReturn(indices); when(fieldCapabilitiesResponse.getIndices()).thenReturn(indices);
when(fieldCapabilitiesResponse.get()).thenReturn(fields); when(fieldCapabilitiesResponse.get()).thenReturn(fields(indices));
doAnswer((Answer<Void>) invocation -> { doAnswer((Answer<Void>) invocation -> {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
ActionListener<FieldCapabilitiesResponse> listener = (ActionListener<FieldCapabilitiesResponse>) invocation.getArguments()[1]; ActionListener<FieldCapabilitiesResponse> listener = (ActionListener<FieldCapabilitiesResponse>) invocation.getArguments()[1];
@ -103,7 +120,7 @@ public class CancellationTests extends ESTestCase {
PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList())); PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList()));
CountDownLatch countDownLatch = new CountDownLatch(1); CountDownLatch countDownLatch = new CountDownLatch(1);
TransportEqlSearchAction.operation(planExecutor, task, new EqlSearchRequest().indices("endgame") TransportEqlSearchAction.operation(planExecutor, task, new EqlSearchRequest().indices("endgame")
.query("process where foo==3"), "", "", new ActionListener<EqlSearchResponse>() { .query("process where foo==3"), "", "", nodeId, new ActionListener<EqlSearchResponse>() {
@Override @Override
public void onResponse(EqlSearchResponse eqlSearchResponse) { public void onResponse(EqlSearchResponse eqlSearchResponse) {
fail("Shouldn't be here"); fail("Shouldn't be here");
@ -119,6 +136,75 @@ public class CancellationTests extends ESTestCase {
countDownLatch.await(); countDownLatch.await();
verify(client).fieldCaps(any(), any()); verify(client).fieldCaps(any(), any());
verify(task, times(2)).isCancelled(); verify(task, times(2)).isCancelled();
verify(task, times(1)).getId();
verify(client, times(1)).settings();
verify(client, times(1)).threadPool();
verifyNoMoreInteractions(client, task);
}
public void testCancellationDuringSearch() throws InterruptedException {
Client client = mock(Client.class);
when(client.settings()).thenReturn(Settings.EMPTY);
EqlSearchTask task = mock(EqlSearchTask.class);
String nodeId = randomAlphaOfLength(10);
long taskId = randomNonNegativeLong();
when(task.isCancelled()).thenReturn(false);
when(task.getId()).thenReturn(taskId);
String[] indices = new String[]{"endgame"};
// Emulation of field capabilities
FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
when(fieldCapabilitiesResponse.getIndices()).thenReturn(indices);
when(fieldCapabilitiesResponse.get()).thenReturn(fields(indices));
doAnswer((Answer<Void>) invocation -> {
@SuppressWarnings("unchecked")
ActionListener<FieldCapabilitiesResponse> listener = (ActionListener<FieldCapabilitiesResponse>) invocation.getArguments()[1];
listener.onResponse(fieldCapabilitiesResponse);
return null;
}).when(client).fieldCaps(any(), any());
// Emulation of search cancellation
ArgumentCaptor<SearchRequest> searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class);
when(client.prepareSearch(any())).thenReturn(new SearchRequestBuilder(client, SearchAction.INSTANCE).setIndices(indices));
doAnswer((Answer<Void>) invocation -> {
@SuppressWarnings("unchecked")
SearchRequest request = (SearchRequest) invocation.getArguments()[1];
TaskId parentTask = request.getParentTask();
assertNotNull(parentTask);
assertEquals(taskId, parentTask.getId());
assertEquals(nodeId, parentTask.getNodeId());
@SuppressWarnings("unchecked")
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocation.getArguments()[2];
listener.onFailure(new TaskCancelledException("cancelled"));
return null;
}).when(client).execute(any(), searchRequestCaptor.capture(), any());
IndexResolver indexResolver = new IndexResolver(client, randomAlphaOfLength(10), DefaultDataTypeRegistry.INSTANCE);
PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList()));
CountDownLatch countDownLatch = new CountDownLatch(1);
TransportEqlSearchAction.operation(planExecutor, task, new EqlSearchRequest().indices("endgame")
.query("process where foo==3"), "", "", nodeId, new ActionListener<EqlSearchResponse>() {
@Override
public void onResponse(EqlSearchResponse eqlSearchResponse) {
fail("Shouldn't be here");
countDownLatch.countDown();
}
@Override
public void onFailure(Exception e) {
assertThat(e, instanceOf(TaskCancelledException.class));
countDownLatch.countDown();
}
});
countDownLatch.await();
// Final verification to ensure no more interaction
verify(client).fieldCaps(any(), any());
verify(client).execute(any(), any(), any());
verify(task, times(2)).isCancelled();
verify(task, times(1)).getId();
verify(client, times(1)).settings();
verify(client, times(1)).threadPool();
verifyNoMoreInteractions(client, task); verifyNoMoreInteractions(client, task);
} }