Broadcast cancellation to only nodes have outstanding child tasks (#54312)

Today when canceling a task we broadcast ban/unban requests to all nodes
in the cluster. This strategy does not scale well for hierarchical
cancellation. With this change, we will track outstanding child requests
and broadcast the cancellation to only nodes that have outstanding child
tasks. This change also prevents a parent task from sending child
requests once it got canceled.

Relates #50990
Supersedes #51157

Co-authored-by: Igor Motov <igor@motovs.org>
Co-authored-by: Yannick Welsch <yannick@welsch.lu>
This commit is contained in:
Nhat Nguyen 2020-04-01 11:22:13 -04:00
parent 7dc1ba4273
commit 2fdbed7797
17 changed files with 942 additions and 223 deletions

View File

@ -39,6 +39,9 @@ final class TasksRequestConverters {
params
.withNodes(req.getNodes())
.withActions(req.getActions());
if (req.getWaitForCompletion() != null) {
params.withWaitForCompletion(req.getWaitForCompletion());
}
request.addParameters(params.asMap());
return request;
}

View File

@ -33,6 +33,7 @@ public class CancelTasksRequest implements Validatable {
private Optional<TimeValue> timeout = Optional.empty();
private Optional<TaskId> parentTaskId = Optional.empty();
private Optional<TaskId> taskId = Optional.empty();
private Boolean waitForCompletion;
CancelTasksRequest(){}
@ -76,6 +77,14 @@ public class CancelTasksRequest implements Validatable {
return taskId;
}
public Boolean getWaitForCompletion() {
return waitForCompletion;
}
public void setWaitForCompletion(boolean waitForCompletion) {
this.waitForCompletion = waitForCompletion;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
@ -85,12 +94,13 @@ public class CancelTasksRequest implements Validatable {
Objects.equals(getActions(), that.getActions()) &&
Objects.equals(getTimeout(), that.getTimeout()) &&
Objects.equals(getParentTaskId(), that.getParentTaskId()) &&
Objects.equals(getTaskId(), that.getTaskId()) ;
Objects.equals(getTaskId(), that.getTaskId()) &&
Objects.equals(waitForCompletion, that.waitForCompletion);
}
@Override
public int hashCode() {
return Objects.hash(getNodes(), getActions(), getTimeout(), getParentTaskId(), getTaskId());
return Objects.hash(getNodes(), getActions(), getTimeout(), getParentTaskId(), getTaskId(), waitForCompletion);
}
@Override
@ -101,6 +111,7 @@ public class CancelTasksRequest implements Validatable {
", timeout=" + timeout +
", parentTaskId=" + parentTaskId +
", taskId=" + taskId +
", waitForCompletion=" + waitForCompletion +
'}';
}
@ -110,6 +121,7 @@ public class CancelTasksRequest implements Validatable {
private Optional<TaskId> parentTaskId = Optional.empty();
private List<String> actionsFilter = new ArrayList<>();
private List<String> nodesFilter = new ArrayList<>();
private Boolean waitForCompletion;
public Builder withTimeout(TimeValue timeout){
this.timeout = Optional.of(timeout);
@ -138,6 +150,11 @@ public class CancelTasksRequest implements Validatable {
return this;
}
public Builder withWaitForCompletion(boolean waitForCompletion) {
this.waitForCompletion = waitForCompletion;
return this;
}
public CancelTasksRequest build() {
CancelTasksRequest request = new CancelTasksRequest();
timeout.ifPresent(request::setTimeout);
@ -145,6 +162,9 @@ public class CancelTasksRequest implements Validatable {
parentTaskId.ifPresent(request::setParentTaskId);
request.setNodes(nodesFilter);
request.setActions(actionsFilter);
if (waitForCompletion != null) {
request.setWaitForCompletion(waitForCompletion);
}
return request;
}
}

View File

@ -22,6 +22,7 @@ package org.elasticsearch.client;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequest;
import org.elasticsearch.client.tasks.CancelTasksRequest;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
@ -40,14 +41,15 @@ public class TasksRequestConvertersTests extends ESTestCase {
new org.elasticsearch.client.tasks.TaskId(randomAlphaOfLength(5), randomNonNegativeLong());
org.elasticsearch.client.tasks.TaskId parentTaskId =
new org.elasticsearch.client.tasks.TaskId(randomAlphaOfLength(5), randomNonNegativeLong());
org.elasticsearch.client.tasks.CancelTasksRequest request =
new org.elasticsearch.client.tasks.CancelTasksRequest.Builder()
.withTaskId(taskId)
.withParentTaskId(parentTaskId)
.build();
CancelTasksRequest.Builder builder = new CancelTasksRequest.Builder().withTaskId(taskId).withParentTaskId(parentTaskId);
expectedParams.put("task_id", taskId.toString());
expectedParams.put("parent_task_id", parentTaskId.toString());
Request httpRequest = TasksRequestConverters.cancelTasks(request);
if (randomBoolean()) {
boolean waitForCompletion = randomBoolean();
builder.withWaitForCompletion(waitForCompletion);
expectedParams.put("wait_for_completion", Boolean.toString(waitForCompletion));
}
Request httpRequest = TasksRequestConverters.cancelTasks(builder.build());
assertThat(httpRequest, notNullValue());
assertThat(httpRequest.getMethod(), equalTo(HttpPost.METHOD_NAME));
assertThat(httpRequest.getEntity(), nullValue());

View File

@ -166,7 +166,8 @@ public class TasksClientDocumentationIT extends ESRestHighLevelClientTestCase {
// tag::cancel-tasks-request-filter
CancelTasksRequest byTaskIdRequest = new org.elasticsearch.client.tasks.CancelTasksRequest.Builder() // <1>
.withTaskId(new org.elasticsearch.client.tasks.TaskId("myNode",44L)) // <2>
.build(); // <3>
.withWaitForCompletion(true) // <3>
.build(); // <4>
// end::cancel-tasks-request-filter
}

View File

@ -22,7 +22,9 @@ include-tagged::{doc-tests}/TasksClientDocumentationIT.java[cancel-tasks-request
--------------------------------------------------
<1> Cancel a task
<2> Cancel only cluster-related tasks
<3> Cancel all tasks running on nodes nodeId1 and nodeId2
<3> Should the request block until the cancellation of the task and its child tasks is completed.
Otherwise, the request can return soon after the cancellation is started. Defaults to `false`.
<4> Cancel all tasks running on nodes nodeId1 and nodeId2
==== Synchronous Execution

View File

@ -232,6 +232,11 @@ list tasks command, so multiple tasks can be cancelled at the same time. For
example, the following command will cancel all reindex tasks running on the
nodes `nodeId1` and `nodeId2`.
`wait_for_completion`::
(Optional, boolean) If `true`, the request blocks until the cancellation of the
task and its child tasks is completed. Otherwise, the request can return soon
after the cancellation is started. Defaults to `false`.
[source,console]
--------------------------------------------------
POST _tasks/_cancel?nodes=nodeId1,nodeId2&actions=*reindex

View File

@ -39,6 +39,10 @@
"parent_task_id":{
"type":"string",
"description":"Cancel tasks with specified parent task id (node_id:task_number). Set to -1 to cancel all."
},
"wait_for_completion": {
"type":"boolean",
"description":"Should the request block until the cancellation of the task and its child tasks is completed. Defaults to false"
}
}
}

View File

@ -19,6 +19,7 @@
package org.elasticsearch.action.admin.cluster.node.tasks.cancel;
import org.elasticsearch.Version;
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -33,20 +34,28 @@ import java.io.IOException;
public class CancelTasksRequest extends BaseTasksRequest<CancelTasksRequest> {
public static final String DEFAULT_REASON = "by user request";
public static final boolean DEFAULT_WAIT_FOR_COMPLETION = false;
private String reason = DEFAULT_REASON;
private boolean waitForCompletion = DEFAULT_WAIT_FOR_COMPLETION;
public CancelTasksRequest() {}
public CancelTasksRequest(StreamInput in) throws IOException {
super(in);
this.reason = in.readString();
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
waitForCompletion = in.readBoolean();
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(reason);
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeBoolean(waitForCompletion);
}
}
@Override
@ -68,4 +77,16 @@ public class CancelTasksRequest extends BaseTasksRequest<CancelTasksRequest> {
public String getReason() {
return reason;
}
/**
* If {@code true}, the request blocks until the cancellation of the task and its child tasks is completed.
* Otherwise, the request can return soon after the cancellation is started. Defaults to {@code false}.
*/
public void setWaitForCompletion(boolean waitForCompletion) {
this.waitForCompletion = waitForCompletion;
}
public boolean waitForCompletion() {
return waitForCompletion;
}
}

View File

@ -31,4 +31,8 @@ public class CancelTasksRequestBuilder extends TasksRequestBuilder<CancelTasksRe
super(client, action, new CancelTasksRequest());
}
public CancelTasksRequestBuilder waitForCompletion(boolean waitForCompletion) {
request.setWaitForCompletion(waitForCompletion);
return this;
}
}

View File

@ -19,15 +19,15 @@
package org.elasticsearch.action.admin.cluster.node.tasks.cancel;
import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
@ -46,9 +46,8 @@ import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
/**
@ -90,7 +89,7 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
// The task exists, but doesn't support cancellation
throw new IllegalArgumentException("task [" + request.getTaskId() + "] doesn't support cancellation");
} else {
throw new ResourceNotFoundException("task [{}] doesn't support cancellation", request.getTaskId());
throw new ResourceNotFoundException("task [{}] is not found", request.getTaskId());
}
}
} else {
@ -103,129 +102,71 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
}
@Override
protected synchronized void taskOperation(CancelTasksRequest request, CancellableTask cancellableTask,
ActionListener<TaskInfo> listener) {
protected void taskOperation(CancelTasksRequest request, CancellableTask cancellableTask, ActionListener<TaskInfo> listener) {
String nodeId = clusterService.localNode().getId();
final boolean canceled;
if (cancellableTask.shouldCancelChildrenOnCancellation()) {
DiscoveryNodes childNodes = clusterService.state().nodes();
final BanLock banLock = new BanLock(childNodes.getSize(), () -> removeBanOnNodes(cancellableTask, childNodes));
canceled = taskManager.cancel(cancellableTask, request.getReason(), banLock::onTaskFinished);
if (canceled) {
// /In case the task has some child tasks, we need to wait for until ban is set on all nodes
logger.trace("cancelling task {} on child nodes", cancellableTask.getId());
AtomicInteger responses = new AtomicInteger(childNodes.getSize());
List<Exception> failures = new ArrayList<>();
setBanOnNodes(request.getReason(), cancellableTask, childNodes, new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
processResponse();
}
StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3);
Collection<DiscoveryNode> childrenNodes =
taskManager.startBanOnChildrenNodes(cancellableTask.getId(), () -> groupedListener.onResponse(null));
taskManager.cancel(cancellableTask, request.getReason(), () -> groupedListener.onResponse(null));
@Override
public void onFailure(Exception e) {
synchronized (failures) {
failures.add(e);
}
processResponse();
}
private void processResponse() {
banLock.onBanSet();
if (responses.decrementAndGet() == 0) {
if (failures.isEmpty() == false) {
IllegalStateException exception = new IllegalStateException("failed to cancel children of the task [" +
cancellableTask.getId() + "]");
failures.forEach(exception::addSuppressed);
listener.onFailure(exception);
} else {
listener.onResponse(cancellableTask.taskInfo(nodeId, false));
}
}
}
});
StepListener<Void> banOnNodesListener = new StepListener<>();
setBanOnNodes(request.getReason(), cancellableTask, childrenNodes, banOnNodesListener);
banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
// We remove bans after all child tasks are completed although in theory we can do it on a per-node basis.
completedListener.whenComplete(
r -> removeBanOnNodes(cancellableTask, childrenNodes),
e -> removeBanOnNodes(cancellableTask, childrenNodes));
// if wait_for_child_tasks is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes.
if (request.waitForCompletion()) {
completedListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure);
} else {
banOnNodesListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure);
}
} else {
canceled = taskManager.cancel(cancellableTask, request.getReason(),
() -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)));
if (canceled) {
logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId());
}
}
if (canceled == false) {
logger.trace("task {} is already cancelled", cancellableTask.getId());
throw new IllegalStateException("task with id " + cancellableTask.getId() + " is already cancelled");
} else {
logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId());
taskManager.cancel(cancellableTask, request.getReason(), () -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)));
}
}
private void setBanOnNodes(String reason, CancellableTask task, DiscoveryNodes nodes, ActionListener<Void> listener) {
sendSetBanRequest(nodes,
BanParentTaskRequest.createSetBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId()), reason),
listener);
}
private void removeBanOnNodes(CancellableTask task, DiscoveryNodes nodes) {
sendRemoveBanRequest(nodes,
BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId())));
}
private void sendSetBanRequest(DiscoveryNodes nodes, BanParentTaskRequest request, ActionListener<Void> listener) {
for (ObjectObjectCursor<String, DiscoveryNode> node : nodes.getNodes()) {
logger.trace("Sending ban for tasks with the parent [{}] to the node [{}], ban [{}]", request.parentTaskId, node.key,
request.ban);
transportService.sendRequest(node.value, BAN_PARENT_ACTION_NAME, request,
private void setBanOnNodes(String reason, CancellableTask task, Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
if (childNodes.isEmpty()) {
listener.onResponse(null);
return;
}
logger.trace("cancelling task {} on child nodes {}", task.getId(), childNodes);
GroupedActionListener<Void> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> null), childNodes.size());
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(
new TaskId(clusterService.localNode().getId(), task.getId()), reason);
for (DiscoveryNode node : childNodes) {
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleResponse(TransportResponse.Empty response) {
listener.onResponse(null);
groupedListener.onResponse(null);
}
@Override
public void handleException(TransportException exp) {
logger.warn("Cannot send ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node.key);
listener.onFailure(exp);
logger.warn("Cannot send ban for tasks with the parent [{}] to the node [{}]", banRequest.parentTaskId, node);
groupedListener.onFailure(exp);
}
});
}
}
private void sendRemoveBanRequest(DiscoveryNodes nodes, BanParentTaskRequest request) {
for (ObjectObjectCursor<String, DiscoveryNode> node : nodes.getNodes()) {
logger.debug("Sending remove ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node.key);
transportService.sendRequest(node.value, BAN_PARENT_ACTION_NAME, request, EmptyTransportResponseHandler
.INSTANCE_SAME);
private void removeBanOnNodes(CancellableTask task, Collection<DiscoveryNode> childNodes) {
final BanParentTaskRequest request =
BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId()));
for (DiscoveryNode node : childNodes) {
logger.trace("Sending remove ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node);
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, request, EmptyTransportResponseHandler.INSTANCE_SAME);
}
}
private static class BanLock {
private final Runnable finish;
private final AtomicInteger counter;
private final int nodesSize;
BanLock(int nodesSize, Runnable finish) {
counter = new AtomicInteger(0);
this.finish = finish;
this.nodesSize = nodesSize;
}
public void onBanSet() {
if (counter.decrementAndGet() == 0) {
finish();
}
}
public void onTaskFinished() {
if (counter.addAndGet(nodesSize) == 0) {
finish();
}
}
public void finish() {
finish.run();
}
}
private static class BanParentTaskRequest extends TransportRequest {
private final TaskId parentTaskId;

View File

@ -24,7 +24,10 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskListener;
import org.elasticsearch.tasks.TaskManager;
@ -47,6 +50,14 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
this.taskManager = taskManager;
}
private Releasable registerChildNode(TaskId parentTask) {
if (parentTask.isSet()) {
return taskManager.registerChildNode(parentTask.getId(), taskManager.localNode());
} else {
return () -> {};
}
}
/**
* Use this method when the transport action call should result in creation of a new task associated with the call.
*
@ -60,12 +71,14 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
* task. That just seems like too many objects. Thus the two versions of
* this method.
*/
final Releasable unregisterChildNode = registerChildNode(request.getParentTask());
Task task = taskManager.register("transport", actionName, request);
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
taskManager.unregister(task);
unregisterChildNode.close();
} finally {
listener.onResponse(response);
}
@ -74,7 +87,7 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
@Override
public void onFailure(Exception e) {
try {
taskManager.unregister(task);
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
}
@ -88,12 +101,14 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
* {@link TaskListener} which listens for the completion of the action.
*/
public final Task execute(Request request, TaskListener<Response> listener) {
final Releasable unregisterChildNode = registerChildNode(request.getParentTask());
Task task = taskManager.register("transport", actionName, request);
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
taskManager.unregister(task);
unregisterChildNode.close();
} finally {
listener.onResponse(task, response);
}
@ -102,7 +117,7 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
@Override
public void onFailure(Exception e) {
try {
taskManager.unregister(task);
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
}

View File

@ -68,6 +68,7 @@ public class RestCancelTasksAction extends BaseRestHandler {
cancelTasksRequest.setNodes(nodesIds);
cancelTasksRequest.setActions(actions);
cancelTasksRequest.setParentTaskId(parentTaskId);
cancelTasksRequest.setWaitForCompletion(request.paramAsBoolean("wait_for_completion", cancelTasksRequest.waitForCompletion()));
return channel ->
client.admin().cluster().cancelTasks(cancelTasksRequest, listTasksResponseListener(nodesInCluster, groupBy, channel));
}

View File

@ -22,14 +22,15 @@ package org.elasticsearch.tasks;
import org.elasticsearch.common.Nullable;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* A task that can be canceled
*/
public abstract class CancellableTask extends Task {
private final AtomicReference<String> reason = new AtomicReference<>();
private volatile String reason;
private final AtomicBoolean cancelled = new AtomicBoolean(false);
public CancellableTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
super(id, type, action, description, parentTaskId, headers);
@ -40,8 +41,10 @@ public abstract class CancellableTask extends Task {
*/
final void cancel(String reason) {
assert reason != null;
this.reason.compareAndSet(null, reason);
onCancelled();
if (cancelled.compareAndSet(false, true)) {
this.reason = reason;
onCancelled();
}
}
/**
@ -58,15 +61,15 @@ public abstract class CancellableTask extends Task {
public abstract boolean shouldCancelChildrenOnCancellation();
public boolean isCancelled() {
return reason.get() != null;
return cancelled.get();
}
/**
* The reason the task was cancelled or null if it hasn't been cancelled.
*/
@Nullable
public String getReasonCancelled() {
return reason.get();
public final String getReasonCancelled() {
return reason;
}
/**

View File

@ -19,6 +19,8 @@
package org.elasticsearch.tasks;
import com.carrotsearch.hppc.ObjectIntHashMap;
import com.carrotsearch.hppc.ObjectIntMap;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
@ -31,6 +33,8 @@ import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterStateApplier;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
@ -41,6 +45,7 @@ import org.elasticsearch.threadpool.ThreadPool;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
@ -50,6 +55,8 @@ import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import static org.elasticsearch.common.unit.TimeValue.timeValueMillis;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE;
@ -78,7 +85,7 @@ public class TaskManager implements ClusterStateApplier {
private TaskResultsService taskResultsService;
private DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES;
private volatile DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES;
private final ByteSizeValue maxHeaderSize;
@ -132,13 +139,14 @@ public class TaskManager implements ClusterStateApplier {
CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask);
CancellableTaskHolder oldHolder = cancellableTasks.put(task.getId(), holder);
assert oldHolder == null;
// Check if this task was banned before we start it
// Check if this task was banned before we start it. The empty check is used to avoid
// computing the hash code of the parent taskId as most of the time banedParents is empty.
if (task.getParentTaskId().isSet() && banedParents.isEmpty() == false) {
String reason = banedParents.get(task.getParentTaskId());
if (reason != null) {
try {
holder.cancel(reason);
throw new IllegalStateException("Task cancelled before it started: " + reason);
throw new TaskCancelledException("Task cancelled before it started: " + reason);
} finally {
// let's clean up the registration
unregister(task);
@ -150,18 +158,18 @@ public class TaskManager implements ClusterStateApplier {
/**
* Cancels a task
* <p>
* Returns true if cancellation was started successful, null otherwise.
*
* After starting cancellation on the parent task, the task manager tries to cancel all children tasks
* of the current task. Once cancellation of the children tasks is done, the listener is triggered.
* If the task is completed or unregistered from TaskManager, then the listener is called immediately.
*/
public boolean cancel(CancellableTask task, String reason, Runnable listener) {
public void cancel(CancellableTask task, String reason, Runnable listener) {
CancellableTaskHolder holder = cancellableTasks.get(task.getId());
if (holder != null) {
logger.trace("cancelling task with id {}", task.getId());
return holder.cancel(reason, listener);
holder.cancel(reason, listener);
} else {
listener.run();
}
return false;
}
/**
@ -182,6 +190,23 @@ public class TaskManager implements ClusterStateApplier {
}
}
/**
* Register a node on which a child task will execute. The returned {@link Releasable} must be called
* to unregister the child node once the child task is completed or failed.
*/
public Releasable registerChildNode(long taskId, DiscoveryNode node) {
final CancellableTaskHolder holder = cancellableTasks.get(taskId);
if (holder != null) {
holder.registerChildNode(node);
return Releasables.releaseOnce(() -> holder.unregisterChildNode(node));
}
return () -> {};
}
public DiscoveryNode localNode() {
return lastDiscoveryNodes.getLocalNode();
}
/**
* Stores the task failure
*/
@ -339,6 +364,31 @@ public class TaskManager implements ClusterStateApplier {
banedParents.remove(parentTaskId);
}
// for testing
public boolean childTasksCancelledOrBanned(TaskId parentTaskId) {
if (banedParents.containsKey(parentTaskId)) {
return true;
}
return cancellableTasks.values().stream().noneMatch(task -> task.hasParent(parentTaskId));
}
/**
* Start rejecting new child requests as the parent task was cancelled.
*
* @param taskId the parent task id
* @param onChildTasksCompleted called when all child tasks are completed or failed
* @return the set of current nodes that have outstanding child tasks
*/
public Collection<DiscoveryNode> startBanOnChildrenNodes(long taskId, Runnable onChildTasksCompleted) {
final CancellableTaskHolder holder = cancellableTasks.get(taskId);
if (holder != null) {
return holder.startBan(onChildTasksCompleted);
} else {
onChildTasksCompleted.run();
return Collections.emptySet();
}
}
@Override
public void applyClusterState(ClusterChangedEvent event) {
lastDiscoveryNodes = event.state().getNodes();
@ -388,74 +438,76 @@ public class TaskManager implements ClusterStateApplier {
}
private static class CancellableTaskHolder {
private static final String TASK_FINISHED_MARKER = "task finished";
private final CancellableTask task;
private volatile String cancellationReason = null;
private volatile Runnable cancellationListener = null;
private boolean finished = false;
private List<Runnable> cancellationListeners = null;
private ObjectIntMap<DiscoveryNode> childTasksPerNode = null;
private boolean banChildren = false;
private List<Runnable> childTaskCompletedListeners = null;
CancellableTaskHolder(CancellableTask task) {
this.task = task;
}
/**
* Marks task as cancelled.
* <p>
* Returns true if cancellation was successful, false otherwise.
*/
public boolean cancel(String reason, Runnable listener) {
final boolean cancelled;
void cancel(String reason, Runnable listener) {
final Runnable toRun;
synchronized (this) {
assert reason != null;
if (cancellationReason == null) {
cancellationReason = reason;
cancellationListener = listener;
cancelled = true;
if (finished) {
assert cancellationListeners == null;
toRun = listener;
} else {
// Already cancelled by somebody else
cancelled = false;
toRun = () -> {};
if (listener != null) {
if (cancellationListeners == null) {
cancellationListeners = new ArrayList<>();
}
cancellationListeners.add(listener);
}
}
}
if (cancelled) {
try {
task.cancel(reason);
} finally {
if (toRun != null) {
toRun.run();
}
}
return cancelled;
}
/**
* Marks task as cancelled.
* <p>
* Returns true if cancellation was successful, false otherwise.
*/
public boolean cancel(String reason) {
return cancel(reason, null);
void cancel(String reason) {
task.cancel(reason);
}
/**
* Marks task as finished.
*/
public void finish() {
Runnable listener = null;
final List<Runnable> listeners;
synchronized (this) {
if (cancellationReason != null) {
// The task was cancelled, we need to notify the listener
if (cancellationListener != null) {
listener = cancellationListener;
cancellationListener = null;
}
this.finished = true;
if (cancellationListeners != null) {
listeners = cancellationListeners;
cancellationListeners = null;
} else {
cancellationReason = TASK_FINISHED_MARKER;
listeners = Collections.emptyList();
}
}
// We need to call the listener outside of the synchronised section to avoid potential bottle necks
// in the listener synchronization
if (listener != null) {
listener.run();
}
notifyListeners(listeners);
}
private void notifyListeners(List<Runnable> listeners) {
assert Thread.holdsLock(this) == false;
Exception rootException = null;
for (Runnable listener : listeners) {
try {
listener.run();
} catch (RuntimeException inner) {
rootException = ExceptionsHelper.useOrSuppress(rootException, inner);
}
}
ExceptionsHelper.reThrowIfNotNull(rootException);
}
public boolean hasParent(TaskId parentTaskId) {
@ -465,6 +517,58 @@ public class TaskManager implements ClusterStateApplier {
public CancellableTask getTask() {
return task;
}
synchronized void registerChildNode(DiscoveryNode node) {
if (banChildren) {
throw new TaskCancelledException("The parent task was cancelled, shouldn't start any child tasks");
}
if (childTasksPerNode == null) {
childTasksPerNode = new ObjectIntHashMap<>();
}
childTasksPerNode.addTo(node, 1);
}
void unregisterChildNode(DiscoveryNode node) {
final List<Runnable> listeners;
synchronized (this) {
if (childTasksPerNode.addTo(node, -1) == 0) {
childTasksPerNode.remove(node);
}
if (childTasksPerNode.isEmpty() && this.childTaskCompletedListeners != null) {
listeners = childTaskCompletedListeners;
childTaskCompletedListeners = null;
} else {
listeners = Collections.emptyList();
}
}
notifyListeners(listeners);
}
Set<DiscoveryNode> startBan(Runnable onChildTasksCompleted) {
final Set<DiscoveryNode> pendingChildNodes;
final Runnable toRun;
synchronized (this) {
banChildren = true;
if (childTasksPerNode == null) {
pendingChildNodes = Collections.emptySet();
} else {
pendingChildNodes = StreamSupport.stream(childTasksPerNode.spliterator(), false)
.map(e -> e.key).collect(Collectors.toSet());
}
if (pendingChildNodes.isEmpty()) {
assert childTaskCompletedListeners == null;
toRun = onChildTasksCompleted;
} else {
toRun = () -> {};
if (childTaskCompletedListeners == null) {
childTaskCompletedListeners = new ArrayList<>();
}
childTaskCompletedListeners.add(onChildTasksCompleted);
}
}
toRun.run();
return pendingChildNodes;
}
}
}

View File

@ -36,6 +36,7 @@ import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.settings.ClusterSettings;
@ -618,6 +619,34 @@ public class TransportService extends AbstractLifecycleComponent implements Tran
final TransportRequestOptions options,
TransportResponseHandler<T> handler) {
try {
if (request.getParentTask().isSet()) {
// TODO: capture the connection instead so that we can cancel child tasks on the remote connections.
final Releasable unregisterChildNode = taskManager.registerChildNode(request.getParentTask().getId(), connection.getNode());
final TransportResponseHandler<T> delegate = handler;
handler = new TransportResponseHandler<T>() {
@Override
public void handleResponse(T response) {
unregisterChildNode.close();
delegate.handleResponse(response);
}
@Override
public void handleException(TransportException exp) {
unregisterChildNode.close();
delegate.handleException(exp);
}
@Override
public String executor() {
return delegate.executor();
}
@Override
public T read(StreamInput in) throws IOException {
return delegate.read(in);
}
};
}
asyncSender.sendRequest(connection, action, request, options, handler);
} catch (final Exception ex) {
// the caller might not handle this so we invoke the handler

View File

@ -0,0 +1,486 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.action.admin.cluster.node.tasks;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.either;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
public class CancellableTasksIT extends ESIntegTestCase {
static final Map<ChildRequest, CountDownLatch> arrivedLatches = ConcurrentCollections.newConcurrentMap();
static final Map<ChildRequest, CountDownLatch> beforeExecuteLatches = ConcurrentCollections.newConcurrentMap();
static final Map<ChildRequest, CountDownLatch> completedLatches = ConcurrentCollections.newConcurrentMap();
@Before
public void resetTestStates() {
arrivedLatches.clear();
beforeExecuteLatches.clear();
completedLatches.clear();
}
List<ChildRequest> setupChildRequests(Set<DiscoveryNode> nodes) {
int numRequests = randomIntBetween(1, 30);
List<ChildRequest> childRequests = new ArrayList<>();
for (int i = 0; i < numRequests; i++) {
ChildRequest req = new ChildRequest(i, randomFrom(nodes));
childRequests.add(req);
arrivedLatches.put(req, new CountDownLatch(1));
beforeExecuteLatches.put(req, new CountDownLatch(1));
completedLatches.put(req, new CountDownLatch(1));
}
return childRequests;
}
public void testBanOnlyNodesWithOutstandingChildTasks() throws Exception {
if (randomBoolean()) {
internalCluster().startNodes(randomIntBetween(1, 3));
}
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
List<ChildRequest> childRequests = setupChildRequests(nodes);
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests));
List<ChildRequest> completedRequests = randomSubsetOf(between(0, childRequests.size() - 1), childRequests);
for (ChildRequest req : completedRequests) {
beforeExecuteLatches.get(req).countDown();
completedLatches.get(req).await();
}
List<ChildRequest> outstandingRequests = childRequests.stream().
filter(r -> completedRequests.contains(r) == false)
.collect(Collectors.toList());
for (ChildRequest req : outstandingRequests) {
arrivedLatches.get(req).await();
}
TaskId taskId = getMainTaskId();
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
.waitForCompletion(true).execute();
Set<DiscoveryNode> nodesWithOutstandingChildTask = outstandingRequests.stream().map(r -> r.targetNode).collect(Collectors.toSet());
assertBusy(() -> {
for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
if (nodesWithOutstandingChildTask.contains(node)) {
assertThat(taskManager.getBanCount(), equalTo(1));
} else {
assertThat(taskManager.getBanCount(), equalTo(0));
}
}
});
// failed to spawn child tasks after cancelling
if (randomBoolean()) {
DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get();
TransportMainAction mainAction = internalCluster().getInstance(TransportMainAction.class, nodeWithParentTask.getName());
PlainActionFuture<ChildResponse> future = new PlainActionFuture<>();
ChildRequest req = new ChildRequest(-1, randomFrom(nodes));
completedLatches.put(req, new CountDownLatch(1));
mainAction.startChildTask(taskId, req, future);
TransportException te = expectThrows(TransportException.class, future::actionGet);
assertThat(te.getCause(), instanceOf(TaskCancelledException.class));
assertThat(te.getCause().getMessage(), equalTo("The parent task was cancelled, shouldn't start any child tasks"));
}
for (ChildRequest req : outstandingRequests) {
beforeExecuteLatches.get(req).countDown();
}
cancelFuture.actionGet();
waitForMainTask(mainTaskFuture);
assertBusy(() -> {
for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
assertThat(taskManager.getBanCount(), equalTo(0));
}
});
}
public void testCancelTaskMultipleTimes() throws Exception {
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
List<ChildRequest> childRequests = setupChildRequests(nodes);
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests));
for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) {
arrivedLatches.get(r).await();
}
TaskId taskId = getMainTaskId();
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
.waitForCompletion(true).execute();
ensureChildTasksCancelledOrBanned(taskId);
if (randomBoolean()) {
CancelTasksResponse resp = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get();
assertThat(resp.getTaskFailures(), empty());
assertThat(resp.getNodeFailures(), empty());
}
assertFalse(cancelFuture.isDone());
for (ChildRequest r : childRequests) {
beforeExecuteLatches.get(r).countDown();
}
assertThat(cancelFuture.actionGet().getTaskFailures(), empty());
assertThat(cancelFuture.actionGet().getTaskFailures(), empty());
waitForMainTask(mainTaskFuture);
CancelTasksResponse cancelError = client().admin().cluster().prepareCancelTasks()
.setTaskId(taskId).waitForCompletion(randomBoolean()).get();
assertThat(cancelError.getNodeFailures(), hasSize(1));
final Throwable notFound = ExceptionsHelper.unwrap(cancelError.getNodeFailures().get(0), ResourceNotFoundException.class);
assertThat(notFound.getMessage(), equalTo("task [" + taskId + "] is not found"));
}
public void testDoNotWaitForCompletion() throws Exception {
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
List<ChildRequest> childRequests = setupChildRequests(nodes);
ActionFuture<MainResponse> mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests));
for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) {
arrivedLatches.get(r).await();
}
TaskId taskId = getMainTaskId();
boolean waitForCompletion = randomBoolean();
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId)
.waitForCompletion(waitForCompletion).execute();
if (waitForCompletion) {
assertFalse(cancelFuture.isDone());
} else {
assertBusy(() -> assertTrue(cancelFuture.isDone()));
}
for (ChildRequest r : childRequests) {
beforeExecuteLatches.get(r).countDown();
}
waitForMainTask(mainTaskFuture);
}
TaskId getMainTaskId() {
ListTasksResponse listTasksResponse = client().admin().cluster().prepareListTasks()
.setActions(TransportMainAction.ACTION.name()).setDetailed(true).get();
assertThat(listTasksResponse.getTasks(), hasSize(1));
return listTasksResponse.getTasks().get(0).getTaskId();
}
void waitForMainTask(ActionFuture<MainResponse> mainTask) {
try {
mainTask.actionGet();
} catch (Exception e) {
final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class);
assertThat(cause.getMessage(),
either(equalTo("The parent task was cancelled, shouldn't start any child tasks"))
.or(containsString("Task cancelled before it started:")));
}
}
public static class MainRequest extends ActionRequest {
final List<ChildRequest> childRequests;
public MainRequest(List<ChildRequest> childRequests) {
this.childRequests = childRequests;
}
public MainRequest(StreamInput in) throws IOException {
super(in);
this.childRequests = in.readList(ChildRequest::new);
}
@Override
public ActionRequestValidationException validate() {
return null;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeList(childRequests);
}
@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) {
@Override
public boolean shouldCancelChildrenOnCancellation() {
return true;
}
};
}
}
public static class MainResponse extends ActionResponse {
public MainResponse() {
}
public MainResponse(StreamInput in) throws IOException {
super(in);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
}
}
public static class ChildRequest extends ActionRequest {
final int id;
final DiscoveryNode targetNode;
public ChildRequest(int id, DiscoveryNode targetNode) {
this.id = id;
this.targetNode = targetNode;
}
public ChildRequest(StreamInput in) throws IOException {
super(in);
this.id = in.readInt();
this.targetNode = new DiscoveryNode(in);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeInt(id);
targetNode.writeTo(out);
}
@Override
public ActionRequestValidationException validate() {
return null;
}
@Override
public String getDescription() {
return "childTask[" + id + "]";
}
@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
if (randomBoolean()) {
boolean shouldCancelChildrenOnCancellation = randomBoolean();
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) {
@Override
public boolean shouldCancelChildrenOnCancellation() {
return shouldCancelChildrenOnCancellation;
}
};
} else {
return super.createTask(id, type, action, parentTaskId, headers);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ChildRequest that = (ChildRequest) o;
return id == that.id && targetNode.equals(that.targetNode);
}
@Override
public int hashCode() {
return Objects.hash(id, targetNode);
}
}
public static class ChildResponse extends ActionResponse {
public ChildResponse() {
}
public ChildResponse(StreamInput in) throws IOException {
super(in);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
}
}
public static class TransportMainAction extends HandledTransportAction<MainRequest, MainResponse> {
public static ActionType<MainResponse> ACTION = new ActionType<>("internal::main_action", MainResponse::new);
private final TransportService transportService;
private final NodeClient client;
@Inject
public TransportMainAction(TransportService transportService, NodeClient client, ActionFilters actionFilters) {
super(ACTION.name(), transportService, actionFilters, MainRequest::new, ThreadPool.Names.GENERIC);
this.transportService = transportService;
this.client = client;
}
@Override
protected void doExecute(Task task, MainRequest request, ActionListener<MainResponse> listener) {
GroupedActionListener<ChildResponse> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> new MainResponse()), request.childRequests.size());
for (ChildRequest childRequest : request.childRequests) {
TaskId parentTaskId = new TaskId(client.getLocalNodeId(), task.getId());
startChildTask(parentTaskId, childRequest, groupedListener);
}
}
protected void startChildTask(TaskId parentTaskId, ChildRequest childRequest, ActionListener<ChildResponse> listener) {
childRequest.setParentTask(parentTaskId);
final CountDownLatch completeLatch = completedLatches.get(childRequest);
LatchedActionListener<ChildResponse> latchedListener = new LatchedActionListener<>(listener, completeLatch);
transportService.getThreadPool().generic().execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
@Override
protected void doRun() {
if (client.getLocalNodeId().equals(childRequest.targetNode.getId()) && randomBoolean()) {
try {
client.executeLocally(TransportChildAction.ACTION, childRequest, latchedListener);
} catch (TaskCancelledException e) {
latchedListener.onFailure(new TransportException(e));
}
} else {
transportService.sendRequest(childRequest.targetNode, TransportChildAction.ACTION.name(), childRequest,
new TransportResponseHandler<ChildResponse>() {
@Override
public void handleResponse(ChildResponse response) {
latchedListener.onResponse(new ChildResponse());
}
@Override
public void handleException(TransportException exp) {
latchedListener.onFailure(exp);
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
@Override
public ChildResponse read(StreamInput in) throws IOException {
return new ChildResponse(in);
}
});
}
}
});
}
}
public static class TransportChildAction extends HandledTransportAction<ChildRequest, ChildResponse> {
public static ActionType<ChildResponse> ACTION = new ActionType<>("internal:child_action", ChildResponse::new);
private final TransportService transportService;
@Inject
public TransportChildAction(TransportService transportService, ActionFilters actionFilters) {
super(ACTION.name(), transportService, actionFilters, ChildRequest::new, ThreadPool.Names.GENERIC);
this.transportService = transportService;
}
@Override
protected void doExecute(Task task, ChildRequest request, ActionListener<ChildResponse> listener) {
assertThat(request.targetNode, equalTo(transportService.getLocalNode()));
arrivedLatches.get(request).countDown();
transportService.getThreadPool().executor(ThreadPool.Names.GENERIC).execute(ActionRunnable.supply(listener, () -> {
beforeExecuteLatches.get(request).await();
return new ChildResponse();
}));
}
}
public static class TaskPlugin extends Plugin implements ActionPlugin {
@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return Arrays.asList(
new ActionHandler<>(TransportMainAction.ACTION, TransportMainAction.class),
new ActionHandler<>(TransportChildAction.ACTION, TransportChildAction.class)
);
}
@Override
public List<ActionType<? extends ActionResponse>> getClientActions() {
return Arrays.asList(TransportMainAction.ACTION, TransportChildAction.ACTION);
}
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
final List<Class<? extends Plugin>> plugins = new ArrayList<>(super.nodePlugins());
plugins.add(TaskPlugin.class);
return plugins;
}
@Override
protected Collection<Class<? extends Plugin>> transportClientPlugins() {
final List<Class<? extends Plugin>> plugins = new ArrayList<>(super.transportClientPlugins());
plugins.add(TaskPlugin.class);
return plugins;
}
/**
* Ensures that all outstanding child tasks of the given parent task are banned or being cancelled.
*/
protected static void ensureChildTasksCancelledOrBanned(TaskId taskId) throws Exception {
assertBusy(() -> {
for (String nodeName : internalCluster().getNodeNames()) {
final TaskManager taskManager = internalCluster().getInstance(TransportService.class, nodeName).getTaskManager();
assertTrue(taskManager.childTasksCancelledOrBanned(taskId));
}
});
}
}

