Expose the logic to cancel task when the rest channel is closed (#51423)
This commit moves the logic that cancels search requests when the rest channel is closed to a generic client that can be used by other APIs. This will be useful for any rest action that wants to cancel the execution of a task if the underlying rest channel is closed by the client before completion. Relates #49931 Relates #50990 Relates #50990
This commit is contained in:
parent
aae93a7578
commit
77f4aafaa2
|
@ -17,54 +17,84 @@
|
|||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.rest.action.search;
|
||||
package org.elasticsearch.rest.action;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.ActionRequest;
|
||||
import org.elasticsearch.action.ActionResponse;
|
||||
import org.elasticsearch.action.ActionType;
|
||||
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
|
||||
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
|
||||
import org.elasticsearch.action.support.ContextPreservingActionListener;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.client.FilterClient;
|
||||
import org.elasticsearch.client.OriginSettingClient;
|
||||
import org.elasticsearch.client.node.NodeClient;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.http.HttpChannel;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.tasks.TaskId;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
|
||||
|
||||
/**
|
||||
* This class executes a request and associates the corresponding {@link Task} with the {@link HttpChannel} that it was originated from,
|
||||
* so that the tasks associated with a certain channel get cancelled when the underlying connection gets closed.
|
||||
* A {@linkplain Client} that cancels tasks executed locally when the provided {@link HttpChannel}
|
||||
* is closed before completion.
|
||||
*/
|
||||
public final class HttpChannelTaskHandler {
|
||||
public class RestCancellableNodeClient extends FilterClient {
|
||||
private static final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();
|
||||
|
||||
public static final HttpChannelTaskHandler INSTANCE = new HttpChannelTaskHandler();
|
||||
//package private for testing
|
||||
final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();
|
||||
private final NodeClient client;
|
||||
private final HttpChannel httpChannel;
|
||||
|
||||
private HttpChannelTaskHandler() {
|
||||
public RestCancellableNodeClient(NodeClient client, HttpChannel httpChannel) {
|
||||
super(client);
|
||||
this.client = client;
|
||||
this.httpChannel = httpChannel;
|
||||
}
|
||||
|
||||
<Response extends ActionResponse> void execute(NodeClient client, HttpChannel httpChannel, ActionRequest request,
|
||||
ActionType<Response> actionType, ActionListener<Response> listener) {
|
||||
/**
|
||||
* Returns the number of channels tracked globally.
|
||||
*/
|
||||
public static int getNumChannels() {
|
||||
return httpChannels.size();
|
||||
}
|
||||
|
||||
CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener(client));
|
||||
/**
|
||||
* Returns the number of tasks tracked globally.
|
||||
*/
|
||||
static int getNumTasks() {
|
||||
return httpChannels.values().stream()
|
||||
.mapToInt(CloseListener::getNumTasks)
|
||||
.sum();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of tasks tracked by the provided {@link HttpChannel}.
|
||||
*/
|
||||
static int getNumTasks(HttpChannel channel) {
|
||||
CloseListener listener = httpChannels.get(channel);
|
||||
return listener == null ? 0 : listener.getNumTasks();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
|
||||
ActionType<Response> action, Request request, ActionListener<Response> listener) {
|
||||
CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener());
|
||||
TaskHolder taskHolder = new TaskHolder();
|
||||
Task task = client.executeLocally(actionType, request,
|
||||
Task task = client.executeLocally(action, request,
|
||||
new ActionListener<Response>() {
|
||||
@Override
|
||||
public void onResponse(Response searchResponse) {
|
||||
public void onResponse(Response response) {
|
||||
try {
|
||||
closeListener.unregisterTask(taskHolder);
|
||||
} finally {
|
||||
listener.onResponse(searchResponse);
|
||||
listener.onResponse(response);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -77,32 +107,35 @@ public final class HttpChannelTaskHandler {
|
|||
}
|
||||
}
|
||||
});
|
||||
closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId()));
|
||||
final TaskId taskId = new TaskId(client.getLocalNodeId(), task.getId());
|
||||
closeListener.registerTask(taskHolder, taskId);
|
||||
closeListener.maybeRegisterChannel(httpChannel);
|
||||
}
|
||||
|
||||
public int getNumChannels() {
|
||||
return httpChannels.size();
|
||||
private void cancelTask(TaskId taskId) {
|
||||
CancelTasksRequest req = new CancelTasksRequest()
|
||||
.setTaskId(taskId)
|
||||
.setReason("channel closed");
|
||||
// force the origin to execute the cancellation as a system user
|
||||
new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
|
||||
}
|
||||
|
||||
final class CloseListener implements ActionListener<Void> {
|
||||
private final Client client;
|
||||
private class CloseListener implements ActionListener<Void> {
|
||||
private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
|
||||
private final Set<TaskId> taskIds = new HashSet<>();
|
||||
private final Set<TaskId> tasks = new HashSet<>();
|
||||
|
||||
CloseListener(Client client) {
|
||||
this.client = client;
|
||||
CloseListener() {
|
||||
}
|
||||
|
||||
int getNumTasks() {
|
||||
return taskIds.size();
|
||||
synchronized int getNumTasks() {
|
||||
return tasks.size();
|
||||
}
|
||||
|
||||
void maybeRegisterChannel(HttpChannel httpChannel) {
|
||||
if (channel.compareAndSet(null, httpChannel)) {
|
||||
//In case the channel is already closed when we register the listener, the listener will be immediately executed which will
|
||||
//remove the channel from the map straight-away. That is why we first create the CloseListener and later we associate it
|
||||
//with the channel. This guarantees that the close listener is already in the map when the it gets registered to its
|
||||
//with the channel. This guarantees that the close listener is already in the map when it gets registered to its
|
||||
//corresponding channel, hence it is always found in the map when it gets invoked if the channel gets closed.
|
||||
httpChannel.addCloseListener(this);
|
||||
}
|
||||
|
@ -111,34 +144,31 @@ public final class HttpChannelTaskHandler {
|
|||
synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
|
||||
taskHolder.taskId = taskId;
|
||||
if (taskHolder.completed == false) {
|
||||
this.taskIds.add(taskId);
|
||||
this.tasks.add(taskId);
|
||||
}
|
||||
}
|
||||
|
||||
synchronized void unregisterTask(TaskHolder taskHolder) {
|
||||
if (taskHolder.taskId != null) {
|
||||
this.taskIds.remove(taskHolder.taskId);
|
||||
this.tasks.remove(taskHolder.taskId);
|
||||
}
|
||||
taskHolder.completed = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void onResponse(Void aVoid) {
|
||||
//When the channel gets closed it won't be reused: we can remove it from the map and forget about it.
|
||||
CloseListener closeListener = httpChannels.remove(channel.get());
|
||||
public void onResponse(Void aVoid) {
|
||||
final HttpChannel httpChannel = channel.get();
|
||||
assert httpChannel != null : "channel not registered";
|
||||
// when the channel gets closed it won't be reused: we can remove it from the map and forget about it.
|
||||
CloseListener closeListener = httpChannels.remove(httpChannel);
|
||||
assert closeListener != null : "channel not found in the map of tracked channels";
|
||||
for (TaskId taskId : taskIds) {
|
||||
ThreadContext threadContext = client.threadPool().getThreadContext();
|
||||
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
|
||||
// we stash any context here since this is an internal execution and should not leak any existing context information
|
||||
threadContext.markAsSystemContext();
|
||||
ContextPreservingActionListener<CancelTasksResponse> contextPreservingListener = new ContextPreservingActionListener<>(
|
||||
threadContext.newRestorableContext(false), ActionListener.wrap(r -> {}, e -> {}));
|
||||
CancelTasksRequest cancelTasksRequest = new CancelTasksRequest();
|
||||
cancelTasksRequest.setTaskId(taskId);
|
||||
//We don't wait for cancel tasks to come back. Task cancellation is just best effort.
|
||||
client.admin().cluster().cancelTasks(cancelTasksRequest, contextPreservingListener);
|
||||
}
|
||||
final List<TaskId> toCancel;
|
||||
synchronized (this) {
|
||||
toCancel = new ArrayList<>(tasks);
|
||||
tasks.clear();
|
||||
}
|
||||
for (TaskId taskId : toCancel) {
|
||||
cancelTask(taskId);
|
||||
}
|
||||
}
|
||||
|
|
@ -22,7 +22,6 @@ package org.elasticsearch.rest.action.search;
|
|||
import org.apache.logging.log4j.LogManager;
|
||||
import org.elasticsearch.action.search.SearchAction;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.client.node.NodeClient;
|
||||
import org.elasticsearch.common.Booleans;
|
||||
|
@ -34,6 +33,7 @@ import org.elasticsearch.rest.BaseRestHandler;
|
|||
import org.elasticsearch.rest.RestController;
|
||||
import org.elasticsearch.rest.RestRequest;
|
||||
import org.elasticsearch.rest.action.RestActions;
|
||||
import org.elasticsearch.rest.action.RestCancellableNodeClient;
|
||||
import org.elasticsearch.rest.action.RestStatusToXContentListener;
|
||||
import org.elasticsearch.search.Scroll;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
|
@ -110,8 +110,8 @@ public class RestSearchAction extends BaseRestHandler {
|
|||
parseSearchRequest(searchRequest, request, parser, setSize));
|
||||
|
||||
return channel -> {
|
||||
RestStatusToXContentListener<SearchResponse> listener = new RestStatusToXContentListener<>(channel);
|
||||
HttpChannelTaskHandler.INSTANCE.execute(client, request.getHttpChannel(), searchRequest, SearchAction.INSTANCE, listener);
|
||||
RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel());
|
||||
cancelClient.execute(SearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel));
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.rest.action.search;
|
||||
package org.elasticsearch.rest.action;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.ActionRequest;
|
||||
|
@ -45,7 +45,6 @@ import java.net.InetSocketAddress;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CopyOnWriteArraySet;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
@ -56,13 +55,13 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
public class HttpChannelTaskHandlerTests extends ESTestCase {
|
||||
public class RestCancellableNodeClientTests extends ESTestCase {
|
||||
|
||||
private ThreadPool threadPool;
|
||||
|
||||
@Before
|
||||
public void createThreadPool() {
|
||||
threadPool = new TestThreadPool(HttpChannelTaskHandlerTests.class.getName());
|
||||
threadPool = new TestThreadPool(RestCancellableNodeClientTests.class.getName());
|
||||
}
|
||||
|
||||
@After
|
||||
|
@ -77,8 +76,7 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
|
|||
*/
|
||||
public void testCompletedTasks() throws Exception {
|
||||
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) {
|
||||
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
|
||||
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
|
||||
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
|
||||
int totalSearches = 0;
|
||||
List<Future<?>> futures = new ArrayList<>();
|
||||
int numChannels = randomIntBetween(1, 30);
|
||||
|
@ -88,8 +86,8 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
|
|||
totalSearches += numTasks;
|
||||
for (int j = 0; j < numTasks; j++) {
|
||||
PlainListenableActionFuture<SearchResponse> actionFuture = PlainListenableActionFuture.newListenableFuture();
|
||||
threadPool.generic().submit(() -> httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(),
|
||||
SearchAction.INSTANCE, actionFuture));
|
||||
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
|
||||
threadPool.generic().submit(() -> client.execute(SearchAction.INSTANCE, new SearchRequest(), actionFuture));
|
||||
futures.add(actionFuture);
|
||||
}
|
||||
}
|
||||
|
@ -97,10 +95,8 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
|
|||
future.get();
|
||||
}
|
||||
//no channels get closed in this test, hence we expect as many channels as we created in the map
|
||||
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
|
||||
for (Map.Entry<HttpChannel, HttpChannelTaskHandler.CloseListener> entry : httpChannelTaskHandler.httpChannels.entrySet()) {
|
||||
assertEquals(0, entry.getValue().getNumTasks());
|
||||
}
|
||||
assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
|
||||
assertEquals(0, RestCancellableNodeClient.getNumTasks());
|
||||
assertEquals(totalSearches, testClient.searchRequests.get());
|
||||
}
|
||||
}
|
||||
|
@ -110,9 +106,8 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
|
|||
* removed and all of its corresponding tasks get cancelled.
|
||||
*/
|
||||
public void testCancelledTasks() throws Exception {
|
||||
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
|
||||
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
|
||||
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
|
||||
try (TestClient nodeClient = new TestClient(Settings.EMPTY, threadPool, true)) {
|
||||
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
|
||||
int numChannels = randomIntBetween(1, 30);
|
||||
int totalSearches = 0;
|
||||
List<TestHttpChannel> channels = new ArrayList<>(numChannels);
|
||||
|
@ -121,18 +116,19 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
|
|||
channels.add(channel);
|
||||
int numTasks = randomIntBetween(1, 30);
|
||||
totalSearches += numTasks;
|
||||
RestCancellableNodeClient client = new RestCancellableNodeClient(nodeClient, channel);
|
||||
for (int j = 0; j < numTasks; j++) {
|
||||
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
|
||||
client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
|
||||
}
|
||||
assertEquals(numTasks, httpChannelTaskHandler.httpChannels.get(channel).getNumTasks());
|
||||
assertEquals(numTasks, RestCancellableNodeClient.getNumTasks(channel));
|
||||
}
|
||||
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
|
||||
assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
|
||||
for (TestHttpChannel channel : channels) {
|
||||
channel.awaitClose();
|
||||
}
|
||||
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
|
||||
assertEquals(totalSearches, testClient.searchRequests.get());
|
||||
assertEquals(totalSearches, testClient.cancelledTasks.size());
|
||||
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
|
||||
assertEquals(totalSearches, nodeClient.searchRequests.get());
|
||||
assertEquals(totalSearches, nodeClient.cancelledTasks.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -144,8 +140,7 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
|
|||
*/
|
||||
public void testChannelAlreadyClosed() {
|
||||
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
|
||||
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
|
||||
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
|
||||
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
|
||||
int numChannels = randomIntBetween(1, 30);
|
||||
int totalSearches = 0;
|
||||
for (int i = 0; i < numChannels; i++) {
|
||||
|
@ -154,12 +149,13 @@ public class HttpChannelTaskHandlerTests extends ESTestCase {
|
|||
channel.close();
|
||||
int numTasks = randomIntBetween(1, 5);
|
||||
totalSearches += numTasks;
|
||||
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
|
||||
for (int j = 0; j < numTasks; j++) {
|
||||
//here the channel will be first registered, then straight-away removed from the map as the close listener is invoked
|
||||
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
|
||||
client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
|
||||
}
|
||||
}
|
||||
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
|
||||
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
|
||||
assertEquals(totalSearches, testClient.searchRequests.get());
|
||||
assertEquals(totalSearches, testClient.cancelledTasks.size());
|
||||
}
|
|
@ -125,8 +125,8 @@ import org.elasticsearch.node.NodeMocksPlugin;
|
|||
import org.elasticsearch.plugins.NetworkPlugin;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.rest.action.RestCancellableNodeClient;
|
||||
import org.elasticsearch.script.ScriptMetaData;
|
||||
import org.elasticsearch.rest.action.search.HttpChannelTaskHandler;
|
||||
import org.elasticsearch.script.ScriptService;
|
||||
import org.elasticsearch.search.MockSearchService;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
|
@ -536,9 +536,11 @@ public abstract class ESIntegTestCase extends ESTestCase {
|
|||
restClient.close();
|
||||
restClient = null;
|
||||
}
|
||||
assertBusy(() -> assertEquals(HttpChannelTaskHandler.INSTANCE.getNumChannels() + " channels still being tracked in " +
|
||||
HttpChannelTaskHandler.class.getSimpleName() + " while there should be none", 0,
|
||||
HttpChannelTaskHandler.INSTANCE.getNumChannels()));
|
||||
assertBusy(() -> {
|
||||
int numChannels = RestCancellableNodeClient.getNumChannels();
|
||||
assertEquals( numChannels+ " channels still being tracked in " + RestCancellableNodeClient.class.getSimpleName()
|
||||
+ " while there should be none", 0, numChannels);
|
||||
});
|
||||
}
|
||||
|
||||
private void afterInternal(boolean afterClass) throws Exception {
|
||||
|
|
Loading…
Reference in New Issue