Expose the logic to cancel task when the rest channel is closed (#51423)

This commit moves the logic that cancels search requests when the rest channel is closed
to a generic client that can be used by other APIs. This will be useful for any rest action
that wants to cancel the execution of a task if the underlying rest channel is closed by the
client before completion.

Relates #49931
Relates #50990
Relates #50990
This commit is contained in:
Jim Ferenczi 2020-01-28 22:03:04 +01:00 committed by jimczi
parent aae93a7578
commit 77f4aafaa2
4 changed files with 105 additions and 77 deletions

View File

@ -17,54 +17,84 @@
* under the License. * under the License.
*/ */
package org.elasticsearch.rest.action.search; package org.elasticsearch.rest.action;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType; import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.FilterClient;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskId;
import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
/** /**
* This class executes a request and associates the corresponding {@link Task} with the {@link HttpChannel} that it was originated from, * A {@linkplain Client} that cancels tasks executed locally when the provided {@link HttpChannel}
* so that the tasks associated with a certain channel get cancelled when the underlying connection gets closed. * is closed before completion.
*/ */
public final class HttpChannelTaskHandler { public class RestCancellableNodeClient extends FilterClient {
private static final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();
public static final HttpChannelTaskHandler INSTANCE = new HttpChannelTaskHandler(); private final NodeClient client;
//package private for testing private final HttpChannel httpChannel;
final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();
private HttpChannelTaskHandler() { public RestCancellableNodeClient(NodeClient client, HttpChannel httpChannel) {
super(client);
this.client = client;
this.httpChannel = httpChannel;
} }
<Response extends ActionResponse> void execute(NodeClient client, HttpChannel httpChannel, ActionRequest request, /**
ActionType<Response> actionType, ActionListener<Response> listener) { * Returns the number of channels tracked globally.
*/
public static int getNumChannels() {
return httpChannels.size();
}
CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener(client)); /**
* Returns the number of tasks tracked globally.
*/
static int getNumTasks() {
return httpChannels.values().stream()
.mapToInt(CloseListener::getNumTasks)
.sum();
}
/**
* Returns the number of tasks tracked by the provided {@link HttpChannel}.
*/
static int getNumTasks(HttpChannel channel) {
CloseListener listener = httpChannels.get(channel);
return listener == null ? 0 : listener.getNumTasks();
}
@Override
public <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action, Request request, ActionListener<Response> listener) {
CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener());
TaskHolder taskHolder = new TaskHolder(); TaskHolder taskHolder = new TaskHolder();
Task task = client.executeLocally(actionType, request, Task task = client.executeLocally(action, request,
new ActionListener<Response>() { new ActionListener<Response>() {
@Override @Override
public void onResponse(Response searchResponse) { public void onResponse(Response response) {
try { try {
closeListener.unregisterTask(taskHolder); closeListener.unregisterTask(taskHolder);
} finally { } finally {
listener.onResponse(searchResponse); listener.onResponse(response);
} }
} }
@ -77,32 +107,35 @@ public final class HttpChannelTaskHandler {
} }
} }
}); });
closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId())); final TaskId taskId = new TaskId(client.getLocalNodeId(), task.getId());
closeListener.registerTask(taskHolder, taskId);
closeListener.maybeRegisterChannel(httpChannel); closeListener.maybeRegisterChannel(httpChannel);
} }
public int getNumChannels() { private void cancelTask(TaskId taskId) {
return httpChannels.size(); CancelTasksRequest req = new CancelTasksRequest()
.setTaskId(taskId)
.setReason("channel closed");
// force the origin to execute the cancellation as a system user
new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
} }
final class CloseListener implements ActionListener<Void> { private class CloseListener implements ActionListener<Void> {
private final Client client;
private final AtomicReference<HttpChannel> channel = new AtomicReference<>(); private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
private final Set<TaskId> taskIds = new HashSet<>(); private final Set<TaskId> tasks = new HashSet<>();
CloseListener(Client client) { CloseListener() {
this.client = client;
} }
int getNumTasks() { synchronized int getNumTasks() {
return taskIds.size(); return tasks.size();
} }
void maybeRegisterChannel(HttpChannel httpChannel) { void maybeRegisterChannel(HttpChannel httpChannel) {
if (channel.compareAndSet(null, httpChannel)) { if (channel.compareAndSet(null, httpChannel)) {
//In case the channel is already closed when we register the listener, the listener will be immediately executed which will //In case the channel is already closed when we register the listener, the listener will be immediately executed which will
//remove the channel from the map straight-away. That is why we first create the CloseListener and later we associate it //remove the channel from the map straight-away. That is why we first create the CloseListener and later we associate it
//with the channel. This guarantees that the close listener is already in the map when the it gets registered to its //with the channel. This guarantees that the close listener is already in the map when it gets registered to its
//corresponding channel, hence it is always found in the map when it gets invoked if the channel gets closed. //corresponding channel, hence it is always found in the map when it gets invoked if the channel gets closed.
httpChannel.addCloseListener(this); httpChannel.addCloseListener(this);
} }
@ -111,34 +144,31 @@ public final class HttpChannelTaskHandler {
synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) { synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
taskHolder.taskId = taskId; taskHolder.taskId = taskId;
if (taskHolder.completed == false) { if (taskHolder.completed == false) {
this.taskIds.add(taskId); this.tasks.add(taskId);
} }
} }
synchronized void unregisterTask(TaskHolder taskHolder) { synchronized void unregisterTask(TaskHolder taskHolder) {
if (taskHolder.taskId != null) { if (taskHolder.taskId != null) {
this.taskIds.remove(taskHolder.taskId); this.tasks.remove(taskHolder.taskId);
} }
taskHolder.completed = true; taskHolder.completed = true;
} }
@Override @Override
public synchronized void onResponse(Void aVoid) { public void onResponse(Void aVoid) {
//When the channel gets closed it won't be reused: we can remove it from the map and forget about it. final HttpChannel httpChannel = channel.get();
CloseListener closeListener = httpChannels.remove(channel.get()); assert httpChannel != null : "channel not registered";
// when the channel gets closed it won't be reused: we can remove it from the map and forget about it.
CloseListener closeListener = httpChannels.remove(httpChannel);
assert closeListener != null : "channel not found in the map of tracked channels"; assert closeListener != null : "channel not found in the map of tracked channels";
for (TaskId taskId : taskIds) { final List<TaskId> toCancel;
ThreadContext threadContext = client.threadPool().getThreadContext(); synchronized (this) {
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { toCancel = new ArrayList<>(tasks);
// we stash any context here since this is an internal execution and should not leak any existing context information tasks.clear();
threadContext.markAsSystemContext(); }
ContextPreservingActionListener<CancelTasksResponse> contextPreservingListener = new ContextPreservingActionListener<>( for (TaskId taskId : toCancel) {
threadContext.newRestorableContext(false), ActionListener.wrap(r -> {}, e -> {})); cancelTask(taskId);
CancelTasksRequest cancelTasksRequest = new CancelTasksRequest();
cancelTasksRequest.setTaskId(taskId);
//We don't wait for cancel tasks to come back. Task cancellation is just best effort.
client.admin().cluster().cancelTasks(cancelTasksRequest, contextPreservingListener);
}
} }
} }

View File

@ -22,7 +22,6 @@ package org.elasticsearch.rest.action.search;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.Booleans; import org.elasticsearch.common.Booleans;
@ -34,6 +33,7 @@ import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestActions; import org.elasticsearch.rest.action.RestActions;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestStatusToXContentListener; import org.elasticsearch.rest.action.RestStatusToXContentListener;
import org.elasticsearch.search.Scroll; import org.elasticsearch.search.Scroll;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
@ -110,8 +110,8 @@ public class RestSearchAction extends BaseRestHandler {
parseSearchRequest(searchRequest, request, parser, setSize)); parseSearchRequest(searchRequest, request, parser, setSize));
return channel -> { return channel -> {
RestStatusToXContentListener<SearchResponse> listener = new RestStatusToXContentListener<>(channel); RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel());
HttpChannelTaskHandler.INSTANCE.execute(client, request.getHttpChannel(), searchRequest, SearchAction.INSTANCE, listener); cancelClient.execute(SearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel));
}; };
} }