View File

@ -21,6 +21,7 @@ package org.elasticsearch.action.admin.cluster.node.tasks;
import com.carrotsearch.randomizedtesting.RandomizedContext;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksAction;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequest;
@ -39,16 +40,21 @@ import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Phaser;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.test.ClusterServiceUtils.setState;
@ -189,19 +195,24 @@ public class CancellableTasksTests extends TaskManagerTestCase {
}
}
private Task startCancellableTestNodesAction(boolean waitForActionToStart, int blockedNodesCount, ActionListener<NodesResponse>
listener) throws InterruptedException {
return startCancellableTestNodesAction(waitForActionToStart, randomSubsetOf(blockedNodesCount, testNodes), new
CancellableNodesRequest("Test Request"), listener);
private Task startCancellableTestNodesAction(boolean waitForActionToStart, int runNodesCount, int blockedNodesCount,
ActionListener<NodesResponse> listener) throws InterruptedException {
List<TestNode> runOnNodes = randomSubsetOf(runNodesCount, testNodes);
return startCancellableTestNodesAction(waitForActionToStart, runOnNodes, randomSubsetOf(blockedNodesCount, runOnNodes), new
CancellableNodesRequest("Test Request",runOnNodes.stream().map(TestNode::getNodeId).toArray(String[]::new)), listener);
}
private Task startCancellableTestNodesAction(boolean waitForActionToStart, Collection<TestNode> blockOnNodes, CancellableNodesRequest
request, ActionListener<NodesResponse> listener) throws InterruptedException {
CountDownLatch actionLatch = waitForActionToStart ? new CountDownLatch(nodesCount) : null;
private Task startCancellableTestNodesAction(boolean waitForActionToStart, List<TestNode> runOnNodes,
Collection<TestNode> blockOnNodes, CancellableNodesRequest
request, ActionListener<NodesResponse> listener) throws InterruptedException {
CountDownLatch actionLatch = waitForActionToStart ? new CountDownLatch(runOnNodes.size()) : null;
CancellableTestNodesAction[] actions = new CancellableTestNodesAction[nodesCount];
for (int i = 0; i < testNodes.length; i++) {
boolean shouldBlock = blockOnNodes.contains(testNodes[i]);
logger.info("The action in the node [{}] should block: [{}]", testNodes[i].getNodeId(), shouldBlock);
boolean shouldRun = runOnNodes.contains(testNodes[i]);
logger.info("The action on the node [{}] should run: [{}] should block: [{}]", testNodes[i].getNodeId(), shouldRun,
shouldBlock);
actions[i] = new CancellableTestNodesAction("internal:testAction", threadPool, testNodes[i]
.clusterService, testNodes[i].transportService, shouldBlock, actionLatch);
}
@ -222,20 +233,22 @@ public class CancellableTasksTests extends TaskManagerTestCase {
logger.info("waitForActionToStart is set to {}", waitForActionToStart);
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
int blockedNodesCount = randomIntBetween(0, nodesCount);
Task mainTask = startCancellableTestNodesAction(waitForActionToStart, blockedNodesCount, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}
int runNodesCount = randomIntBetween(1, nodesCount);
int blockedNodesCount = randomIntBetween(0, runNodesCount);
Task mainTask = startCancellableTestNodesAction(waitForActionToStart, runNodesCount, blockedNodesCount,
new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
}
});
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
}
});
// Cancel main task
CancelTasksRequest request = new CancelTasksRequest();
@ -255,12 +268,12 @@ public class CancellableTasksTests extends TaskManagerTestCase {
// Make sure that the request was successful
assertNull(throwableReference.get());
assertNotNull(responseReference.get());
assertEquals(nodesCount, responseReference.get().getNodes().size());
assertEquals(runNodesCount, responseReference.get().getNodes().size());
assertEquals(0, responseReference.get().failureCount());
} else {
// We canceled the request, in this case it should have fail, but we should get partial response
assertNull(throwableReference.get());
assertEquals(nodesCount, responseReference.get().failureCount() + responseReference.get().getNodes().size());
assertEquals(runNodesCount, responseReference.get().failureCount() + responseReference.get().getNodes().size());
// and we should have at least as many failures as the number of blocked operations
// (we might have cancelled some non-blocked operations before they even started and that's ok)
assertThat(responseReference.get().failureCount(), greaterThanOrEqualTo(blockedNodesCount));
@ -295,19 +308,23 @@ public class CancellableTasksTests extends TaskManagerTestCase {
CountDownLatch responseLatch = new CountDownLatch(1);
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
Task mainTask = startCancellableTestNodesAction(true, nodesCount, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}
int runNodesCount = randomIntBetween(1, nodesCount);
int blockedNodesCount = randomIntBetween(0, runNodesCount);
Task mainTask = startCancellableTestNodesAction(true, runNodesCount, blockedNodesCount,
new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
}
}
});
);
// Cancel all child tasks without cancelling the main task, which should quit on its own
CancelTasksRequest request = new CancelTasksRequest();
@ -320,8 +337,10 @@ public class CancellableTasksTests extends TaskManagerTestCase {
// Awaiting for the main task to finish
responseLatch.await();
// Should have cancelled tasks on all nodes
assertThat(response.getTasks().size(), equalTo(testNodes.length));
// Should have cancelled tasks at least on all nodes where it was blocked
assertThat(response.getTasks().size(), lessThanOrEqualTo(runNodesCount));
// but may also encounter some nodes where it was still running
assertThat(response.getTasks().size(), greaterThanOrEqualTo(blockedNodesCount));
assertBusy(() -> {
// Make sure that main task is no longer running
@ -343,20 +362,22 @@ public class CancellableTasksTests extends TaskManagerTestCase {
// We shouldn't block on the first node since it's leaving the cluster anyway so it doesn't matter
List<TestNode> blockOnNodes = randomSubsetOf(blockedNodesCount, Arrays.copyOfRange(testNodes, 1, nodesCount));
Task mainTask = startCancellableTestNodesAction(true, blockOnNodes, new CancellableNodesRequest("Test Request"), new
ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}
Task mainTask = startCancellableTestNodesAction(true, Arrays.asList(testNodes), blockOnNodes,
new CancellableNodesRequest("Test Request"), new
ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
}
});
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
}
}
);
String mainNode = testNodes[0].getNodeId();
@ -415,6 +436,63 @@ public class CancellableTasksTests extends TaskManagerTestCase {
}
public void testNonExistingTaskCancellation() throws Exception {
setupTestNodes(Settings.EMPTY);
connectNodes(testNodes);
// Cancel a task that doesn't exist
CancelTasksRequest request = new CancelTasksRequest();
request.setReason("Testing Cancellation");
request.setActions("do-not-match-anything");
request.setNodes(
randomSubsetOf(randomIntBetween(1,testNodes.length - 1), testNodes).stream().map(TestNode::getNodeId).toArray(String[]::new));
// And send the cancellation request to a random node
CancelTasksResponse response = ActionTestUtils.executeBlocking(
testNodes[randomIntBetween(1, testNodes.length - 1)].transportCancelTasksAction, request);
// Shouldn't have cancelled anything
assertThat(response.getTasks().size(), equalTo(0));
assertBusy(() -> {
// Make sure that main task is no longer running
ListTasksResponse listTasksResponse = ActionTestUtils.executeBlocking(
testNodes[randomIntBetween(0, testNodes.length - 1)].transportListTasksAction,
new ListTasksRequest().setActions(CancelTasksAction.NAME + "*"));
assertEquals(0, listTasksResponse.getTasks().size());
});
}
public void testCancelConcurrently() throws Exception {
setupTestNodes(Settings.EMPTY);
final TaskManager taskManager = testNodes[0].transportService.getTaskManager();
int numTasks = randomIntBetween(1, 10);
List<CancellableTask> tasks = new ArrayList<>(numTasks);
for (int i = 0; i < numTasks; i++) {
tasks.add((CancellableTask) taskManager.register("type-" + i, "action-" + i, new CancellableNodeRequest()));
}
Thread[] threads = new Thread[randomIntBetween(1, 8)];
AtomicIntegerArray notified = new AtomicIntegerArray(threads.length);
Phaser phaser = new Phaser(threads.length + 1);
final CancellableTask cancellingTask = randomFrom(tasks);
for (int i = 0; i < threads.length; i++) {
int idx = i;
threads[i] = new Thread(() -> {
phaser.arriveAndAwaitAdvance();
taskManager.cancel(cancellingTask, "test", () -> assertTrue(notified.compareAndSet(idx, 0, 1)));
});
threads[i].start();
}
phaser.arriveAndAwaitAdvance();
taskManager.unregister(cancellingTask);
for (int i = 0; i < threads.length; i++) {
threads[i].join();
assertThat(notified.get(i), equalTo(1));
}
AtomicBoolean called = new AtomicBoolean();
taskManager.cancel(cancellingTask, "test", () -> assertTrue(called.compareAndSet(false, true)));
assertTrue(called.get());
}
private static void debugDelay(String name) {
// Introduce an additional pseudo random repeatable race conditions
String delayName = RandomizedContext.current().getRunnerSeedAsString() + ":" + name;