Support hierarchical task cancellation (#54757)

With this change, when a task is canceled, the task manager will cancel
not only its direct child tasks but all also its descendant tasks.

Closes #50990
This commit is contained in:
Nhat Nguyen 2020-04-06 12:00:02 -04:00
parent 51c6f69e02
commit 96bb1164f0
6 changed files with 289 additions and 261 deletions

View File

@ -234,7 +234,7 @@ nodes `nodeId1` and `nodeId2`.
`wait_for_completion`::
(Optional, boolean) If `true`, the request blocks until the cancellation of the
task and its child tasks is completed. Otherwise, the request can return soon
task and its descendant tasks is completed. Otherwise, the request can return soon
after the cancellation is started. Defaults to `false`.
[source,console]

View File

@ -42,7 +42,7 @@
},
"wait_for_completion": {
"type":"boolean",
"description":"Should the request block until the cancellation of the task and its child tasks is completed. Defaults to false"
"description":"Should the request block until the cancellation of the task and its descendant tasks is completed. Defaults to false"
}
}
}

View File

@ -79,7 +79,7 @@ public class CancelTasksRequest extends BaseTasksRequest<CancelTasksRequest> {
}
/**
* If {@code true}, the request blocks until the cancellation of the task and its child tasks is completed.
* If {@code true}, the request blocks until the cancellation of the task and its descendant tasks is completed.
* Otherwise, the request can return soon after the cancellation is started. Defaults to {@code false}.
*/
public void setWaitForCompletion(boolean waitForCompletion) {

View File

@ -20,11 +20,13 @@
package org.elasticsearch.action.admin.cluster.node.tasks.cancel;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.node.DiscoveryNode;
@ -104,34 +106,43 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
@Override
protected void taskOperation(CancelTasksRequest request, CancellableTask cancellableTask, ActionListener<TaskInfo> listener) {
String nodeId = clusterService.localNode().getId();
if (cancellableTask.shouldCancelChildrenOnCancellation()) {
cancelTaskAndDescendants(cancellableTask, request.getReason(), request.waitForCompletion(),
ActionListener.map(listener, r -> cancellableTask.taskInfo(nodeId, false)));
}
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
if (task.shouldCancelChildrenOnCancellation()) {
StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3);
Collection<DiscoveryNode> childrenNodes =
taskManager.startBanOnChildrenNodes(cancellableTask.getId(), () -> groupedListener.onResponse(null));
taskManager.cancel(cancellableTask, request.getReason(), () -> groupedListener.onResponse(null));
taskManager.startBanOnChildrenNodes(task.getId(), () -> groupedListener.onResponse(null));
taskManager.cancel(task, reason, () -> groupedListener.onResponse(null));
StepListener<Void> banOnNodesListener = new StepListener<>();
setBanOnNodes(request.getReason(), cancellableTask, childrenNodes, banOnNodesListener);
setBanOnNodes(reason, waitForCompletion, task, childrenNodes, banOnNodesListener);
banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
// We remove bans after all child tasks are completed although in theory we can do it on a per-node basis.
completedListener.whenComplete(
r -> removeBanOnNodes(cancellableTask, childrenNodes),
e -> removeBanOnNodes(cancellableTask, childrenNodes));
// if wait_for_child_tasks is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
completedListener.whenComplete(r -> removeBanOnNodes(task, childrenNodes), e -> removeBanOnNodes(task, childrenNodes));
// if wait_for_completion is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes.
if (request.waitForCompletion()) {
completedListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure);
if (waitForCompletion) {
completedListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
} else {
banOnNodesListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure);
banOnNodesListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
}
} else {
logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId());
taskManager.cancel(cancellableTask, request.getReason(), () -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)));
logger.trace("task {} doesn't have any children that should be cancelled", task.getId());
if (waitForCompletion) {
taskManager.cancel(task, reason, () -> listener.onResponse(null));
} else {
taskManager.cancel(task, reason, () -> {});
listener.onResponse(null);
}
}
}
private void setBanOnNodes(String reason, CancellableTask task, Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
private void setBanOnNodes(String reason, boolean waitForCompletion, CancellableTask task,
Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
if (childNodes.isEmpty()) {
listener.onResponse(null);
return;
@ -140,7 +151,7 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
GroupedActionListener<Void> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> null), childNodes.size());
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(
new TaskId(clusterService.localNode().getId(), task.getId()), reason);
new TaskId(clusterService.localNode().getId(), task.getId()), reason, waitForCompletion);
for (DiscoveryNode node : childNodes) {
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@ -171,26 +182,29 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
private final TaskId parentTaskId;
private final boolean ban;
private final boolean waitForCompletion;
private final String reason;
static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason) {
return new BanParentTaskRequest(parentTaskId, reason);
static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
return new BanParentTaskRequest(parentTaskId, reason, waitForCompletion);
}
static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) {
return new BanParentTaskRequest(parentTaskId);
}
private BanParentTaskRequest(TaskId parentTaskId, String reason) {
private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
this.parentTaskId = parentTaskId;
this.ban = true;
this.reason = reason;
this.waitForCompletion = waitForCompletion;
}
private BanParentTaskRequest(TaskId parentTaskId) {
this.parentTaskId = parentTaskId;
this.ban = false;
this.reason = null;
this.waitForCompletion = false;
}
private BanParentTaskRequest(StreamInput in) throws IOException {
@ -198,6 +212,11 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
parentTaskId = TaskId.readFromStream(in);
ban = in.readBoolean();
reason = ban ? in.readString() : null;
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
waitForCompletion = in.readBoolean();
} else {
waitForCompletion = false;
}
}
@Override
@ -208,6 +227,9 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
if (ban) {
out.writeString(reason);
}
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeBoolean(waitForCompletion);
}
}
}
@ -217,14 +239,21 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
if (request.ban) {
logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId,
clusterService.localNode().getId(), request.reason);
taskManager.setBan(request.parentTaskId, request.reason);
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason);
final GroupedActionListener<Void> listener = new GroupedActionListener<>(ActionListener.map(
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request), r -> TransportResponse.Empty.INSTANCE),
childTasks.size() + 1);
for (CancellableTask childTask : childTasks) {
cancelTaskAndDescendants(childTask, request.reason, request.waitForCompletion, listener);
}
listener.onResponse(null);
} else {
logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId,
clusterService.localNode().getId());
taskManager.removeBan(request.parentTaskId);
}
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
}
}
}

