Task cancellation command should wait for all child nodes to receive cancellation request before returning

Currently the task cancellation command returns as soon as the top-level parent child is marked as cancelled. This create race conditions in tests where child tasks on other nodes may continue to run for some time after the main task is cancelled. This commit fixes this situation making task cancellation command to wait until it got propagated to all nodes that have child tasks.

Closes #21126
This commit is contained in:
Igor Motov 2016-11-07 15:01:29 -10:00
parent 06a50fa31e
commit df965fc9b3
2 changed files with 44 additions and 11 deletions

View File

@ -33,6 +33,7 @@ import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.tasks.TaskInfo;
@ -46,6 +47,7 @@ import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -118,12 +120,44 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
Set<String> childNodes = taskManager.cancel(cancellableTask, request.getReason(), banLock::onTaskFinished); Set<String> childNodes = taskManager.cancel(cancellableTask, request.getReason(), banLock::onTaskFinished);
if (childNodes != null) { if (childNodes != null) {
if (childNodes.isEmpty()) { if (childNodes.isEmpty()) {
// The task has no child tasks, so we can return immediately
logger.trace("cancelling task {} with no children", cancellableTask.getId()); logger.trace("cancelling task {} with no children", cancellableTask.getId());
listener.onResponse(cancellableTask.taskInfo(clusterService.localNode().getId(), false)); listener.onResponse(cancellableTask.taskInfo(clusterService.localNode().getId(), false));
} else { } else {
// The task has some child tasks, we need to wait for until ban is set on all nodes
logger.trace("cancelling task {} with children on nodes [{}]", cancellableTask.getId(), childNodes); logger.trace("cancelling task {} with children on nodes [{}]", cancellableTask.getId(), childNodes);
setBanOnNodes(request.getReason(), cancellableTask, childNodes, banLock); String nodeId = clusterService.localNode().getId();
listener.onResponse(cancellableTask.taskInfo(clusterService.localNode().getId(), false)); AtomicInteger responses = new AtomicInteger(childNodes.size());
List<Exception> failures = new ArrayList<>();
setBanOnNodes(request.getReason(), cancellableTask, childNodes, new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
processResponse();
}
@Override
public void onFailure(Exception e) {
synchronized (failures) {
failures.add(e);
}
processResponse();
}
private void processResponse() {
banLock.onBanSet();
if (responses.decrementAndGet() == 0) {
if (failures.isEmpty() == false) {
IllegalStateException exception = new IllegalStateException("failed to cancel children of the task [" +
cancellableTask.getId() + "]");
failures.forEach(exception::addSuppressed);
listener.onFailure(exception);
} else {
listener.onResponse(cancellableTask.taskInfo(nodeId, false));
}
}
}
});
} }
} else { } else {
logger.trace("task {} is already cancelled", cancellableTask.getId()); logger.trace("task {} is already cancelled", cancellableTask.getId());
@ -136,10 +170,10 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
return true; return true;
} }
private void setBanOnNodes(String reason, CancellableTask task, Set<String> nodes, BanLock banLock) { private void setBanOnNodes(String reason, CancellableTask task, Set<String> nodes, ActionListener<Void> listener) {
sendSetBanRequest(nodes, sendSetBanRequest(nodes,
BanParentTaskRequest.createSetBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId()), reason), BanParentTaskRequest.createSetBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId()), reason),
banLock); listener);
} }
private void removeBanOnNodes(CancellableTask task, Set<String> nodes) { private void removeBanOnNodes(CancellableTask task, Set<String> nodes) {
@ -147,28 +181,29 @@ public class TransportCancelTasksAction extends TransportTasksAction<Cancellable
BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId()))); BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId())));
} }
private void sendSetBanRequest(Set<String> nodes, BanParentTaskRequest request, BanLock banLock) { private void sendSetBanRequest(Set<String> nodes, BanParentTaskRequest request, ActionListener<Void> listener) {
ClusterState clusterState = clusterService.state(); ClusterState clusterState = clusterService.state();
for (String node : nodes) { for (String node : nodes) {
DiscoveryNode discoveryNode = clusterState.getNodes().get(node); DiscoveryNode discoveryNode = clusterState.getNodes().get(node);
if (discoveryNode != null) { if (discoveryNode != null) {
// Check if node still in the cluster // Check if node still in the cluster
logger.debug("Sending ban for tasks with the parent [{}] to the node [{}], ban [{}]", request.parentTaskId, node, logger.trace("Sending ban for tasks with the parent [{}] to the node [{}], ban [{}]", request.parentTaskId, node,
request.ban); request.ban);
transportService.sendRequest(discoveryNode, BAN_PARENT_ACTION_NAME, request, transportService.sendRequest(discoveryNode, BAN_PARENT_ACTION_NAME, request,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) { new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override @Override
public void handleResponse(TransportResponse.Empty response) { public void handleResponse(TransportResponse.Empty response) {
banLock.onBanSet(); listener.onResponse(null);
} }
@Override @Override
public void handleException(TransportException exp) { public void handleException(TransportException exp) {
banLock.onBanSet(); logger.warn("Cannot send ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node);
listener.onFailure(exp);
} }
}); });
} else { } else {
banLock.onBanSet(); listener.onResponse(null);
logger.debug("Cannot send ban for tasks with the parent [{}] to the node [{}] - the node no longer in the cluster", logger.debug("Cannot send ban for tasks with the parent [{}] to the node [{}] - the node no longer in the cluster",
request.parentTaskId, node); request.parentTaskId, node);
} }

View File

@ -176,7 +176,6 @@ public class SearchCancellationIT extends ESIntegTestCase {
ensureSearchWasCancelled(searchResponse); ensureSearchWasCancelled(searchResponse);
} }
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/21126")
public void testCancellationOfScrollSearches() throws Exception { public void testCancellationOfScrollSearches() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory(); List<ScriptedBlockPlugin> plugins = initBlockFactory();
@ -198,7 +197,6 @@ public class SearchCancellationIT extends ESIntegTestCase {
} }
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/21126")
public void testCancellationOfScrollSearchesOnFollowupRequests() throws Exception { public void testCancellationOfScrollSearchesOnFollowupRequests() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory(); List<ScriptedBlockPlugin> plugins = initBlockFactory();