Support hierarchical task cancellation (#54757)
With this change, when a task is canceled, the task manager will cancel not only its direct child tasks but all also its descendant tasks. Closes #50990
This commit is contained in:
parent
51c6f69e02
commit
96bb1164f0
|
@ -234,7 +234,7 @@ nodes `nodeId1` and `nodeId2`.
|
|||
|
||||
`wait_for_completion`::
|
||||
(Optional, boolean) If `true`, the request blocks until the cancellation of the
|
||||
task and its child tasks is completed. Otherwise, the request can return soon
|
||||
task and its descendant tasks is completed. Otherwise, the request can return soon
|
||||
after the cancellation is started. Defaults to `false`.
|
||||
|
||||
[source,console]
|
||||
|
|
|
@ -42,7 +42,7 @@
|
|||
},
|
||||
"wait_for_completion": {
|
||||
"type":"boolean",
|
||||
"description":"Should the request block until the cancellation of the task and its child tasks is completed. Defaults to false"
|
||||
"description":"Should the request block until the cancellation of the task and its descendant tasks is completed. Defaults to false"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ public class CancelTasksRequest extends BaseTasksRequest<CancelTasksRequest> {
|
|||
}
|
||||
|
||||
/**
|
||||
* If {@code true}, the request blocks until the cancellation of the task and its child tasks is completed.
|
||||
* If {@code true}, the request blocks until the cancellation of the task and its descendant tasks is completed.
|
||||
* Otherwise, the request can return soon after the cancellation is started. Defaults to {@code false}.
|
||||
*/
|
||||
public void setWaitForCompletion(boolean waitForCompletion) {
|
||||
|
|
|
@ -20,11 +20,13 @@
|
|||
package org.elasticsearch.action.admin.cluster.node.tasks.cancel;
|
||||
|
||||
import org.elasticsearch.ResourceNotFoundException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.FailedNodeException;
|
||||
import org.elasticsearch.action.StepListener;
|
||||
import org.elasticsearch.action.TaskOperationFailure;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.ChannelActionListener;
|
||||
import org.elasticsearch.action.support.GroupedActionListener;
|
||||
import org.elasticsearch.action.support.tasks.TransportTasksAction;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
|
@ -104,34 +106,43 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
|
|||
@Override
|
||||
protected void taskOperation(CancelTasksRequest request, CancellableTask cancellableTask, ActionListener<TaskInfo> listener) {
|
||||
String nodeId = clusterService.localNode().getId();
|
||||
if (cancellableTask.shouldCancelChildrenOnCancellation()) {
|
||||
cancelTaskAndDescendants(cancellableTask, request.getReason(), request.waitForCompletion(),
|
||||
ActionListener.map(listener, r -> cancellableTask.taskInfo(nodeId, false)));
|
||||
}
|
||||
|
||||
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
|
||||
if (task.shouldCancelChildrenOnCancellation()) {
|
||||
StepListener<Void> completedListener = new StepListener<>();
|
||||
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3);
|
||||
Collection<DiscoveryNode> childrenNodes =
|
||||
taskManager.startBanOnChildrenNodes(cancellableTask.getId(), () -> groupedListener.onResponse(null));
|
||||
taskManager.cancel(cancellableTask, request.getReason(), () -> groupedListener.onResponse(null));
|
||||
taskManager.startBanOnChildrenNodes(task.getId(), () -> groupedListener.onResponse(null));
|
||||
taskManager.cancel(task, reason, () -> groupedListener.onResponse(null));
|
||||
|
||||
StepListener<Void> banOnNodesListener = new StepListener<>();
|
||||
setBanOnNodes(request.getReason(), cancellableTask, childrenNodes, banOnNodesListener);
|
||||
setBanOnNodes(reason, waitForCompletion, task, childrenNodes, banOnNodesListener);
|
||||
banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
|
||||
// We remove bans after all child tasks are completed although in theory we can do it on a per-node basis.
|
||||
completedListener.whenComplete(
|
||||
r -> removeBanOnNodes(cancellableTask, childrenNodes),
|
||||
e -> removeBanOnNodes(cancellableTask, childrenNodes));
|
||||
// if wait_for_child_tasks is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
|
||||
completedListener.whenComplete(r -> removeBanOnNodes(task, childrenNodes), e -> removeBanOnNodes(task, childrenNodes));
|
||||
// if wait_for_completion is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
|
||||
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes.
|
||||
if (request.waitForCompletion()) {
|
||||
completedListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure);
|
||||
if (waitForCompletion) {
|
||||
completedListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
|
||||
} else {
|
||||
banOnNodesListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure);
|
||||
banOnNodesListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
|
||||
}
|
||||
} else {
|
||||
logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId());
|
||||
taskManager.cancel(cancellableTask, request.getReason(), () -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)));
|
||||
logger.trace("task {} doesn't have any children that should be cancelled", task.getId());
|
||||
if (waitForCompletion) {
|
||||
taskManager.cancel(task, reason, () -> listener.onResponse(null));
|
||||
} else {
|
||||
taskManager.cancel(task, reason, () -> {});
|
||||
listener.onResponse(null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void setBanOnNodes(String reason, CancellableTask task, Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
|
||||
private void setBanOnNodes(String reason, boolean waitForCompletion, CancellableTask task,
|
||||
Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
|
||||
if (childNodes.isEmpty()) {
|
||||
listener.onResponse(null);
|
||||
return;
|
||||
|
@ -140,7 +151,7 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
|
|||
GroupedActionListener<Void> groupedListener =
|
||||
new GroupedActionListener<>(ActionListener.map(listener, r -> null), childNodes.size());
|
||||
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(
|
||||
new TaskId(clusterService.localNode().getId(), task.getId()), reason);
|
||||
new TaskId(clusterService.localNode().getId(), task.getId()), reason, waitForCompletion);
|
||||
for (DiscoveryNode node : childNodes) {
|
||||
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest,
|
||||
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
|
||||
|
@ -171,26 +182,29 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
|
|||
|
||||
private final TaskId parentTaskId;
|
||||
private final boolean ban;
|
||||
private final boolean waitForCompletion;
|
||||
private final String reason;
|
||||
|
||||
static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason) {
|
||||
return new BanParentTaskRequest(parentTaskId, reason);
|
||||
static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
|
||||
return new BanParentTaskRequest(parentTaskId, reason, waitForCompletion);
|
||||
}
|
||||
|
||||
static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) {
|
||||
return new BanParentTaskRequest(parentTaskId);
|
||||
}
|
||||
|
||||
private BanParentTaskRequest(TaskId parentTaskId, String reason) {
|
||||
private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
|
||||
this.parentTaskId = parentTaskId;
|
||||
this.ban = true;
|
||||
this.reason = reason;
|
||||
this.waitForCompletion = waitForCompletion;
|
||||
}
|
||||
|
||||
private BanParentTaskRequest(TaskId parentTaskId) {
|
||||
this.parentTaskId = parentTaskId;
|
||||
this.ban = false;
|
||||
this.reason = null;
|
||||
this.waitForCompletion = false;
|
||||
}
|
||||
|
||||
private BanParentTaskRequest(StreamInput in) throws IOException {
|
||||
|
@ -198,6 +212,11 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
|
|||
parentTaskId = TaskId.readFromStream(in);
|
||||
ban = in.readBoolean();
|
||||
reason = ban ? in.readString() : null;
|
||||
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
waitForCompletion = in.readBoolean();
|
||||
} else {
|
||||
waitForCompletion = false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -208,6 +227,9 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
|
|||
if (ban) {
|
||||
out.writeString(reason);
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
out.writeBoolean(waitForCompletion);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -217,13 +239,20 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
|
|||
if (request.ban) {
|
||||
logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId,
|
||||
clusterService.localNode().getId(), request.reason);
|
||||
taskManager.setBan(request.parentTaskId, request.reason);
|
||||
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason);
|
||||
final GroupedActionListener<Void> listener = new GroupedActionListener<>(ActionListener.map(
|
||||
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request), r -> TransportResponse.Empty.INSTANCE),
|
||||
childTasks.size() + 1);
|
||||
for (CancellableTask childTask : childTasks) {
|
||||
cancelTaskAndDescendants(childTask, request.reason, request.waitForCompletion, listener);
|
||||
}
|
||||
listener.onResponse(null);
|
||||
} else {
|
||||
logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId,
|
||||
clusterService.localNode().getId());
|
||||
taskManager.removeBan(request.parentTaskId);
|
||||
channel.sendResponse(TransportResponse.Empty.INSTANCE);
|
||||
}
|
||||
channel.sendResponse(TransportResponse.Empty.INSTANCE);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -333,8 +333,9 @@ public class TaskManager implements ClusterStateApplier {
|
|||
* Bans all tasks with the specified parent task from execution, cancels all tasks that are currently executing.
|
||||
* <p>
|
||||
* This method is called when a parent task that has children is cancelled.
|
||||
* @return a list of pending cancellable child tasks
|
||||
*/
|
||||
public void setBan(TaskId parentTaskId, String reason) {
|
||||
public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
|
||||
logger.trace("setting ban for the parent task {} {}", parentTaskId, reason);
|
||||
|
||||
// Set the ban first, so the newly created tasks cannot be registered
|
||||
|
@ -344,14 +345,10 @@ public class TaskManager implements ClusterStateApplier {
|
|||
banedParents.put(parentTaskId, reason);
|
||||
}
|
||||
}
|
||||
|
||||
// Now go through already running tasks and cancel them
|
||||
for (Map.Entry<Long, CancellableTaskHolder> taskEntry : cancellableTasks.entrySet()) {
|
||||
CancellableTaskHolder holder = taskEntry.getValue();
|
||||
if (holder.hasParent(parentTaskId)) {
|
||||
holder.cancel(reason);
|
||||
}
|
||||
}
|
||||
return cancellableTasks.values().stream()
|
||||
.filter(t -> t.hasParent(parentTaskId))
|
||||
.map(t -> t.task)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -365,11 +362,8 @@ public class TaskManager implements ClusterStateApplier {
|
|||
}
|
||||
|
||||
// for testing
|
||||
public boolean childTasksCancelledOrBanned(TaskId parentTaskId) {
|
||||
if (banedParents.containsKey(parentTaskId)) {
|
||||
return true;
|
||||
}
|
||||
return cancellableTasks.values().stream().noneMatch(task -> task.hasParent(parentTaskId));
|
||||
public Set<TaskId> getBannedTaskIds() {
|
||||
return Collections.unmodifiableSet(banedParents.keySet());
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -41,12 +41,14 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
|
||||
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.plugins.ActionPlugin;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.tasks.CancellableTask;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.tasks.TaskCancelledException;
|
||||
import org.elasticsearch.tasks.TaskId;
|
||||
import org.elasticsearch.tasks.TaskInfo;
|
||||
import org.elasticsearch.tasks.TaskManager;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
|
@ -57,97 +59,147 @@ import org.junit.Before;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
import static org.hamcrest.Matchers.anyOf;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.either;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
||||
public class CancellableTasksIT extends ESIntegTestCase {
|
||||
static final Map<ChildRequest, CountDownLatch> arrivedLatches = ConcurrentCollections.newConcurrentMap();
|
||||
static final Map<ChildRequest, CountDownLatch> beforeExecuteLatches = ConcurrentCollections.newConcurrentMap();
|
||||
static final Map<ChildRequest, CountDownLatch> completedLatches = ConcurrentCollections.newConcurrentMap();
|
||||
|
||||
static int idGenerator = 0;
|
||||
static final Map<TestRequest, CountDownLatch> beforeSendLatches = ConcurrentCollections.newConcurrentMap();
|
||||
static final Map<TestRequest, CountDownLatch> arrivedLatches = ConcurrentCollections.newConcurrentMap();
|
||||
static final Map<TestRequest, CountDownLatch> beforeExecuteLatches = ConcurrentCollections.newConcurrentMap();
|
||||
static final Map<TestRequest, CountDownLatch> completedLatches = ConcurrentCollections.newConcurrentMap();
|
||||
|
||||
@Before
|
||||
public void resetTestStates() {
|
||||
idGenerator = 0;
|
||||
beforeSendLatches.clear();
|
||||
arrivedLatches.clear();
|
||||
beforeExecuteLatches.clear();
|
||||
completedLatches.clear();
|
||||
}
|
||||
|
||||
List<ChildRequest> setupChildRequests(Set<DiscoveryNode> nodes) {
|
||||
int numRequests = randomIntBetween(1, 30);
|
||||
List<ChildRequest> childRequests = new ArrayList<>();
|
||||
for (int i = 0; i < numRequests; i++) {
|
||||
ChildRequest req = new ChildRequest(i, randomFrom(nodes));
|
||||
childRequests.add(req);
|
||||
arrivedLatches.put(req, new CountDownLatch(1));
|
||||
beforeExecuteLatches.put(req, new CountDownLatch(1));
|
||||
completedLatches.put(req, new CountDownLatch(1));
|
||||
static TestRequest generateTestRequest(Set<DiscoveryNode> nodes, int level, int maxLevel) {
|
||||
List<TestRequest> subRequests = new ArrayList<>();
|
||||
int lower = level == 0 ? 1 : 0;
|
||||
int upper = 10 / (level + 1);
|
||||
int numOfSubRequests = randomIntBetween(lower, upper);
|
||||
for (int i = 0; i < numOfSubRequests && level <= maxLevel; i++) {
|
||||
subRequests.add(generateTestRequest(nodes, level + 1, maxLevel));
|
||||
}
|
||||
return childRequests;
|
||||
final TestRequest request = new TestRequest(idGenerator++, randomFrom(nodes), subRequests);
|
||||
beforeSendLatches.put(request, new CountDownLatch(1));
|
||||
arrivedLatches.put(request, new CountDownLatch(1));
|
||||
beforeExecuteLatches.put(request, new CountDownLatch(1));
|
||||
completedLatches.put(request, new CountDownLatch(1));
|
||||
return request;
|
||||
}
|
||||
|
||||
public void testBanOnlyNodesWithOutstandingChildTasks() throws Exception {
|
||||
static void randomDescendants(TestRequest request, Set<TestRequest> result) {
|
||||
if (randomBoolean()) {
|
||||
result.add(request);
|
||||
for (TestRequest subReq : request.subRequests) {
|
||||
randomDescendants(subReq, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Allow some parts of the request to be completed
|
||||
* @return a pending child requests
|
||||
*/
|
||||
static Set<TestRequest> allowPartialRequest(TestRequest request) throws Exception {
|
||||
final Set<TestRequest> sentRequests = new HashSet<>();
|
||||
while (sentRequests.isEmpty()) {
|
||||
for (TestRequest subRequest : request.subRequests) {
|
||||
randomDescendants(subRequest, sentRequests);
|
||||
}
|
||||
}
|
||||
for (TestRequest req : sentRequests) {
|
||||
beforeSendLatches.get(req).countDown();
|
||||
}
|
||||
for (TestRequest req : sentRequests) {
|
||||
arrivedLatches.get(req).await();
|
||||
}
|
||||
Set<TestRequest> completedRequests = new HashSet<>();
|
||||
for (TestRequest req : randomSubsetOf(sentRequests)) {
|
||||
if (sentRequests.containsAll(req.descendants())) {
|
||||
completedRequests.add(req);
|
||||
completedRequests.addAll(req.descendants());
|
||||
}
|
||||
}
|
||||
for (TestRequest req : completedRequests) {
|
||||
beforeExecuteLatches.get(req).countDown();
|
||||
}
|
||||
for (TestRequest req : completedRequests) {
|
||||
completedLatches.get(req).await();
|
||||
}
|
||||
return Sets.difference(sentRequests, completedRequests);
|
||||
}
|
||||
|
||||
static void allowEntireRequest(TestRequest request) {
|
||||
beforeSendLatches.get(request).countDown();
|
||||
beforeExecuteLatches.get(request).countDown();
|
||||
for (TestRequest subReq : request.subRequests) {
|
||||
allowEntireRequest(subReq);
|
||||
}
|
||||
}
|
||||
|
||||
public void testBanOnlyNodesWithOutstandingDescendantTasks() throws Exception {
|
||||
if (randomBoolean()) {
|
||||
internalCluster().startNodes(randomIntBetween(1, 3));
|
||||
}
|
||||
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
|
||||
List<ChildRequest> childRequests = setupChildRequests(nodes);
|
||||
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests));
|
||||
List<ChildRequest> completedRequests = randomSubsetOf(between(0, childRequests.size() - 1), childRequests);
|
||||
for (ChildRequest req : completedRequests) {
|
||||
beforeExecuteLatches.get(req).countDown();
|
||||
completedLatches.get(req).await();
|
||||
final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4));
|
||||
ActionFuture<TestResponse> rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
|
||||
Set<TestRequest> pendingRequests = allowPartialRequest(rootRequest);
|
||||
TaskId rootTaskId = getRootTaskId(rootRequest);
|
||||
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks()
|
||||
.setTaskId(rootTaskId).waitForCompletion(true).execute();
|
||||
if (randomBoolean()) {
|
||||
List<TaskInfo> runningTasks = client().admin().cluster().prepareListTasks()
|
||||
.setActions(TransportTestAction.ACTION.name()).setDetailed(true).get().getTasks();
|
||||
for (TaskInfo subTask : randomSubsetOf(runningTasks)) {
|
||||
client().admin().cluster().prepareCancelTasks().setTaskId(subTask.getTaskId()).waitForCompletion(false).get();
|
||||
}
|
||||
}
|
||||
List<ChildRequest> outstandingRequests = childRequests.stream().
|
||||
filter(r -> completedRequests.contains(r) == false)
|
||||
.collect(Collectors.toList());
|
||||
for (ChildRequest req : outstandingRequests) {
|
||||
arrivedLatches.get(req).await();
|
||||
}
|
||||
TaskId taskId = getMainTaskId();
|
||||
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
|
||||
.waitForCompletion(true).execute();
|
||||
Set<DiscoveryNode> nodesWithOutstandingChildTask = outstandingRequests.stream().map(r -> r.targetNode).collect(Collectors.toSet());
|
||||
assertBusy(() -> {
|
||||
for (DiscoveryNode node : nodes) {
|
||||
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
|
||||
if (nodesWithOutstandingChildTask.contains(node)) {
|
||||
assertThat(taskManager.getBanCount(), equalTo(1));
|
||||
} else {
|
||||
assertThat(taskManager.getBanCount(), equalTo(0));
|
||||
Set<TaskId> expectedBans = new HashSet<>();
|
||||
for (TestRequest req : pendingRequests) {
|
||||
if (req.node.equals(node)) {
|
||||
List<Task> childTasks = taskManager.getTasks().values().stream()
|
||||
.filter(t -> t.getParentTaskId() != null && t.getDescription().equals(req.taskDescription()))
|
||||
.collect(Collectors.toList());
|
||||
assertThat(childTasks, hasSize(1));
|
||||
CancellableTask childTask = (CancellableTask) childTasks.get(0);
|
||||
assertTrue(childTask.isCancelled());
|
||||
expectedBans.add(childTask.getParentTaskId());
|
||||
}
|
||||
}
|
||||
assertThat(taskManager.getBannedTaskIds(), equalTo(expectedBans));
|
||||
}
|
||||
});
|
||||
// failed to spawn child tasks after cancelling
|
||||
if (randomBoolean()) {
|
||||
DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get();
|
||||
TransportMainAction mainAction = internalCluster().getInstance(TransportMainAction.class, nodeWithParentTask.getName());
|
||||
PlainActionFuture<ChildResponse> future = new PlainActionFuture<>();
|
||||
ChildRequest req = new ChildRequest(-1, randomFrom(nodes));
|
||||
completedLatches.put(req, new CountDownLatch(1));
|
||||
mainAction.startChildTask(taskId, req, future);
|
||||
TransportException te = expectThrows(TransportException.class, future::actionGet);
|
||||
assertThat(te.getCause(), instanceOf(TaskCancelledException.class));
|
||||
assertThat(te.getCause().getMessage(), equalTo("The parent task was cancelled, shouldn't start any child tasks"));
|
||||
}
|
||||
for (ChildRequest req : outstandingRequests) {
|
||||
beforeExecuteLatches.get(req).countDown();
|
||||
}
|
||||
allowEntireRequest(rootRequest);
|
||||
cancelFuture.actionGet();
|
||||
waitForMainTask(mainTaskFuture);
|
||||
waitForRootTask(rootTaskFuture);
|
||||
assertBusy(() -> {
|
||||
for (DiscoveryNode node : nodes) {
|
||||
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
|
||||
|
@ -158,27 +210,20 @@ public class CancellableTasksIT extends ESIntegTestCase {
|
|||
|
||||
public void testCancelTaskMultipleTimes() throws Exception {
|
||||
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
|
||||
List<ChildRequest> childRequests = setupChildRequests(nodes);
|
||||
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests));
|
||||
for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) {
|
||||
arrivedLatches.get(r).await();
|
||||
}
|
||||
TaskId taskId = getMainTaskId();
|
||||
TestRequest rootRequest = generateTestRequest(nodes, 0, randomIntBetween(1, 3));
|
||||
ActionFuture<TestResponse> mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
|
||||
TaskId taskId = getRootTaskId(rootRequest);
|
||||
allowPartialRequest(rootRequest);
|
||||
CancelTasksResponse resp = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get();
|
||||
assertThat(resp.getTaskFailures(), empty());
|
||||
assertThat(resp.getNodeFailures(), empty());
|
||||
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
|
||||
.waitForCompletion(true).execute();
|
||||
ensureChildTasksCancelledOrBanned(taskId);
|
||||
if (randomBoolean()) {
|
||||
CancelTasksResponse resp = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get();
|
||||
assertThat(resp.getTaskFailures(), empty());
|
||||
assertThat(resp.getNodeFailures(), empty());
|
||||
}
|
||||
assertFalse(cancelFuture.isDone());
|
||||
for (ChildRequest r : childRequests) {
|
||||
beforeExecuteLatches.get(r).countDown();
|
||||
}
|
||||
allowEntireRequest(rootRequest);
|
||||
assertThat(cancelFuture.actionGet().getTaskFailures(), empty());
|
||||
assertThat(cancelFuture.actionGet().getTaskFailures(), empty());
|
||||
waitForMainTask(mainTaskFuture);
|
||||
waitForRootTask(mainTaskFuture);
|
||||
CancelTasksResponse cancelError = client().admin().cluster().prepareCancelTasks()
|
||||
.setTaskId(taskId).waitForCompletion(randomBoolean()).get();
|
||||
assertThat(cancelError.getNodeFailures(), hasSize(1));
|
||||
|
@ -188,12 +233,12 @@ public class CancellableTasksIT extends ESIntegTestCase {
|
|||
|
||||
public void testDoNotWaitForCompletion() throws Exception {
|
||||
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
|
||||
List<ChildRequest> childRequests = setupChildRequests(nodes);
|
||||
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests));
|
||||
for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) {
|
||||
arrivedLatches.get(r).await();
|
||||
TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3));
|
||||
ActionFuture<TestResponse> mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
|
||||
TaskId taskId = getRootTaskId(rootRequest);
|
||||
if (randomBoolean()) {
|
||||
allowPartialRequest(rootRequest);
|
||||
}
|
||||
TaskId taskId = getMainTaskId();
|
||||
boolean waitForCompletion = randomBoolean();
|
||||
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
|
||||
.waitForCompletion(waitForCompletion).execute();
|
||||
|
@ -202,40 +247,76 @@ public class CancellableTasksIT extends ESIntegTestCase {
|
|||
} else {
|
||||
assertBusy(() -> assertTrue(cancelFuture.isDone()));
|
||||
}
|
||||
for (ChildRequest r : childRequests) {
|
||||
beforeExecuteLatches.get(r).countDown();
|
||||
}
|
||||
waitForMainTask(mainTaskFuture);
|
||||
allowEntireRequest(rootRequest);
|
||||
waitForRootTask(mainTaskFuture);
|
||||
}
|
||||
|
||||
TaskId getMainTaskId() {
|
||||
public void testFailedToStartChildTaskAfterCancelled() {
|
||||
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
|
||||
TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3));
|
||||
ActionFuture<TestResponse> rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
|
||||
TaskId taskId = getRootTaskId(rootRequest);
|
||||
client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get();
|
||||
DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get();
|
||||
TransportTestAction mainAction = internalCluster().getInstance(TransportTestAction.class, nodeWithParentTask.getName());
|
||||
PlainActionFuture<TestResponse> future = new PlainActionFuture<>();
|
||||
TestRequest subRequest = generateTestRequest(nodes, 0, between(0, 1));
|
||||
beforeSendLatches.get(subRequest).countDown();
|
||||
mainAction.startSubTask(taskId, subRequest, future);
|
||||
TransportException te = expectThrows(TransportException.class, future::actionGet);
|
||||
assertThat(te.getCause(), instanceOf(TaskCancelledException.class));
|
||||
assertThat(te.getCause().getMessage(), equalTo("The parent task was cancelled, shouldn't start any child tasks"));
|
||||
allowEntireRequest(rootRequest);
|
||||
waitForRootTask(rootTaskFuture);
|
||||
}
|
||||
|
||||
static TaskId getRootTaskId(TestRequest request) {
|
||||
ListTasksResponse listTasksResponse = client().admin().cluster().prepareListTasks()
|
||||
.setActions(TransportMainAction.ACTION.name()).setDetailed(true).get();
|
||||
assertThat(listTasksResponse.getTasks(), hasSize(1));
|
||||
return listTasksResponse.getTasks().get(0).getTaskId();
|
||||
.setActions(TransportTestAction.ACTION.name()).setDetailed(true).get();
|
||||
List<TaskInfo> tasks = listTasksResponse.getTasks().stream()
|
||||
.filter(t -> t.getDescription().equals(request.taskDescription()))
|
||||
.collect(Collectors.toList());
|
||||
assertThat(tasks, hasSize(1));
|
||||
return tasks.get(0).getTaskId();
|
||||
}
|
||||
|
||||
void waitForMainTask(ActionFuture<MainResponse> mainTask) {
|
||||
static void waitForRootTask(ActionFuture<TestResponse> rootTask) {
|
||||
try {
|
||||
mainTask.actionGet();
|
||||
rootTask.actionGet();
|
||||
} catch (Exception e) {
|
||||
final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class);
|
||||
assertThat(cause.getMessage(),
|
||||
either(equalTo("The parent task was cancelled, shouldn't start any child tasks"))
|
||||
.or(containsString("Task cancelled before it started:")));
|
||||
assertThat(cause.getMessage(), anyOf(
|
||||
equalTo("The parent task was cancelled, shouldn't start any child tasks"),
|
||||
containsString("Task cancelled before it started:"),
|
||||
equalTo("Task was cancelled while executing")));
|
||||
}
|
||||
}
|
||||
|
||||
public static class MainRequest extends ActionRequest {
|
||||
final List<ChildRequest> childRequests;
|
||||
static class TestRequest extends ActionRequest {
|
||||
final int id;
|
||||
final DiscoveryNode node;
|
||||
final List<TestRequest> subRequests;
|
||||
|
||||
public MainRequest(List<ChildRequest> childRequests) {
|
||||
this.childRequests = childRequests;
|
||||
TestRequest(int id, DiscoveryNode node, List<TestRequest> subRequests) {
|
||||
this.id = id;
|
||||
this.node = node;
|
||||
this.subRequests = subRequests;
|
||||
}
|
||||
|
||||
public MainRequest(StreamInput in) throws IOException {
|
||||
TestRequest(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.childRequests = in.readList(ChildRequest::new);
|
||||
this.id = in.readInt();
|
||||
this.node = new DiscoveryNode(in);
|
||||
this.subRequests = in.readList(TestRequest::new);
|
||||
}
|
||||
|
||||
List<TestRequest> descendants() {
|
||||
List<TestRequest> descendants = new ArrayList<>();
|
||||
for (TestRequest subRequest : subRequests) {
|
||||
descendants.add(subRequest);
|
||||
descendants.addAll(subRequest.descendants());
|
||||
}
|
||||
return descendants;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -243,104 +324,53 @@ public class CancellableTasksIT extends ESIntegTestCase {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeList(childRequests);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
|
||||
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) {
|
||||
@Override
|
||||
public boolean shouldCancelChildrenOnCancellation() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public static class MainResponse extends ActionResponse {
|
||||
public MainResponse() {
|
||||
}
|
||||
|
||||
public MainResponse(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
public static class ChildRequest extends ActionRequest {
|
||||
final int id;
|
||||
final DiscoveryNode targetNode;
|
||||
|
||||
public ChildRequest(int id, DiscoveryNode targetNode) {
|
||||
this.id = id;
|
||||
this.targetNode = targetNode;
|
||||
}
|
||||
|
||||
|
||||
public ChildRequest(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.id = in.readInt();
|
||||
this.targetNode = new DiscoveryNode(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeInt(id);
|
||||
targetNode.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActionRequestValidationException validate() {
|
||||
return null;
|
||||
node.writeTo(out);
|
||||
out.writeList(subRequests);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getDescription() {
|
||||
return "childTask[" + id + "]";
|
||||
return taskDescription();
|
||||
}
|
||||
|
||||
String taskDescription() {
|
||||
return "id=" + id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
|
||||
if (randomBoolean()) {
|
||||
boolean shouldCancelChildrenOnCancellation = randomBoolean();
|
||||
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) {
|
||||
@Override
|
||||
public boolean shouldCancelChildrenOnCancellation() {
|
||||
return shouldCancelChildrenOnCancellation;
|
||||
}
|
||||
};
|
||||
} else {
|
||||
return super.createTask(id, type, action, parentTaskId, headers);
|
||||
}
|
||||
return new CancellableTask(id, type, action, taskDescription(), parentTaskId, headers) {
|
||||
@Override
|
||||
public boolean shouldCancelChildrenOnCancellation() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ChildRequest that = (ChildRequest) o;
|
||||
return id == that.id && targetNode.equals(that.targetNode);
|
||||
TestRequest that = (TestRequest) o;
|
||||
return id == that.id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(id, targetNode);
|
||||
return Objects.hash(id);
|
||||
}
|
||||
}
|
||||
|
||||
public static class ChildResponse extends ActionResponse {
|
||||
public ChildResponse() {
|
||||
public static class TestResponse extends ActionResponse {
|
||||
public TestResponse() {
|
||||
|
||||
}
|
||||
|
||||
public ChildResponse(StreamInput in) throws IOException {
|
||||
public TestResponse(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
}
|
||||
|
||||
|
@ -350,33 +380,44 @@ public class CancellableTasksIT extends ESIntegTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public static class TransportMainAction extends HandledTransportAction<MainRequest, MainResponse> {
|
||||
public static class TransportTestAction extends HandledTransportAction<TestRequest, TestResponse> {
|
||||
|
||||
public static ActionType<MainResponse> ACTION = new ActionType<>("internal::main_action", MainResponse::new);
|
||||
static AtomicInteger counter = new AtomicInteger();
|
||||
public static ActionType<TestResponse> ACTION = new ActionType<>("internal::test_action", TestResponse::new);
|
||||
private final TransportService transportService;
|
||||
private final NodeClient client;
|
||||
|
||||
@Inject
|
||||
public TransportMainAction(TransportService transportService, NodeClient client, ActionFilters actionFilters) {
|
||||
super(ACTION.name(), transportService, actionFilters, MainRequest::new, ThreadPool.Names.GENERIC);
|
||||
public TransportTestAction(TransportService transportService, NodeClient client, ActionFilters actionFilters) {
|
||||
super(ACTION.name(), transportService, actionFilters, TestRequest::new, ThreadPool.Names.GENERIC);
|
||||
this.transportService = transportService;
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task, MainRequest request, ActionListener<MainResponse> listener) {
|
||||
GroupedActionListener<ChildResponse> groupedListener =
|
||||
new GroupedActionListener<>(ActionListener.map(listener, r -> new MainResponse()), request.childRequests.size());
|
||||
for (ChildRequest childRequest : request.childRequests) {
|
||||
protected void doExecute(Task task, TestRequest request, ActionListener<TestResponse> listener) {
|
||||
arrivedLatches.get(request).countDown();
|
||||
List<TestRequest> subRequests = request.subRequests;
|
||||
GroupedActionListener<TestResponse> groupedListener =
|
||||
new GroupedActionListener<>(ActionListener.map(listener, r -> new TestResponse()), subRequests.size() + 1);
|
||||
transportService.getThreadPool().generic().execute(ActionRunnable.supply(groupedListener, () -> {
|
||||
beforeExecuteLatches.get(request).await();
|
||||
if (((CancellableTask) task).isCancelled()) {
|
||||
throw new TaskCancelledException("Task was cancelled while executing");
|
||||
}
|
||||
counter.incrementAndGet();
|
||||
return new TestResponse();
|
||||
}));
|
||||
for (TestRequest subRequest : subRequests) {
|
||||
TaskId parentTaskId = new TaskId(client.getLocalNodeId(), task.getId());
|
||||
startChildTask(parentTaskId, childRequest, groupedListener);
|
||||
startSubTask(parentTaskId, subRequest, groupedListener);
|
||||
}
|
||||
}
|
||||
|
||||
protected void startChildTask(TaskId parentTaskId, ChildRequest childRequest, ActionListener<ChildResponse> listener) {
|
||||
childRequest.setParentTask(parentTaskId);
|
||||
final CountDownLatch completeLatch = completedLatches.get(childRequest);
|
||||
LatchedActionListener<ChildResponse> latchedListener = new LatchedActionListener<>(listener, completeLatch);
|
||||
protected void startSubTask(TaskId parentTaskId, TestRequest subRequest, ActionListener<TestResponse> listener) {
|
||||
subRequest.setParentTask(parentTaskId);
|
||||
CountDownLatch completeLatch = completedLatches.get(subRequest);
|
||||
LatchedActionListener<TestResponse> latchedListener = new LatchedActionListener<>(listener, completeLatch);
|
||||
transportService.getThreadPool().generic().execute(new AbstractRunnable() {
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
|
@ -384,20 +425,20 @@ public class CancellableTasksIT extends ESIntegTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected void doRun() {
|
||||
if (client.getLocalNodeId().equals(childRequest.targetNode.getId()) && randomBoolean()) {
|
||||
protected void doRun() throws Exception {
|
||||
beforeSendLatches.get(subRequest).await();
|
||||
if (client.getLocalNodeId().equals(subRequest.node.getId()) && randomBoolean()) {
|
||||
try {
|
||||
client.executeLocally(TransportChildAction.ACTION, childRequest, latchedListener);
|
||||
client.executeLocally(TransportTestAction.ACTION, subRequest, latchedListener);
|
||||
} catch (TaskCancelledException e) {
|
||||
latchedListener.onFailure(new TransportException(e));
|
||||
}
|
||||
} else {
|
||||
transportService.sendRequest(childRequest.targetNode, TransportChildAction.ACTION.name(), childRequest,
|
||||
new TransportResponseHandler<ChildResponse>() {
|
||||
|
||||
transportService.sendRequest(subRequest.node, ACTION.name(), subRequest,
|
||||
new TransportResponseHandler<TestResponse>() {
|
||||
@Override
|
||||
public void handleResponse(ChildResponse response) {
|
||||
latchedListener.onResponse(new ChildResponse());
|
||||
public void handleResponse(TestResponse response) {
|
||||
latchedListener.onResponse(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -411,8 +452,8 @@ public class CancellableTasksIT extends ESIntegTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public ChildResponse read(StreamInput in) throws IOException {
|
||||
return new ChildResponse(in);
|
||||
public TestResponse read(StreamInput in) throws IOException {
|
||||
return new TestResponse(in);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -421,40 +462,16 @@ public class CancellableTasksIT extends ESIntegTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public static class TransportChildAction extends HandledTransportAction<ChildRequest, ChildResponse> {
|
||||
public static ActionType<ChildResponse> ACTION = new ActionType<>("internal:child_action", ChildResponse::new);
|
||||
private final TransportService transportService;
|
||||
|
||||
|
||||
@Inject
|
||||
public TransportChildAction(TransportService transportService, ActionFilters actionFilters) {
|
||||
super(ACTION.name(), transportService, actionFilters, ChildRequest::new, ThreadPool.Names.GENERIC);
|
||||
this.transportService = transportService;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task, ChildRequest request, ActionListener<ChildResponse> listener) {
|
||||
assertThat(request.targetNode, equalTo(transportService.getLocalNode()));
|
||||
arrivedLatches.get(request).countDown();
|
||||
transportService.getThreadPool().executor(ThreadPool.Names.GENERIC).execute(ActionRunnable.supply(listener, () -> {
|
||||
beforeExecuteLatches.get(request).await();
|
||||
return new ChildResponse();
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
public static class TaskPlugin extends Plugin implements ActionPlugin {
|
||||
@Override
|
||||
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
|
||||
return Arrays.asList(
|
||||
new ActionHandler<>(TransportMainAction.ACTION, TransportMainAction.class),
|
||||
new ActionHandler<>(TransportChildAction.ACTION, TransportChildAction.class)
|
||||
);
|
||||
return Collections.singletonList(new ActionHandler<>(TransportTestAction.ACTION, TransportTestAction.class));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ActionType<? extends ActionResponse>> getClientActions() {
|
||||
return Arrays.asList(TransportMainAction.ACTION, TransportChildAction.ACTION);
|
||||
return Collections.singletonList(TransportTestAction.ACTION);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -471,16 +488,4 @@ public class CancellableTasksIT extends ESIntegTestCase {
|
|||
plugins.add(TaskPlugin.class);
|
||||
return plugins;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that all outstanding child tasks of the given parent task are banned or being cancelled.
|
||||
*/
|
||||
protected static void ensureChildTasksCancelledOrBanned(TaskId taskId) throws Exception {
|
||||
assertBusy(() -> {
|
||||
for (String nodeName : internalCluster().getNodeNames()) {
|
||||
final TaskManager taskManager = internalCluster().getInstance(TransportService.class, nodeName).getTaskManager();
|
||||
assertTrue(taskManager.childTasksCancelledOrBanned(taskId));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue