Cancel search task on connection close (#43332)

This PR introduces a mechanism to cancel a search task when its corresponding connection gets closed. That would relief users from having to manually deal with tasks and cancel them if needed. Especially the process of finding the task_id requires calling get tasks which needs to call every node in the cluster.

The implementation is based on associating each http channel with its currently running search task, and cancelling the task when the previously registered close listener gets called.
This commit is contained in:
Luca Cavanna 2019-08-21 19:01:37 +02:00
parent 824f1090a9
commit a47ade3e64
4 changed files with 445 additions and 1 deletions

View File

@ -0,0 +1,155 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.rest.action.search;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
/**
* This class executes a request and associates the corresponding {@link Task} with the {@link HttpChannel} that it was originated from,
* so that the tasks associated with a certain channel get cancelled when the underlying connection gets closed.
*/
public final class HttpChannelTaskHandler {
public static final HttpChannelTaskHandler INSTANCE = new HttpChannelTaskHandler();
//package private for testing
final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();
private HttpChannelTaskHandler() {
}
<Response extends ActionResponse> void execute(NodeClient client, HttpChannel httpChannel, ActionRequest request,
ActionType<Response> actionType, ActionListener<Response> listener) {
CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener(client));
TaskHolder taskHolder = new TaskHolder();
Task task = client.executeLocally(actionType, request,
new ActionListener<>() {
@Override
public void onResponse(Response searchResponse) {
try {
closeListener.unregisterTask(taskHolder);
} finally {
listener.onResponse(searchResponse);
}
}
@Override
public void onFailure(Exception e) {
try {
closeListener.unregisterTask(taskHolder);
} finally {
listener.onFailure(e);
}
}
});
closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId()));
closeListener.maybeRegisterChannel(httpChannel);
}
public int getNumChannels() {
return httpChannels.size();
}
final class CloseListener implements ActionListener<Void> {
private final Client client;
private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
private final Set<TaskId> taskIds = new HashSet<>();
CloseListener(Client client) {
this.client = client;
}
int getNumTasks() {
return taskIds.size();
}
void maybeRegisterChannel(HttpChannel httpChannel) {
if (channel.compareAndSet(null, httpChannel)) {
//In case the channel is already closed when we register the listener, the listener will be immediately executed which will
//remove the channel from the map straight-away. That is why we first create the CloseListener and later we associate it
//with the channel. This guarantees that the close listener is already in the map when the it gets registered to its
//corresponding channel, hence it is always found in the map when it gets invoked if the channel gets closed.
httpChannel.addCloseListener(this);
}
}
synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
taskHolder.taskId = taskId;
if (taskHolder.completed == false) {
this.taskIds.add(taskId);
}
}
synchronized void unregisterTask(TaskHolder taskHolder) {
if (taskHolder.taskId != null) {
this.taskIds.remove(taskHolder.taskId);
}
taskHolder.completed = true;
}
@Override
public synchronized void onResponse(Void aVoid) {
//When the channel gets closed it won't be reused: we can remove it from the map and forget about it.
CloseListener closeListener = httpChannels.remove(channel.get());
assert closeListener != null : "channel not found in the map of tracked channels";
for (TaskId taskId : taskIds) {
ThreadContext threadContext = client.threadPool().getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any existing context information
threadContext.markAsSystemContext();
ContextPreservingActionListener<CancelTasksResponse> contextPreservingListener = new ContextPreservingActionListener<>(
threadContext.newRestorableContext(false), ActionListener.wrap(r -> {}, e -> {}));
CancelTasksRequest cancelTasksRequest = new CancelTasksRequest();
cancelTasksRequest.setTaskId(taskId);
//We don't wait for cancel tasks to come back. Task cancellation is just best effort.
client.admin().cluster().cancelTasks(cancelTasksRequest, contextPreservingListener);
}
}
}
@Override
public void onFailure(Exception e) {
onResponse(null);
}
}
private static class TaskHolder {
private TaskId taskId;
private boolean completed = false;
}
}

View File

