Support hierarchical task cancellation (#54757)

With this change, when a task is canceled, the task manager will cancel
not only its direct child tasks but all also its descendant tasks.

Closes #50990
This commit is contained in:
Nhat Nguyen 2020-04-06 12:00:02 -04:00
parent 51c6f69e02
commit 96bb1164f0
6 changed files with 289 additions and 261 deletions

View File

@ -234,7 +234,7 @@ nodes `nodeId1` and `nodeId2`.
`wait_for_completion`:: `wait_for_completion`::
(Optional, boolean) If `true`, the request blocks until the cancellation of the (Optional, boolean) If `true`, the request blocks until the cancellation of the
task and its child tasks is completed. Otherwise, the request can return soon task and its descendant tasks is completed. Otherwise, the request can return soon
after the cancellation is started. Defaults to `false`. after the cancellation is started. Defaults to `false`.
[source,console] [source,console]

View File

@ -42,7 +42,7 @@
}, },
"wait_for_completion": { "wait_for_completion": {
"type":"boolean", "type":"boolean",
"description":"Should the request block until the cancellation of the task and its child tasks is completed. Defaults to false" "description":"Should the request block until the cancellation of the task and its descendant tasks is completed. Defaults to false"
} }
} }
} }

View File

@ -79,7 +79,7 @@ public class CancelTasksRequest extends BaseTasksRequest<CancelTasksRequest> {
} }
/** /**
* If {@code true}, the request blocks until the cancellation of the task and its child tasks is completed. * If {@code true}, the request blocks until the cancellation of the task and its descendant tasks is completed.
* Otherwise, the request can return soon after the cancellation is started. Defaults to {@code false}. * Otherwise, the request can return soon after the cancellation is started. Defaults to {@code false}.
*/ */
public void setWaitForCompletion(boolean waitForCompletion) { public void setWaitForCompletion(boolean waitForCompletion) {

View File

@ -20,11 +20,13 @@
package org.elasticsearch.action.admin.cluster.node.tasks.cancel; package org.elasticsearch.action.admin.cluster.node.tasks.cancel;
import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.StepListener; import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.TaskOperationFailure; import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
@ -104,34 +106,43 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
@Override @Override
protected void taskOperation(CancelTasksRequest request, CancellableTask cancellableTask, ActionListener<TaskInfo> listener) { protected void taskOperation(CancelTasksRequest request, CancellableTask cancellableTask, ActionListener<TaskInfo> listener) {
String nodeId = clusterService.localNode().getId(); String nodeId = clusterService.localNode().getId();
if (cancellableTask.shouldCancelChildrenOnCancellation()) { cancelTaskAndDescendants(cancellableTask, request.getReason(), request.waitForCompletion(),
ActionListener.map(listener, r -> cancellableTask.taskInfo(nodeId, false)));
}
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
if (task.shouldCancelChildrenOnCancellation()) {
StepListener<Void> completedListener = new StepListener<>(); StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3); GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3);
Collection<DiscoveryNode> childrenNodes = Collection<DiscoveryNode> childrenNodes =
taskManager.startBanOnChildrenNodes(cancellableTask.getId(), () -> groupedListener.onResponse(null)); taskManager.startBanOnChildrenNodes(task.getId(), () -> groupedListener.onResponse(null));
taskManager.cancel(cancellableTask, request.getReason(), () -> groupedListener.onResponse(null)); taskManager.cancel(task, reason, () -> groupedListener.onResponse(null));
StepListener<Void> banOnNodesListener = new StepListener<>(); StepListener<Void> banOnNodesListener = new StepListener<>();
setBanOnNodes(request.getReason(), cancellableTask, childrenNodes, banOnNodesListener); setBanOnNodes(reason, waitForCompletion, task, childrenNodes, banOnNodesListener);
banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure); banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
// We remove bans after all child tasks are completed although in theory we can do it on a per-node basis. // We remove bans after all child tasks are completed although in theory we can do it on a per-node basis.
completedListener.whenComplete( completedListener.whenComplete(r -> removeBanOnNodes(task, childrenNodes), e -> removeBanOnNodes(task, childrenNodes));
r -> removeBanOnNodes(cancellableTask, childrenNodes), // if wait_for_completion is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
e -> removeBanOnNodes(cancellableTask, childrenNodes));
// if wait_for_child_tasks is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes. // completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes.
if (request.waitForCompletion()) { if (waitForCompletion) {
completedListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure); completedListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
} else { } else {
banOnNodesListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure); banOnNodesListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
} }
} else { } else {
logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId()); logger.trace("task {} doesn't have any children that should be cancelled", task.getId());
taskManager.cancel(cancellableTask, request.getReason(), () -> listener.onResponse(cancellableTask.taskInfo(nodeId, false))); if (waitForCompletion) {
taskManager.cancel(task, reason, () -> listener.onResponse(null));
} else {
taskManager.cancel(task, reason, () -> {});
listener.onResponse(null);
}
} }
} }
private void setBanOnNodes(String reason, CancellableTask task, Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) { private void setBanOnNodes(String reason, boolean waitForCompletion, CancellableTask task,
Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
if (childNodes.isEmpty()) { if (childNodes.isEmpty()) {
listener.onResponse(null); listener.onResponse(null);
return; return;
@ -140,7 +151,7 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
GroupedActionListener<Void> groupedListener = GroupedActionListener<Void> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> null), childNodes.size()); new GroupedActionListener<>(ActionListener.map(listener, r -> null), childNodes.size());
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest( final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(
new TaskId(clusterService.localNode().getId(), task.getId()), reason); new TaskId(clusterService.localNode().getId(), task.getId()), reason, waitForCompletion);
for (DiscoveryNode node : childNodes) { for (DiscoveryNode node : childNodes) {
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest, transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) { new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@ -171,26 +182,29 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
private final TaskId parentTaskId; private final TaskId parentTaskId;
private final boolean ban; private final boolean ban;
private final boolean waitForCompletion;
private final String reason; private final String reason;
static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason) { static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
return new BanParentTaskRequest(parentTaskId, reason); return new BanParentTaskRequest(parentTaskId, reason, waitForCompletion);
} }
static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) { static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) {
return new BanParentTaskRequest(parentTaskId); return new BanParentTaskRequest(parentTaskId);
} }
private BanParentTaskRequest(TaskId parentTaskId, String reason) { private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
this.parentTaskId = parentTaskId; this.parentTaskId = parentTaskId;
this.ban = true; this.ban = true;
this.reason = reason; this.reason = reason;
this.waitForCompletion = waitForCompletion;
} }
private BanParentTaskRequest(TaskId parentTaskId) { private BanParentTaskRequest(TaskId parentTaskId) {
this.parentTaskId = parentTaskId; this.parentTaskId = parentTaskId;
this.ban = false; this.ban = false;
this.reason = null; this.reason = null;
this.waitForCompletion = false;
} }
private BanParentTaskRequest(StreamInput in) throws IOException { private BanParentTaskRequest(StreamInput in) throws IOException {
@ -198,6 +212,11 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
parentTaskId = TaskId.readFromStream(in); parentTaskId = TaskId.readFromStream(in);
ban = in.readBoolean(); ban = in.readBoolean();
reason = ban ? in.readString() : null; reason = ban ? in.readString() : null;
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
waitForCompletion = in.readBoolean();
} else {
waitForCompletion = false;
}
} }
@Override @Override
@ -208,6 +227,9 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
if (ban) { if (ban) {
out.writeString(reason); out.writeString(reason);
} }
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeBoolean(waitForCompletion);
}
} }
} }
@ -217,13 +239,20 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
if (request.ban) { if (request.ban) {
logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId, logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId,
clusterService.localNode().getId(), request.reason); clusterService.localNode().getId(), request.reason);
taskManager.setBan(request.parentTaskId, request.reason); final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason);
final GroupedActionListener<Void> listener = new GroupedActionListener<>(ActionListener.map(
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request), r -> TransportResponse.Empty.INSTANCE),
childTasks.size() + 1);
for (CancellableTask childTask : childTasks) {
cancelTaskAndDescendants(childTask, request.reason, request.waitForCompletion, listener);
}
listener.onResponse(null);
} else { } else {
logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId, logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId,
clusterService.localNode().getId()); clusterService.localNode().getId());
taskManager.removeBan(request.parentTaskId); taskManager.removeBan(request.parentTaskId);
channel.sendResponse(TransportResponse.Empty.INSTANCE);
} }
channel.sendResponse(TransportResponse.Empty.INSTANCE);
} }
} }