View File

@ -333,8 +333,9 @@ public class TaskManager implements ClusterStateApplier {
* Bans all tasks with the specified parent task from execution, cancels all tasks that are currently executing.
* <p>
* This method is called when a parent task that has children is cancelled.
* @return a list of pending cancellable child tasks
*/
public void setBan(TaskId parentTaskId, String reason) {
public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
logger.trace("setting ban for the parent task {} {}", parentTaskId, reason);
// Set the ban first, so the newly created tasks cannot be registered
@ -344,14 +345,10 @@ public class TaskManager implements ClusterStateApplier {
banedParents.put(parentTaskId, reason);
}
}
// Now go through already running tasks and cancel them
for (Map.Entry<Long, CancellableTaskHolder> taskEntry : cancellableTasks.entrySet()) {
CancellableTaskHolder holder = taskEntry.getValue();
if (holder.hasParent(parentTaskId)) {
holder.cancel(reason);
}
}
return cancellableTasks.values().stream()
.filter(t -> t.hasParent(parentTaskId))
.map(t -> t.task)
.collect(Collectors.toList());
}
/**
@ -365,11 +362,8 @@ public class TaskManager implements ClusterStateApplier {
}
// for testing
public boolean childTasksCancelledOrBanned(TaskId parentTaskId) {
if (banedParents.containsKey(parentTaskId)) {
return true;
}
return cancellableTasks.values().stream().noneMatch(task -> task.hasParent(parentTaskId));
public Set<TaskId> getBannedTaskIds() {
return Collections.unmodifiableSet(banedParents.keySet());
}
/**

View File

@ -41,12 +41,14 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.threadpool.ThreadPool;
@ -57,97 +59,147 @@ import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.either;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
public class CancellableTasksIT extends ESIntegTestCase {
static final Map<ChildRequest, CountDownLatch> arrivedLatches = ConcurrentCollections.newConcurrentMap();
static final Map<ChildRequest, CountDownLatch> beforeExecuteLatches = ConcurrentCollections.newConcurrentMap();
static final Map<ChildRequest, CountDownLatch> completedLatches = ConcurrentCollections.newConcurrentMap();
static int idGenerator = 0;
static final Map<TestRequest, CountDownLatch> beforeSendLatches = ConcurrentCollections.newConcurrentMap();
static final Map<TestRequest, CountDownLatch> arrivedLatches = ConcurrentCollections.newConcurrentMap();
static final Map<TestRequest, CountDownLatch> beforeExecuteLatches = ConcurrentCollections.newConcurrentMap();
static final Map<TestRequest, CountDownLatch> completedLatches = ConcurrentCollections.newConcurrentMap();
@Before
public void resetTestStates() {
idGenerator = 0;
beforeSendLatches.clear();
arrivedLatches.clear();
beforeExecuteLatches.clear();
completedLatches.clear();
}
List<ChildRequest> setupChildRequests(Set<DiscoveryNode> nodes) {
int numRequests = randomIntBetween(1, 30);
List<ChildRequest> childRequests = new ArrayList<>();
for (int i = 0; i < numRequests; i++) {
ChildRequest req = new ChildRequest(i, randomFrom(nodes));
childRequests.add(req);
arrivedLatches.put(req, new CountDownLatch(1));
beforeExecuteLatches.put(req, new CountDownLatch(1));
completedLatches.put(req, new CountDownLatch(1));
static TestRequest generateTestRequest(Set<DiscoveryNode> nodes, int level, int maxLevel) {
List<TestRequest> subRequests = new ArrayList<>();
int lower = level == 0 ? 1 : 0;
int upper = 10 / (level + 1);
int numOfSubRequests = randomIntBetween(lower, upper);
for (int i = 0; i < numOfSubRequests && level <= maxLevel; i++) {
subRequests.add(generateTestRequest(nodes, level + 1, maxLevel));
}
return childRequests;
final TestRequest request = new TestRequest(idGenerator++, randomFrom(nodes), subRequests);
beforeSendLatches.put(request, new CountDownLatch(1));
arrivedLatches.put(request, new CountDownLatch(1));
beforeExecuteLatches.put(request, new CountDownLatch(1));
completedLatches.put(request, new CountDownLatch(1));
return request;
}
public void testBanOnlyNodesWithOutstandingChildTasks() throws Exception {
static void randomDescendants(TestRequest request, Set<TestRequest> result) {
if (randomBoolean()) {
result.add(request);
for (TestRequest subReq : request.subRequests) {
randomDescendants(subReq, result);
}
}
}
/**
* Allow some parts of the request to be completed
* @return a pending child requests
*/
static Set<TestRequest> allowPartialRequest(TestRequest request) throws Exception {
final Set<TestRequest> sentRequests = new HashSet<>();
while (sentRequests.isEmpty()) {
for (TestRequest subRequest : request.subRequests) {
randomDescendants(subRequest, sentRequests);
}
}
for (TestRequest req : sentRequests) {
beforeSendLatches.get(req).countDown();
}
for (TestRequest req : sentRequests) {
arrivedLatches.get(req).await();
}
Set<TestRequest> completedRequests = new HashSet<>();
for (TestRequest req : randomSubsetOf(sentRequests)) {
if (sentRequests.containsAll(req.descendants())) {
completedRequests.add(req);
completedRequests.addAll(req.descendants());
}
}
for (TestRequest req : completedRequests) {
beforeExecuteLatches.get(req).countDown();
}
for (TestRequest req : completedRequests) {
completedLatches.get(req).await();
}
return Sets.difference(sentRequests, completedRequests);
}
static void allowEntireRequest(TestRequest request) {
beforeSendLatches.get(request).countDown();
beforeExecuteLatches.get(request).countDown();
for (TestRequest subReq : request.subRequests) {
allowEntireRequest(subReq);
}
}
public void testBanOnlyNodesWithOutstandingDescendantTasks() throws Exception {
if (randomBoolean()) {
internalCluster().startNodes(randomIntBetween(1, 3));
}
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
List<ChildRequest> childRequests = setupChildRequests(nodes);
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests));
List<ChildRequest> completedRequests = randomSubsetOf(between(0, childRequests.size() - 1), childRequests);
for (ChildRequest req : completedRequests) {
beforeExecuteLatches.get(req).countDown();
completedLatches.get(req).await();
final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4));
ActionFuture<TestResponse> rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
Set<TestRequest> pendingRequests = allowPartialRequest(rootRequest);
TaskId rootTaskId = getRootTaskId(rootRequest);
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks()
.setTaskId(rootTaskId).waitForCompletion(true).execute();
if (randomBoolean()) {
List<TaskInfo> runningTasks = client().admin().cluster().prepareListTasks()
.setActions(TransportTestAction.ACTION.name()).setDetailed(true).get().getTasks();
for (TaskInfo subTask : randomSubsetOf(runningTasks)) {
client().admin().cluster().prepareCancelTasks().setTaskId(subTask.getTaskId()).waitForCompletion(false).get();
}
List<ChildRequest> outstandingRequests = childRequests.stream().
filter(r -> completedRequests.contains(r) == false)
.collect(Collectors.toList());
for (ChildRequest req : outstandingRequests) {
arrivedLatches.get(req).await();
}
TaskId taskId = getMainTaskId();
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
.waitForCompletion(true).execute();
Set<DiscoveryNode> nodesWithOutstandingChildTask = outstandingRequests.stream().map(r -> r.targetNode).collect(Collectors.toSet());
assertBusy(() -> {
for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
if (nodesWithOutstandingChildTask.contains(node)) {
assertThat(taskManager.getBanCount(), equalTo(1));
} else {
assertThat(taskManager.getBanCount(), equalTo(0));
Set<TaskId> expectedBans = new HashSet<>();
for (TestRequest req : pendingRequests) {
if (req.node.equals(node)) {
List<Task> childTasks = taskManager.getTasks().values().stream()
.filter(t -> t.getParentTaskId() != null && t.getDescription().equals(req.taskDescription()))
.collect(Collectors.toList());
assertThat(childTasks, hasSize(1));
CancellableTask childTask = (CancellableTask) childTasks.get(0);
assertTrue(childTask.isCancelled());
expectedBans.add(childTask.getParentTaskId());
}
}
assertThat(taskManager.getBannedTaskIds(), equalTo(expectedBans));
}
});
// failed to spawn child tasks after cancelling
if (randomBoolean()) {
DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get();
TransportMainAction mainAction = internalCluster().getInstance(TransportMainAction.class, nodeWithParentTask.getName());
PlainActionFuture<ChildResponse> future = new PlainActionFuture<>();
ChildRequest req = new ChildRequest(-1, randomFrom(nodes));
completedLatches.put(req, new CountDownLatch(1));
mainAction.startChildTask(taskId, req, future);
TransportException te = expectThrows(TransportException.class, future::actionGet);
assertThat(te.getCause(), instanceOf(TaskCancelledException.class));
assertThat(te.getCause().getMessage(), equalTo("The parent task was cancelled, shouldn't start any child tasks"));
}
for (ChildRequest req : outstandingRequests) {
beforeExecuteLatches.get(req).countDown();
}
allowEntireRequest(rootRequest);
cancelFuture.actionGet();
waitForMainTask(mainTaskFuture);
waitForRootTask(rootTaskFuture);
assertBusy(() -> {
for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
@ -158,27 +210,20 @@ public class CancellableTasksIT extends ESIntegTestCase {
public void testCancelTaskMultipleTimes() throws Exception {
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
List<ChildRequest> childRequests = setupChildRequests(nodes);
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests));
for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) {
arrivedLatches.get(r).await();
}
TaskId taskId = getMainTaskId();
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
.waitForCompletion(true).execute();
ensureChildTasksCancelledOrBanned(taskId);
if (randomBoolean()) {
TestRequest rootRequest = generateTestRequest(nodes, 0, randomIntBetween(1, 3));
ActionFuture<TestResponse> mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
TaskId taskId = getRootTaskId(rootRequest);
allowPartialRequest(rootRequest);
CancelTasksResponse resp = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get();
assertThat(resp.getTaskFailures(), empty());
assertThat(resp.getNodeFailures(), empty());
}
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
.waitForCompletion(true).execute();
assertFalse(cancelFuture.isDone());
for (ChildRequest r : childRequests) {
beforeExecuteLatches.get(r).countDown();
}
allowEntireRequest(rootRequest);
assertThat(cancelFuture.actionGet().getTaskFailures(), empty());
assertThat(cancelFuture.actionGet().getTaskFailures(), empty());
waitForMainTask(mainTaskFuture);
waitForRootTask(mainTaskFuture);
CancelTasksResponse cancelError = client().admin().cluster().prepareCancelTasks()
.setTaskId(taskId).waitForCompletion(randomBoolean()).get();
assertThat(cancelError.getNodeFailures(), hasSize(1));
@ -188,12 +233,12 @@ public class CancellableTasksIT extends ESIntegTestCase {
public void testDoNotWaitForCompletion() throws Exception {
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
List<ChildRequest> childRequests = setupChildRequests(nodes);
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests));
for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) {
arrivedLatches.get(r).await();
TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3));
ActionFuture<TestResponse> mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
TaskId taskId = getRootTaskId(rootRequest);
if (randomBoolean()) {
allowPartialRequest(rootRequest);
}
TaskId taskId = getMainTaskId();
boolean waitForCompletion = randomBoolean();
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
.waitForCompletion(waitForCompletion).execute();
@ -202,40 +247,76 @@ public class CancellableTasksIT extends ESIntegTestCase {
} else {
assertBusy(() -> assertTrue(cancelFuture.isDone()));
}
for (ChildRequest r : childRequests) {
beforeExecuteLatches.get(r).countDown();
}
waitForMainTask(mainTaskFuture);
allowEntireRequest(rootRequest);
waitForRootTask(mainTaskFuture);
}
TaskId getMainTaskId() {
public void testFailedToStartChildTaskAfterCancelled() {
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3));
ActionFuture<TestResponse> rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
TaskId taskId = getRootTaskId(rootRequest);
client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get();
DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get();
TransportTestAction mainAction = internalCluster().getInstance(TransportTestAction.class, nodeWithParentTask.getName());
PlainActionFuture<TestResponse> future = new PlainActionFuture<>();
TestRequest subRequest = generateTestRequest(nodes, 0, between(0, 1));
beforeSendLatches.get(subRequest).countDown();
mainAction.startSubTask(taskId, subRequest, future);
TransportException te = expectThrows(TransportException.class, future::actionGet);
assertThat(te.getCause(), instanceOf(TaskCancelledException.class));
assertThat(te.getCause().getMessage(), equalTo("The parent task was cancelled, shouldn't start any child tasks"));
allowEntireRequest(rootRequest);
waitForRootTask(rootTaskFuture);
}
static TaskId getRootTaskId(TestRequest request) {
ListTasksResponse listTasksResponse = client().admin().cluster().prepareListTasks()
.setActions(TransportMainAction.ACTION.name()).setDetailed(true).get();
assertThat(listTasksResponse.getTasks(), hasSize(1));
return listTasksResponse.getTasks().get(0).getTaskId();
.setActions(TransportTestAction.ACTION.name()).setDetailed(true).get();
List<TaskInfo> tasks = listTasksResponse.getTasks().stream()
.filter(t -> t.getDescription().equals(request.taskDescription()))
.collect(Collectors.toList());
assertThat(tasks, hasSize(1));
return tasks.get(0).getTaskId();
}
void waitForMainTask(ActionFuture<MainResponse> mainTask) {
static void waitForRootTask(ActionFuture<TestResponse> rootTask) {
try {
mainTask.actionGet();
rootTask.actionGet();
} catch (Exception e) {
final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class);
assertThat(cause.getMessage(),
either(equalTo("The parent task was cancelled, shouldn't start any child tasks"))
.or(containsString("Task cancelled before it started:")));
assertThat(cause.getMessage(), anyOf(
equalTo("The parent task was cancelled, shouldn't start any child tasks"),
containsString("Task cancelled before it started:"),
equalTo("Task was cancelled while executing")));
}
}
public static class MainRequest extends ActionRequest {
final List<ChildRequest> childRequests;
static class TestRequest extends ActionRequest {
final int id;
final DiscoveryNode node;
final List<TestRequest> subRequests;
public MainRequest(List<ChildRequest> childRequests) {
this.childRequests = childRequests;
TestRequest(int id, DiscoveryNode node, List<TestRequest> subRequests) {
this.id = id;
this.node = node;
this.subRequests = subRequests;
}
public MainRequest(StreamInput in) throws IOException {
TestRequest(StreamInput in) throws IOException {
super(in);
this.childRequests = in.readList(ChildRequest::new);
this.id = in.readInt();
this.node = new DiscoveryNode(in);
this.subRequests = in.readList(TestRequest::new);
}
List<TestRequest> descendants() {
List<TestRequest> descendants = new ArrayList<>();
for (TestRequest subRequest : subRequests) {
descendants.add(subRequest);
descendants.addAll(subRequest.descendants());
}
return descendants;
}
@Override
@ -243,104 +324,53 @@ public class CancellableTasksIT extends ESIntegTestCase {
return null;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeList(childRequests);
}
@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) {
@Override
public boolean shouldCancelChildrenOnCancellation() {
return true;
}
};
}
}
public static class MainResponse extends ActionResponse {
public MainResponse() {
}
public MainResponse(StreamInput in) throws IOException {
super(in);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
}
}
public static class ChildRequest extends ActionRequest {
final int id;
final DiscoveryNode targetNode;
public ChildRequest(int id, DiscoveryNode targetNode) {
this.id = id;
this.targetNode = targetNode;
}
public ChildRequest(StreamInput in) throws IOException {
super(in);
this.id = in.readInt();
this.targetNode = new DiscoveryNode(in);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeInt(id);
targetNode.writeTo(out);
}
@Override
public ActionRequestValidationException validate() {
return null;
node.writeTo(out);
out.writeList(subRequests);
}
@Override
public String getDescription() {
return "childTask[" + id + "]";
return taskDescription();
}
String taskDescription() {
return "id=" + id;
}
@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
if (randomBoolean()) {
boolean shouldCancelChildrenOnCancellation = randomBoolean();
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) {
return new CancellableTask(id, type, action, taskDescription(), parentTaskId, headers) {
@Override
public boolean shouldCancelChildrenOnCancellation() {
return shouldCancelChildrenOnCancellation;
return true;
}
};
} else {
return super.createTask(id, type, action, parentTaskId, headers);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ChildRequest that = (ChildRequest) o;
return id == that.id && targetNode.equals(that.targetNode);
TestRequest that = (TestRequest) o;
return id == that.id;
}
@Override
public int hashCode() {
return Objects.hash(id, targetNode);
return Objects.hash(id);
}
}
public static class ChildResponse extends ActionResponse {
public ChildResponse() {
public static class TestResponse extends ActionResponse {
public TestResponse() {
}
public ChildResponse(StreamInput in) throws IOException {
public TestResponse(StreamInput in) throws IOException {
super(in);
}
@ -350,33 +380,44 @@ public class CancellableTasksIT extends ESIntegTestCase {
}
}
public static class TransportMainAction extends HandledTransportAction<MainRequest, MainResponse> {
public static class TransportTestAction extends HandledTransportAction<TestRequest, TestResponse> {
public static ActionType<MainResponse> ACTION = new ActionType<>("internal::main_action", MainResponse::new);
static AtomicInteger counter = new AtomicInteger();
public static ActionType<TestResponse> ACTION = new ActionType<>("internal::test_action", TestResponse::new);
private final TransportService transportService;
private final NodeClient client;
@Inject
public TransportMainAction(TransportService transportService, NodeClient client, ActionFilters actionFilters) {
super(ACTION.name(), transportService, actionFilters, MainRequest::new, ThreadPool.Names.GENERIC);
public TransportTestAction(TransportService transportService, NodeClient client, ActionFilters actionFilters) {
super(ACTION.name(), transportService, actionFilters, TestRequest::new, ThreadPool.Names.GENERIC);
this.transportService = transportService;
this.client = client;
}
@Override
protected void doExecute(Task task, MainRequest request, ActionListener<MainResponse> listener) {
GroupedActionListener<ChildResponse> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> new MainResponse()), request.childRequests.size());
for (ChildRequest childRequest : request.childRequests) {
protected void doExecute(Task task, TestRequest request, ActionListener<TestResponse> listener) {
arrivedLatches.get(request).countDown();
List<TestRequest> subRequests = request.subRequests;
GroupedActionListener<TestResponse> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> new TestResponse()), subRequests.size() + 1);
transportService.getThreadPool().generic().execute(ActionRunnable.supply(groupedListener, () -> {
beforeExecuteLatches.get(request).await();
if (((CancellableTask) task).isCancelled()) {
throw new TaskCancelledException("Task was cancelled while executing");
}
counter.incrementAndGet();
return new TestResponse();
}));
for (TestRequest subRequest : subRequests) {
TaskId parentTaskId = new TaskId(client.getLocalNodeId(), task.getId());
startChildTask(parentTaskId, childRequest, groupedListener);
startSubTask(parentTaskId, subRequest, groupedListener);
}
}
protected void startChildTask(TaskId parentTaskId, ChildRequest childRequest, ActionListener<ChildResponse> listener) {
childRequest.setParentTask(parentTaskId);
final CountDownLatch completeLatch = completedLatches.get(childRequest);
LatchedActionListener<ChildResponse> latchedListener = new LatchedActionListener<>(listener, completeLatch);
protected void startSubTask(TaskId parentTaskId, TestRequest subRequest, ActionListener<TestResponse> listener) {
subRequest.setParentTask(parentTaskId);
CountDownLatch completeLatch = completedLatches.get(subRequest);
LatchedActionListener<TestResponse> latchedListener = new LatchedActionListener<>(listener, completeLatch);
transportService.getThreadPool().generic().execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
@ -384,20 +425,20 @@ public class CancellableTasksIT extends ESIntegTestCase {
}
@Override
protected void doRun() {
if (client.getLocalNodeId().equals(childRequest.targetNode.getId()) && randomBoolean()) {
protected void doRun() throws Exception {
beforeSendLatches.get(subRequest).await();
if (client.getLocalNodeId().equals(subRequest.node.getId()) && randomBoolean()) {
try {
client.executeLocally(TransportChildAction.ACTION, childRequest, latchedListener);
client.executeLocally(TransportTestAction.ACTION, subRequest, latchedListener);
} catch (TaskCancelledException e) {
latchedListener.onFailure(new TransportException(e));
}
} else {
transportService.sendRequest(childRequest.targetNode, TransportChildAction.ACTION.name(), childRequest,
new TransportResponseHandler<ChildResponse>() {
transportService.sendRequest(subRequest.node, ACTION.name(), subRequest,
new TransportResponseHandler<TestResponse>() {
@Override
public void handleResponse(ChildResponse response) {
latchedListener.onResponse(new ChildResponse());
public void handleResponse(TestResponse response) {
latchedListener.onResponse(response);
}
@Override
@ -411,8 +452,8 @@ public class CancellableTasksIT extends ESIntegTestCase {
}
@Override
public ChildResponse read(StreamInput in) throws IOException {
return new ChildResponse(in);
public TestResponse read(StreamInput in) throws IOException {
return new TestResponse(in);
}
});
}
@ -421,40 +462,16 @@ public class CancellableTasksIT extends ESIntegTestCase {
}
}
public static class TransportChildAction extends HandledTransportAction<ChildRequest, ChildResponse> {
public static ActionType<ChildResponse> ACTION = new ActionType<>("internal:child_action", ChildResponse::new);
private final TransportService transportService;
@Inject
public TransportChildAction(TransportService transportService, ActionFilters actionFilters) {
super(ACTION.name(), transportService, actionFilters, ChildRequest::new, ThreadPool.Names.GENERIC);
this.transportService = transportService;
}
@Override
protected void doExecute(Task task, ChildRequest request, ActionListener<ChildResponse> listener) {
assertThat(request.targetNode, equalTo(transportService.getLocalNode()));
arrivedLatches.get(request).countDown();
transportService.getThreadPool().executor(ThreadPool.Names.GENERIC).execute(ActionRunnable.supply(listener, () -> {
beforeExecuteLatches.get(request).await();
return new ChildResponse();
}));
}
}
public static class TaskPlugin extends Plugin implements ActionPlugin {
@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return Arrays.asList(
new ActionHandler<>(TransportMainAction.ACTION, TransportMainAction.class),
new ActionHandler<>(TransportChildAction.ACTION, TransportChildAction.class)
);
return Collections.singletonList(new ActionHandler<>(TransportTestAction.ACTION, TransportTestAction.class));
}
@Override
public List<ActionType<? extends ActionResponse>> getClientActions() {
return Arrays.asList(TransportMainAction.ACTION, TransportChildAction.ACTION);
return Collections.singletonList(TransportTestAction.ACTION);
}
}
@ -471,16 +488,4 @@ public class CancellableTasksIT extends ESIntegTestCase {
plugins.add(TaskPlugin.class);
return plugins;
}
/**
* Ensures that all outstanding child tasks of the given parent task are banned or being cancelled.
*/
protected static void ensureChildTasksCancelledOrBanned(TaskId taskId) throws Exception {
assertBusy(() -> {
for (String nodeName : internalCluster().getNodeNames()) {
final TaskManager taskManager = internalCluster().getInstance(TransportService.class, nodeName).getTaskManager();
assertTrue(taskManager.childTasksCancelledOrBanned(taskId));
}
});
}
}