@ -20,7 +20,9 @@
package org.elasticsearch.rest.action.search;
import org.apache.logging.log4j.LogManager;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.Booleans;
@ -107,7 +109,10 @@ public class RestSearchAction extends BaseRestHandler {
request.withContentOrSourceParamParserOrNull(parser ->
parseSearchRequest(searchRequest, request, parser, setSize));
return channel -> client.search(searchRequest, new RestStatusToXContentListener<>(channel));
return channel -> {
RestStatusToXContentListener<SearchResponse> listener = new RestStatusToXContentListener<>(channel);
HttpChannelTaskHandler.INSTANCE.execute(client, request.getHttpChannel(), searchRequest, SearchAction.INSTANCE, listener);
};
}
/**

View File

@ -0,0 +1,280 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.rest.action.search;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksAction;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.PlainListenableActionFuture;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
public class HttpChannelTaskHandlerTests extends ESTestCase {
private ThreadPool threadPool;
@Before
public void createThreadPool() {
threadPool = new TestThreadPool(HttpChannelTaskHandlerTests.class.getName());
}
@After
public void stopThreadPool() {
ThreadPool.terminate(threadPool, 5, TimeUnit.SECONDS);
}
/**
* This test verifies that no tasks are left in the map where channels and their corresponding tasks are tracked.
* Through the {@link TestClient} we simulate a scenario where the task may complete even before it has been
* associated with its corresponding channel. Either way, we need to make sure that no tasks are left in the map.
*/
public void testCompletedTasks() throws Exception {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int totalSearches = 0;
List<Future<?>> futures = new ArrayList<>();
int numChannels = randomIntBetween(1, 30);
for (int i = 0; i < numChannels; i++) {
int numTasks = randomIntBetween(1, 30);
TestHttpChannel channel = new TestHttpChannel();
totalSearches += numTasks;
for (int j = 0; j < numTasks; j++) {
PlainListenableActionFuture<SearchResponse> actionFuture = PlainListenableActionFuture.newListenableFuture();
threadPool.generic().submit(() -> httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(),
SearchAction.INSTANCE, actionFuture));
futures.add(actionFuture);
}
}
for (Future<?> future : futures) {
future.get();
}
//no channels get closed in this test, hence we expect as many channels as we created in the map
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
for (Map.Entry<HttpChannel, HttpChannelTaskHandler.CloseListener> entry : httpChannelTaskHandler.httpChannels.entrySet()) {
assertEquals(0, entry.getValue().getNumTasks());
}
assertEquals(totalSearches, testClient.searchRequests.get());
}
}
/**
* This test verifies the behaviour when the channel gets closed. The channel is expected to be
* removed and all of its corresponding tasks get cancelled.
*/
public void testCancelledTasks() throws Exception {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int numChannels = randomIntBetween(1, 30);
int totalSearches = 0;
List<TestHttpChannel> channels = new ArrayList<>(numChannels);
for (int i = 0; i < numChannels; i++) {
TestHttpChannel channel = new TestHttpChannel();
channels.add(channel);
int numTasks = randomIntBetween(1, 30);
totalSearches += numTasks;
for (int j = 0; j < numTasks; j++) {
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
}
assertEquals(numTasks, httpChannelTaskHandler.httpChannels.get(channel).getNumTasks());
}
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
for (TestHttpChannel channel : channels) {
channel.awaitClose();
}
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(totalSearches, testClient.searchRequests.get());
assertEquals(totalSearches, testClient.cancelledTasks.size());
}
}
/**
* This test verified what happens when a request comes through yet its corresponding http channel is already closed.
* The close listener is straight-away executed, the task is cancelled. This can even happen multiple times, it's the only case
* where we may end up registering a close listener multiple times to the channel, but the channel is already closed hence only
* the newly added listener will be invoked at registration time.
*/
public void testChannelAlreadyClosed() {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int numChannels = randomIntBetween(1, 30);
int totalSearches = 0;
for (int i = 0; i < numChannels; i++) {
TestHttpChannel channel = new TestHttpChannel();
//no need to wait here, there will be no close listener registered, nothing to wait for.
channel.close();
int numTasks = randomIntBetween(1, 5);
totalSearches += numTasks;
for (int j = 0; j < numTasks; j++) {
//here the channel will be first registered, then straight-away removed from the map as the close listener is invoked
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
}
}
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(totalSearches, testClient.searchRequests.get());
assertEquals(totalSearches, testClient.cancelledTasks.size());
}
}
private static class TestClient extends NodeClient {
private final AtomicLong counter = new AtomicLong(0);
private final Set<TaskId> cancelledTasks = new CopyOnWriteArraySet<>();
private final AtomicInteger searchRequests = new AtomicInteger(0);
private final boolean timeout;
TestClient(Settings settings, ThreadPool threadPool, boolean timeout) {
super(settings, threadPool);
this.timeout = timeout;
}
@Override
public <Request extends ActionRequest, Response extends ActionResponse> Task executeLocally(ActionType<Response> action,
Request request,
ActionListener<Response> listener) {
switch(action.name()) {
case CancelTasksAction.NAME:
CancelTasksRequest cancelTasksRequest = (CancelTasksRequest) request;
assertTrue("tried to cancel the same task more than once", cancelledTasks.add(cancelTasksRequest.getTaskId()));
Task task = request.createTask(counter.getAndIncrement(), "cancel_task", action.name(), null, Collections.emptyMap());
if (randomBoolean()) {
listener.onResponse(null);
} else {
//test that cancel tasks is best effort, failure received are not propagated
listener.onFailure(new IllegalStateException());
}
return task;
case SearchAction.NAME:
searchRequests.incrementAndGet();
Task searchTask = request.createTask(counter.getAndIncrement(), "search", action.name(), null, Collections.emptyMap());
if (timeout == false) {
if (rarely()) {
//make sure that search is sometimes also called from the same thread before the task is returned
listener.onResponse(null);
} else {
threadPool().generic().submit(() -> listener.onResponse(null));
}
}
return searchTask;
default:
throw new UnsupportedOperationException();
}
}
@Override
public String getLocalNodeId() {
return "node";
}
}
private class TestHttpChannel implements HttpChannel {
private final AtomicBoolean open = new AtomicBoolean(true);
private final AtomicReference<ActionListener<Void>> closeListener = new AtomicReference<>();
private final CountDownLatch closeLatch = new CountDownLatch(1);
@Override
public void sendResponse(HttpResponse response, ActionListener<Void> listener) {
}
@Override
public InetSocketAddress getLocalAddress() {
return null;
}
@Override
public InetSocketAddress getRemoteAddress() {
return null;
}
@Override
public void close() {
if (open.compareAndSet(true, false) == false) {
throw new IllegalStateException("channel already closed!");
}
ActionListener<Void> listener = closeListener.get();
if (listener != null) {
boolean failure = randomBoolean();
threadPool.generic().submit(() -> {
if (failure) {
listener.onFailure(new IllegalStateException());
} else {
listener.onResponse(null);
}
closeLatch.countDown();
});
}
}
private void awaitClose() throws InterruptedException {
close();
closeLatch.await();
}
@Override
public boolean isOpen() {
return open.get();
}
@Override
public void addCloseListener(ActionListener<Void> listener) {
//if the channel is already closed, the listener gets notified immediately, from the same thread.
if (open.get() == false) {
listener.onResponse(null);
} else {
if (closeListener.compareAndSet(null, listener) == false) {
throw new IllegalStateException("close listener already set, only one is allowed!");
}
}
}
}
}

View File

@ -125,6 +125,7 @@ import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.script.ScriptMetaData;
import org.elasticsearch.rest.action.search.HttpChannelTaskHandler;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.MockSearchService;
import org.elasticsearch.search.SearchHit;
@ -536,6 +537,9 @@ public abstract class ESIntegTestCase extends ESTestCase {
restClient.close();
restClient = null;
}
assertEquals(HttpChannelTaskHandler.INSTANCE.getNumChannels() + " channels still being tracked in " +
HttpChannelTaskHandler.class.getSimpleName() + " while there should be none", 0,
HttpChannelTaskHandler.INSTANCE.getNumChannels());
}
private void afterInternal(boolean afterClass) throws Exception {