View File

@ -333,8 +333,9 @@ public class TaskManager implements ClusterStateApplier {
* Bans all tasks with the specified parent task from execution, cancels all tasks that are currently executing. * Bans all tasks with the specified parent task from execution, cancels all tasks that are currently executing.
* <p> * <p>
* This method is called when a parent task that has children is cancelled. * This method is called when a parent task that has children is cancelled.
* @return a list of pending cancellable child tasks
*/ */
public void setBan(TaskId parentTaskId, String reason) { public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
logger.trace("setting ban for the parent task {} {}", parentTaskId, reason); logger.trace("setting ban for the parent task {} {}", parentTaskId, reason);
// Set the ban first, so the newly created tasks cannot be registered // Set the ban first, so the newly created tasks cannot be registered
@ -344,14 +345,10 @@ public class TaskManager implements ClusterStateApplier {
banedParents.put(parentTaskId, reason); banedParents.put(parentTaskId, reason);
} }
} }
return cancellableTasks.values().stream()
// Now go through already running tasks and cancel them .filter(t -> t.hasParent(parentTaskId))
for (Map.Entry<Long, CancellableTaskHolder> taskEntry : cancellableTasks.entrySet()) { .map(t -> t.task)
CancellableTaskHolder holder = taskEntry.getValue(); .collect(Collectors.toList());
if (holder.hasParent(parentTaskId)) {
holder.cancel(reason);
}
}
} }
/** /**
@ -365,11 +362,8 @@ public class TaskManager implements ClusterStateApplier {
} }
// for testing // for testing
public boolean childTasksCancelledOrBanned(TaskId parentTaskId) { public Set<TaskId> getBannedTaskIds() {
if (banedParents.containsKey(parentTaskId)) { return Collections.unmodifiableSet(banedParents.keySet());
return true;
}
return cancellableTasks.values().stream().noneMatch(task -> task.hasParent(parentTaskId));
} }
/** /**

View File

@ -41,12 +41,14 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -57,97 +59,147 @@ import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.StreamSupport; import java.util.stream.StreamSupport;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.either;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
public class CancellableTasksIT extends ESIntegTestCase { public class CancellableTasksIT extends ESIntegTestCase {
static final Map<ChildRequest, CountDownLatch> arrivedLatches = ConcurrentCollections.newConcurrentMap();
static final Map<ChildRequest, CountDownLatch> beforeExecuteLatches = ConcurrentCollections.newConcurrentMap(); static int idGenerator = 0;
static final Map<ChildRequest, CountDownLatch> completedLatches = ConcurrentCollections.newConcurrentMap(); static final Map<TestRequest, CountDownLatch> beforeSendLatches = ConcurrentCollections.newConcurrentMap();
static final Map<TestRequest, CountDownLatch> arrivedLatches = ConcurrentCollections.newConcurrentMap();
static final Map<TestRequest, CountDownLatch> beforeExecuteLatches = ConcurrentCollections.newConcurrentMap();
static final Map<TestRequest, CountDownLatch> completedLatches = ConcurrentCollections.newConcurrentMap();
@Before @Before
public void resetTestStates() { public void resetTestStates() {
idGenerator = 0;
beforeSendLatches.clear();
arrivedLatches.clear(); arrivedLatches.clear();
beforeExecuteLatches.clear(); beforeExecuteLatches.clear();
completedLatches.clear(); completedLatches.clear();
} }
List<ChildRequest> setupChildRequests(Set<DiscoveryNode> nodes) { static TestRequest generateTestRequest(Set<DiscoveryNode> nodes, int level, int maxLevel) {
int numRequests = randomIntBetween(1, 30); List<TestRequest> subRequests = new ArrayList<>();
List<ChildRequest> childRequests = new ArrayList<>(); int lower = level == 0 ? 1 : 0;
for (int i = 0; i < numRequests; i++) { int upper = 10 / (level + 1);
ChildRequest req = new ChildRequest(i, randomFrom(nodes)); int numOfSubRequests = randomIntBetween(lower, upper);
childRequests.add(req); for (int i = 0; i < numOfSubRequests && level <= maxLevel; i++) {
arrivedLatches.put(req, new CountDownLatch(1)); subRequests.add(generateTestRequest(nodes, level + 1, maxLevel));
beforeExecuteLatches.put(req, new CountDownLatch(1));
completedLatches.put(req, new CountDownLatch(1));
} }
return childRequests; final TestRequest request = new TestRequest(idGenerator++, randomFrom(nodes), subRequests);
beforeSendLatches.put(request, new CountDownLatch(1));
arrivedLatches.put(request, new CountDownLatch(1));
beforeExecuteLatches.put(request, new CountDownLatch(1));
completedLatches.put(request, new CountDownLatch(1));
return request;
} }
public void testBanOnlyNodesWithOutstandingChildTasks() throws Exception { static void randomDescendants(TestRequest request, Set<TestRequest> result) {
if (randomBoolean()) {
result.add(request);
for (TestRequest subReq : request.subRequests) {
randomDescendants(subReq, result);
}
}
}
/**
* Allow some parts of the request to be completed
* @return a pending child requests
*/
static Set<TestRequest> allowPartialRequest(TestRequest request) throws Exception {
final Set<TestRequest> sentRequests = new HashSet<>();
while (sentRequests.isEmpty()) {
for (TestRequest subRequest : request.subRequests) {
randomDescendants(subRequest, sentRequests);
}
}
for (TestRequest req : sentRequests) {
beforeSendLatches.get(req).countDown();
}
for (TestRequest req : sentRequests) {
arrivedLatches.get(req).await();
}
Set<TestRequest> completedRequests = new HashSet<>();
for (TestRequest req : randomSubsetOf(sentRequests)) {
if (sentRequests.containsAll(req.descendants())) {
completedRequests.add(req);
completedRequests.addAll(req.descendants());
}
}
for (TestRequest req : completedRequests) {
beforeExecuteLatches.get(req).countDown();
}
for (TestRequest req : completedRequests) {
completedLatches.get(req).await();
}
return Sets.difference(sentRequests, completedRequests);
}
static void allowEntireRequest(TestRequest request) {
beforeSendLatches.get(request).countDown();
beforeExecuteLatches.get(request).countDown();
for (TestRequest subReq : request.subRequests) {
allowEntireRequest(subReq);
}
}
public void testBanOnlyNodesWithOutstandingDescendantTasks() throws Exception {
if (randomBoolean()) { if (randomBoolean()) {
internalCluster().startNodes(randomIntBetween(1, 3)); internalCluster().startNodes(randomIntBetween(1, 3));
} }
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
List<ChildRequest> childRequests = setupChildRequests(nodes); final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4));
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests)); ActionFuture<TestResponse> rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
List<ChildRequest> completedRequests = randomSubsetOf(between(0, childRequests.size() - 1), childRequests); Set<TestRequest> pendingRequests = allowPartialRequest(rootRequest);
for (ChildRequest req : completedRequests) { TaskId rootTaskId = getRootTaskId(rootRequest);
beforeExecuteLatches.get(req).countDown(); ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks()
completedLatches.get(req).await(); .setTaskId(rootTaskId).waitForCompletion(true).execute();
if (randomBoolean()) {
List<TaskInfo> runningTasks = client().admin().cluster().prepareListTasks()
.setActions(TransportTestAction.ACTION.name()).setDetailed(true).get().getTasks();
for (TaskInfo subTask : randomSubsetOf(runningTasks)) {
client().admin().cluster().prepareCancelTasks().setTaskId(subTask.getTaskId()).waitForCompletion(false).get();
}
} }
List<ChildRequest> outstandingRequests = childRequests.stream().
filter(r -> completedRequests.contains(r) == false)
.collect(Collectors.toList());
for (ChildRequest req : outstandingRequests) {
arrivedLatches.get(req).await();
}
TaskId taskId = getMainTaskId();
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
.waitForCompletion(true).execute();
Set<DiscoveryNode> nodesWithOutstandingChildTask = outstandingRequests.stream().map(r -> r.targetNode).collect(Collectors.toSet());
assertBusy(() -> { assertBusy(() -> {
for (DiscoveryNode node : nodes) { for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager(); TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
if (nodesWithOutstandingChildTask.contains(node)) { Set<TaskId> expectedBans = new HashSet<>();
assertThat(taskManager.getBanCount(), equalTo(1)); for (TestRequest req : pendingRequests) {
} else { if (req.node.equals(node)) {
assertThat(taskManager.getBanCount(), equalTo(0)); List<Task> childTasks = taskManager.getTasks().values().stream()
.filter(t -> t.getParentTaskId() != null && t.getDescription().equals(req.taskDescription()))
.collect(Collectors.toList());
assertThat(childTasks, hasSize(1));
CancellableTask childTask = (CancellableTask) childTasks.get(0);
assertTrue(childTask.isCancelled());
expectedBans.add(childTask.getParentTaskId());
}
} }
assertThat(taskManager.getBannedTaskIds(), equalTo(expectedBans));
} }
}); });
// failed to spawn child tasks after cancelling allowEntireRequest(rootRequest);
if (randomBoolean()) {
DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get();
TransportMainAction mainAction = internalCluster().getInstance(TransportMainAction.class, nodeWithParentTask.getName());
PlainActionFuture<ChildResponse> future = new PlainActionFuture<>();
ChildRequest req = new ChildRequest(-1, randomFrom(nodes));
completedLatches.put(req, new CountDownLatch(1));
mainAction.startChildTask(taskId, req, future);
TransportException te = expectThrows(TransportException.class, future::actionGet);
assertThat(te.getCause(), instanceOf(TaskCancelledException.class));
assertThat(te.getCause().getMessage(), equalTo("The parent task was cancelled, shouldn't start any child tasks"));
}
for (ChildRequest req : outstandingRequests) {
beforeExecuteLatches.get(req).countDown();
}
cancelFuture.actionGet(); cancelFuture.actionGet();
waitForMainTask(mainTaskFuture); waitForRootTask(rootTaskFuture);
assertBusy(() -> { assertBusy(() -> {
for (DiscoveryNode node : nodes) { for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager(); TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
@ -158,27 +210,20 @@ public class CancellableTasksIT extends ESIntegTestCase {
public void testCancelTaskMultipleTimes() throws Exception { public void testCancelTaskMultipleTimes() throws Exception {
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
List<ChildRequest> childRequests = setupChildRequests(nodes); TestRequest rootRequest = generateTestRequest(nodes, 0, randomIntBetween(1, 3));
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests)); ActionFuture<TestResponse> mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) { TaskId taskId = getRootTaskId(rootRequest);
arrivedLatches.get(r).await(); allowPartialRequest(rootRequest);
} CancelTasksResponse resp = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get();
TaskId taskId = getMainTaskId(); assertThat(resp.getTaskFailures(), empty());
assertThat(resp.getNodeFailures(), empty());
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId) ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
.waitForCompletion(true).execute(); .waitForCompletion(true).execute();
ensureChildTasksCancelledOrBanned(taskId);
if (randomBoolean()) {
CancelTasksResponse resp = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get();
assertThat(resp.getTaskFailures(), empty());
assertThat(resp.getNodeFailures(), empty());
}
assertFalse(cancelFuture.isDone()); assertFalse(cancelFuture.isDone());
for (ChildRequest r : childRequests) { allowEntireRequest(rootRequest);
beforeExecuteLatches.get(r).countDown();
}
assertThat(cancelFuture.actionGet().getTaskFailures(), empty()); assertThat(cancelFuture.actionGet().getTaskFailures(), empty());
assertThat(cancelFuture.actionGet().getTaskFailures(), empty()); assertThat(cancelFuture.actionGet().getTaskFailures(), empty());
waitForMainTask(mainTaskFuture); waitForRootTask(mainTaskFuture);
CancelTasksResponse cancelError = client().admin().cluster().prepareCancelTasks() CancelTasksResponse cancelError = client().admin().cluster().prepareCancelTasks()
.setTaskId(taskId).waitForCompletion(randomBoolean()).get(); .setTaskId(taskId).waitForCompletion(randomBoolean()).get();
assertThat(cancelError.getNodeFailures(), hasSize(1)); assertThat(cancelError.getNodeFailures(), hasSize(1));
@ -188,12 +233,12 @@ public class CancellableTasksIT extends ESIntegTestCase {
public void testDoNotWaitForCompletion() throws Exception { public void testDoNotWaitForCompletion() throws Exception {
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
List<ChildRequest> childRequests = setupChildRequests(nodes); TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3));
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests)); ActionFuture<TestResponse> mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) { TaskId taskId = getRootTaskId(rootRequest);
arrivedLatches.get(r).await(); if (randomBoolean()) {
allowPartialRequest(rootRequest);
} }
TaskId taskId = getMainTaskId();
boolean waitForCompletion = randomBoolean(); boolean waitForCompletion = randomBoolean();
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId) ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
.waitForCompletion(waitForCompletion).execute(); .waitForCompletion(waitForCompletion).execute();
@ -202,40 +247,76 @@ public class CancellableTasksIT extends ESIntegTestCase {
} else { } else {
assertBusy(() -> assertTrue(cancelFuture.isDone())); assertBusy(() -> assertTrue(cancelFuture.isDone()));
} }
for (ChildRequest r : childRequests) { allowEntireRequest(rootRequest);
beforeExecuteLatches.get(r).countDown(); waitForRootTask(mainTaskFuture);
}
waitForMainTask(mainTaskFuture);
} }
TaskId getMainTaskId() { public void testFailedToStartChildTaskAfterCancelled() {
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3));
ActionFuture<TestResponse> rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
TaskId taskId = getRootTaskId(rootRequest);
client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get();
DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get();
TransportTestAction mainAction = internalCluster().getInstance(TransportTestAction.class, nodeWithParentTask.getName());
PlainActionFuture<TestResponse> future = new PlainActionFuture<>();
TestRequest subRequest = generateTestRequest(nodes, 0, between(0, 1));
beforeSendLatches.get(subRequest).countDown();
mainAction.startSubTask(taskId, subRequest, future);
TransportException te = expectThrows(TransportException.class, future::actionGet);
assertThat(te.getCause(), instanceOf(TaskCancelledException.class));
assertThat(te.getCause().getMessage(), equalTo("The parent task was cancelled, shouldn't start any child tasks"));
allowEntireRequest(rootRequest);
waitForRootTask(rootTaskFuture);
}
static TaskId getRootTaskId(TestRequest request) {
ListTasksResponse listTasksResponse = client().admin().cluster().prepareListTasks() ListTasksResponse listTasksResponse = client().admin().cluster().prepareListTasks()
.setActions(TransportMainAction.ACTION.name()).setDetailed(true).get(); .setActions(TransportTestAction.ACTION.name()).setDetailed(true).get();
assertThat(listTasksResponse.getTasks(), hasSize(1)); List<TaskInfo> tasks = listTasksResponse.getTasks().stream()
return listTasksResponse.getTasks().get(0).getTaskId(); .filter(t -> t.getDescription().equals(request.taskDescription()))
.collect(Collectors.toList());
assertThat(tasks, hasSize(1));
return tasks.get(0).getTaskId();
} }
void waitForMainTask(ActionFuture<MainResponse> mainTask) { static void waitForRootTask(ActionFuture<TestResponse> rootTask) {
try { try {
mainTask.actionGet(); rootTask.actionGet();
} catch (Exception e) { } catch (Exception e) {
final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class); final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class);
assertThat(cause.getMessage(), assertThat(cause.getMessage(), anyOf(
either(equalTo("The parent task was cancelled, shouldn't start any child tasks")) equalTo("The parent task was cancelled, shouldn't start any child tasks"),
.or(containsString("Task cancelled before it started:"))); containsString("Task cancelled before it started:"),
equalTo("Task was cancelled while executing")));
} }
} }
public static class MainRequest extends ActionRequest { static class TestRequest extends ActionRequest {
final List<ChildRequest> childRequests; final int id;
final DiscoveryNode node;
final List<TestRequest> subRequests;
public MainRequest(List<ChildRequest> childRequests) { TestRequest(int id, DiscoveryNode node, List<TestRequest> subRequests) {
this.childRequests = childRequests; this.id = id;
this.node = node;
this.subRequests = subRequests;
} }
public MainRequest(StreamInput in) throws IOException { TestRequest(StreamInput in) throws IOException {
super(in); super(in);
this.childRequests = in.readList(ChildRequest::new); this.id = in.readInt();
this.node = new DiscoveryNode(in);
this.subRequests = in.readList(TestRequest::new);
}
List<TestRequest> descendants() {
List<TestRequest> descendants = new ArrayList<>();
for (TestRequest subRequest : subRequests) {
descendants.add(subRequest);
descendants.addAll(subRequest.descendants());
}
return descendants;
} }
@Override @Override
@ -243,104 +324,53 @@ public class CancellableTasksIT extends ESIntegTestCase {
return null; return null;
} }
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeList(childRequests);
}
@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) {
@Override
public boolean shouldCancelChildrenOnCancellation() {
return true;
}
};
}
}
public static class MainResponse extends ActionResponse {
public MainResponse() {
}
public MainResponse(StreamInput in) throws IOException {
super(in);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
}
}
public static class ChildRequest extends ActionRequest {
final int id;
final DiscoveryNode targetNode;
public ChildRequest(int id, DiscoveryNode targetNode) {
this.id = id;
this.targetNode = targetNode;
}
public ChildRequest(StreamInput in) throws IOException {
super(in);
this.id = in.readInt();
this.targetNode = new DiscoveryNode(in);
}
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out); super.writeTo(out);
out.writeInt(id); out.writeInt(id);
targetNode.writeTo(out); node.writeTo(out);
} out.writeList(subRequests);
@Override
public ActionRequestValidationException validate() {
return null;
} }
@Override @Override
public String getDescription() { public String getDescription() {
return "childTask[" + id + "]"; return taskDescription();
}
String taskDescription() {
return "id=" + id;
} }
@Override @Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) { public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
if (randomBoolean()) { return new CancellableTask(id, type, action, taskDescription(), parentTaskId, headers) {
boolean shouldCancelChildrenOnCancellation = randomBoolean(); @Override
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { public boolean shouldCancelChildrenOnCancellation() {
@Override return true;
public boolean shouldCancelChildrenOnCancellation() { }
return shouldCancelChildrenOnCancellation; };
}
};
} else {
return super.createTask(id, type, action, parentTaskId, headers);
}
} }
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
ChildRequest that = (ChildRequest) o; TestRequest that = (TestRequest) o;
return id == that.id && targetNode.equals(that.targetNode); return id == that.id;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(id, targetNode); return Objects.hash(id);
} }
} }
public static class ChildResponse extends ActionResponse { public static class TestResponse extends ActionResponse {
public ChildResponse() { public TestResponse() {
} }
public ChildResponse(StreamInput in) throws IOException { public TestResponse(StreamInput in) throws IOException {
super(in); super(in);
} }
@ -350,33 +380,44 @@ public class CancellableTasksIT extends ESIntegTestCase {
} }
} }
public static class TransportMainAction extends HandledTransportAction<MainRequest, MainResponse> { public static class TransportTestAction extends HandledTransportAction<TestRequest, TestResponse> {
public static ActionType<MainResponse> ACTION = new ActionType<>("internal::main_action", MainResponse::new); static AtomicInteger counter = new AtomicInteger();
public static ActionType<TestResponse> ACTION = new ActionType<>("internal::test_action", TestResponse::new);
private final TransportService transportService; private final TransportService transportService;
private final NodeClient client; private final NodeClient client;
@Inject @Inject
public TransportMainAction(TransportService transportService, NodeClient client, ActionFilters actionFilters) { public TransportTestAction(TransportService transportService, NodeClient client, ActionFilters actionFilters) {
super(ACTION.name(), transportService, actionFilters, MainRequest::new, ThreadPool.Names.GENERIC); super(ACTION.name(), transportService, actionFilters, TestRequest::new, ThreadPool.Names.GENERIC);
this.transportService = transportService; this.transportService = transportService;
this.client = client; this.client = client;
} }
@Override @Override
protected void doExecute(Task task, MainRequest request, ActionListener<MainResponse> listener) { protected void doExecute(Task task, TestRequest request, ActionListener<TestResponse> listener) {
GroupedActionListener<ChildResponse> groupedListener = arrivedLatches.get(request).countDown();
new GroupedActionListener<>(ActionListener.map(listener, r -> new MainResponse()), request.childRequests.size()); List<TestRequest> subRequests = request.subRequests;
for (ChildRequest childRequest : request.childRequests) { GroupedActionListener<TestResponse> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> new TestResponse()), subRequests.size() + 1);
transportService.getThreadPool().generic().execute(ActionRunnable.supply(groupedListener, () -> {
beforeExecuteLatches.get(request).await();
if (((CancellableTask) task).isCancelled()) {
throw new TaskCancelledException("Task was cancelled while executing");
}
counter.incrementAndGet();
return new TestResponse();
}));
for (TestRequest subRequest : subRequests) {
TaskId parentTaskId = new TaskId(client.getLocalNodeId(), task.getId()); TaskId parentTaskId = new TaskId(client.getLocalNodeId(), task.getId());
startChildTask(parentTaskId, childRequest, groupedListener); startSubTask(parentTaskId, subRequest, groupedListener);
} }
} }
protected void startChildTask(TaskId parentTaskId, ChildRequest childRequest, ActionListener<ChildResponse> listener) { protected void startSubTask(TaskId parentTaskId, TestRequest subRequest, ActionListener<TestResponse> listener) {
childRequest.setParentTask(parentTaskId); subRequest.setParentTask(parentTaskId);
final CountDownLatch completeLatch = completedLatches.get(childRequest); CountDownLatch completeLatch = completedLatches.get(subRequest);
LatchedActionListener<ChildResponse> latchedListener = new LatchedActionListener<>(listener, completeLatch); LatchedActionListener<TestResponse> latchedListener = new LatchedActionListener<>(listener, completeLatch);
transportService.getThreadPool().generic().execute(new AbstractRunnable() { transportService.getThreadPool().generic().execute(new AbstractRunnable() {
@Override @Override
public void onFailure(Exception e) { public void onFailure(Exception e) {
@ -384,20 +425,20 @@ public class CancellableTasksIT extends ESIntegTestCase {
} }
@Override @Override
protected void doRun() { protected void doRun() throws Exception {
if (client.getLocalNodeId().equals(childRequest.targetNode.getId()) && randomBoolean()) { beforeSendLatches.get(subRequest).await();
if (client.getLocalNodeId().equals(subRequest.node.getId()) && randomBoolean()) {
try { try {
client.executeLocally(TransportChildAction.ACTION, childRequest, latchedListener); client.executeLocally(TransportTestAction.ACTION, subRequest, latchedListener);
} catch (TaskCancelledException e) { } catch (TaskCancelledException e) {
latchedListener.onFailure(new TransportException(e)); latchedListener.onFailure(new TransportException(e));
} }
} else { } else {
transportService.sendRequest(childRequest.targetNode, TransportChildAction.ACTION.name(), childRequest, transportService.sendRequest(subRequest.node, ACTION.name(), subRequest,
new TransportResponseHandler<ChildResponse>() { new TransportResponseHandler<TestResponse>() {
@Override @Override
public void handleResponse(ChildResponse response) { public void handleResponse(TestResponse response) {
latchedListener.onResponse(new ChildResponse()); latchedListener.onResponse(response);
} }
@Override @Override
@ -411,8 +452,8 @@ public class CancellableTasksIT extends ESIntegTestCase {
} }
@Override @Override
public ChildResponse read(StreamInput in) throws IOException { public TestResponse read(StreamInput in) throws IOException {
return new ChildResponse(in); return new TestResponse(in);
} }
}); });
} }
@ -421,40 +462,16 @@ public class CancellableTasksIT extends ESIntegTestCase {
} }
} }
public static class TransportChildAction extends HandledTransportAction<ChildRequest, ChildResponse> {
public static ActionType<ChildResponse> ACTION = new ActionType<>("internal:child_action", ChildResponse::new);
private final TransportService transportService;
@Inject
public TransportChildAction(TransportService transportService, ActionFilters actionFilters) {
super(ACTION.name(), transportService, actionFilters, ChildRequest::new, ThreadPool.Names.GENERIC);
this.transportService = transportService;
}
@Override
protected void doExecute(Task task, ChildRequest request, ActionListener<ChildResponse> listener) {
assertThat(request.targetNode, equalTo(transportService.getLocalNode()));
arrivedLatches.get(request).countDown();
transportService.getThreadPool().executor(ThreadPool.Names.GENERIC).execute(ActionRunnable.supply(listener, () -> {
beforeExecuteLatches.get(request).await();
return new ChildResponse();
}));
}
}
public static class TaskPlugin extends Plugin implements ActionPlugin { public static class TaskPlugin extends Plugin implements ActionPlugin {
@Override @Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() { public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return Arrays.asList( return Collections.singletonList(new ActionHandler<>(TransportTestAction.ACTION, TransportTestAction.class));
new ActionHandler<>(TransportMainAction.ACTION, TransportMainAction.class),
new ActionHandler<>(TransportChildAction.ACTION, TransportChildAction.class)
);
} }
@Override @Override
public List<ActionType<? extends ActionResponse>> getClientActions() { public List<ActionType<? extends ActionResponse>> getClientActions() {
return Arrays.asList(TransportMainAction.ACTION, TransportChildAction.ACTION); return Collections.singletonList(TransportTestAction.ACTION);
} }
} }
@ -471,16 +488,4 @@ public class CancellableTasksIT extends ESIntegTestCase {
plugins.add(TaskPlugin.class); plugins.add(TaskPlugin.class);
return plugins; return plugins;
} }
/**
* Ensures that all outstanding child tasks of the given parent task are banned or being cancelled.
*/
protected static void ensureChildTasksCancelledOrBanned(TaskId taskId) throws Exception {
assertBusy(() -> {
for (String nodeName : internalCluster().getNodeNames()) {
final TaskManager taskManager = internalCluster().getInstance(TransportService.class, nodeName).getTaskManager();
assertTrue(taskManager.childTasksCancelledOrBanned(taskId));
}
});
}
} }