View File

@ -17,7 +17,7 @@
* under the License. * under the License.
*/ */
package org.elasticsearch.rest.action.search; package org.elasticsearch.rest.action;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequest;
@ -45,7 +45,6 @@ import java.net.InetSocketAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -56,13 +55,13 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
public class HttpChannelTaskHandlerTests extends ESTestCase { public class RestCancellableNodeClientTests extends ESTestCase {
private ThreadPool threadPool; private ThreadPool threadPool;
@Before @Before
public void createThreadPool() { public void createThreadPool() {
threadPool = new TestThreadPool(HttpChannelTaskHandlerTests.class.getName()); threadPool = new TestThreadPool(RestCancellableNodeClientTests.class.getName());
} }
@After @After
@ -77,8 +76,7 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
*/ */
public void testCompletedTasks() throws Exception { public void testCompletedTasks() throws Exception {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) { try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int totalSearches = 0; int totalSearches = 0;
List<Future<?>> futures = new ArrayList<>(); List<Future<?>> futures = new ArrayList<>();
int numChannels = randomIntBetween(1, 30); int numChannels = randomIntBetween(1, 30);
@ -88,8 +86,8 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
totalSearches += numTasks; totalSearches += numTasks;
for (int j = 0; j < numTasks; j++) { for (int j = 0; j < numTasks; j++) {
PlainListenableActionFuture<SearchResponse> actionFuture = PlainListenableActionFuture.newListenableFuture(); PlainListenableActionFuture<SearchResponse> actionFuture = PlainListenableActionFuture.newListenableFuture();
threadPool.generic().submit(() -> httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
SearchAction.INSTANCE, actionFuture)); threadPool.generic().submit(() -> client.execute(SearchAction.INSTANCE, new SearchRequest(), actionFuture));
futures.add(actionFuture); futures.add(actionFuture);
} }
} }
@ -97,10 +95,8 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
future.get(); future.get();
} }
//no channels get closed in this test, hence we expect as many channels as we created in the map //no channels get closed in this test, hence we expect as many channels as we created in the map
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels()); assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
for (Map.Entry<HttpChannel, HttpChannelTaskHandler.CloseListener> entry : httpChannelTaskHandler.httpChannels.entrySet()) { assertEquals(0, RestCancellableNodeClient.getNumTasks());
assertEquals(0, entry.getValue().getNumTasks());
}
assertEquals(totalSearches, testClient.searchRequests.get()); assertEquals(totalSearches, testClient.searchRequests.get());
} }
} }
@ -110,9 +106,8 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
* removed and all of its corresponding tasks get cancelled. * removed and all of its corresponding tasks get cancelled.
*/ */
public void testCancelledTasks() throws Exception { public void testCancelledTasks() throws Exception {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) { try (TestClient nodeClient = new TestClient(Settings.EMPTY, threadPool, true)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int numChannels = randomIntBetween(1, 30); int numChannels = randomIntBetween(1, 30);
int totalSearches = 0; int totalSearches = 0;
List<TestHttpChannel> channels = new ArrayList<>(numChannels); List<TestHttpChannel> channels = new ArrayList<>(numChannels);
@ -121,18 +116,19 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
channels.add(channel); channels.add(channel);
int numTasks = randomIntBetween(1, 30); int numTasks = randomIntBetween(1, 30);
totalSearches += numTasks; totalSearches += numTasks;
RestCancellableNodeClient client = new RestCancellableNodeClient(nodeClient, channel);
for (int j = 0; j < numTasks; j++) { for (int j = 0; j < numTasks; j++) {
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null); client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
} }
assertEquals(numTasks, httpChannelTaskHandler.httpChannels.get(channel).getNumTasks()); assertEquals(numTasks, RestCancellableNodeClient.getNumTasks(channel));
} }
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels()); assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
for (TestHttpChannel channel : channels) { for (TestHttpChannel channel : channels) {
channel.awaitClose(); channel.awaitClose();
} }
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels()); assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(totalSearches, testClient.searchRequests.get()); assertEquals(totalSearches, nodeClient.searchRequests.get());
assertEquals(totalSearches, testClient.cancelledTasks.size()); assertEquals(totalSearches, nodeClient.cancelledTasks.size());
} }
} }
@ -144,8 +140,7 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
*/ */
public void testChannelAlreadyClosed() { public void testChannelAlreadyClosed() {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) { try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int numChannels = randomIntBetween(1, 30); int numChannels = randomIntBetween(1, 30);
int totalSearches = 0; int totalSearches = 0;
for (int i = 0; i < numChannels; i++) { for (int i = 0; i < numChannels; i++) {
@ -154,12 +149,13 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
channel.close(); channel.close();
int numTasks = randomIntBetween(1, 5); int numTasks = randomIntBetween(1, 5);
totalSearches += numTasks; totalSearches += numTasks;
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
for (int j = 0; j < numTasks; j++) { for (int j = 0; j < numTasks; j++) {
//here the channel will be first registered, then straight-away removed from the map as the close listener is invoked //here the channel will be first registered, then straight-away removed from the map as the close listener is invoked
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null); client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
} }
} }
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels()); assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(totalSearches, testClient.searchRequests.get()); assertEquals(totalSearches, testClient.searchRequests.get());
assertEquals(totalSearches, testClient.cancelledTasks.size()); assertEquals(totalSearches, testClient.cancelledTasks.size());
} }

View File

@ -125,8 +125,8 @@ import org.elasticsearch.node.NodeMocksPlugin;
import org.elasticsearch.plugins.NetworkPlugin; import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.script.ScriptMetaData; import org.elasticsearch.script.ScriptMetaData;
import org.elasticsearch.rest.action.search.HttpChannelTaskHandler;
import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.MockSearchService; import org.elasticsearch.search.MockSearchService;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
@ -536,9 +536,11 @@ public abstract class ESIntegTestCase extends ESTestCase {
restClient.close(); restClient.close();
restClient = null; restClient = null;
} }
assertBusy(() -> assertEquals(HttpChannelTaskHandler.INSTANCE.getNumChannels() + " channels still being tracked in " + assertBusy(() -> {
HttpChannelTaskHandler.class.getSimpleName() + " while there should be none", 0, int numChannels = RestCancellableNodeClient.getNumChannels();
HttpChannelTaskHandler.INSTANCE.getNumChannels())); assertEquals( numChannels+ " channels still being tracked in " + RestCancellableNodeClient.class.getSimpleName()
+ " while there should be none", 0, numChannels);
});
} }
private void afterInternal(boolean afterClass) throws Exception { private void afterInternal(boolean afterClass) throws Exception {