diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java similarity index 52% rename from server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java rename to server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java index efda0c55f28..b5026f31924 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java +++ b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java @@ -17,54 +17,84 @@ * under the License. */ -package org.elasticsearch.rest.action.search; +package org.elasticsearch.rest.action; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; -import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; -import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.client.Client; +import org.elasticsearch.client.FilterClient; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.client.node.NodeClient; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; +import java.util.ArrayList; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN; + /** - * This class executes a request and associates the corresponding {@link Task} with the {@link HttpChannel} that it was originated from, - * so that the tasks associated with a certain channel get cancelled when the underlying connection gets closed. + * A {@linkplain Client} that cancels tasks executed locally when the provided {@link HttpChannel} + * is closed before completion. */ -public final class HttpChannelTaskHandler { +public class RestCancellableNodeClient extends FilterClient { + private static final Map httpChannels = new ConcurrentHashMap<>(); - public static final HttpChannelTaskHandler INSTANCE = new HttpChannelTaskHandler(); - //package private for testing - final Map httpChannels = new ConcurrentHashMap<>(); + private final NodeClient client; + private final HttpChannel httpChannel; - private HttpChannelTaskHandler() { + public RestCancellableNodeClient(NodeClient client, HttpChannel httpChannel) { + super(client); + this.client = client; + this.httpChannel = httpChannel; } - void execute(NodeClient client, HttpChannel httpChannel, ActionRequest request, - ActionType actionType, ActionListener listener) { + /** + * Returns the number of channels tracked globally. + */ + public static int getNumChannels() { + return httpChannels.size(); + } - CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener(client)); + /** + * Returns the number of tasks tracked globally. + */ + static int getNumTasks() { + return httpChannels.values().stream() + .mapToInt(CloseListener::getNumTasks) + .sum(); + } + + /** + * Returns the number of tasks tracked by the provided {@link HttpChannel}. + */ + static int getNumTasks(HttpChannel channel) { + CloseListener listener = httpChannels.get(channel); + return listener == null ? 0 : listener.getNumTasks(); + } + + @Override + public void doExecute( + ActionType action, Request request, ActionListener listener) { + CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener()); TaskHolder taskHolder = new TaskHolder(); - Task task = client.executeLocally(actionType, request, + Task task = client.executeLocally(action, request, new ActionListener() { @Override - public void onResponse(Response searchResponse) { + public void onResponse(Response response) { try { closeListener.unregisterTask(taskHolder); } finally { - listener.onResponse(searchResponse); + listener.onResponse(response); } } @@ -77,32 +107,35 @@ public final class HttpChannelTaskHandler { } } }); - closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId())); + final TaskId taskId = new TaskId(client.getLocalNodeId(), task.getId()); + closeListener.registerTask(taskHolder, taskId); closeListener.maybeRegisterChannel(httpChannel); } - public int getNumChannels() { - return httpChannels.size(); + private void cancelTask(TaskId taskId) { + CancelTasksRequest req = new CancelTasksRequest() + .setTaskId(taskId) + .setReason("channel closed"); + // force the origin to execute the cancellation as a system user + new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {})); } - final class CloseListener implements ActionListener { - private final Client client; + private class CloseListener implements ActionListener { private final AtomicReference channel = new AtomicReference<>(); - private final Set taskIds = new HashSet<>(); + private final Set tasks = new HashSet<>(); - CloseListener(Client client) { - this.client = client; + CloseListener() { } - int getNumTasks() { - return taskIds.size(); + synchronized int getNumTasks() { + return tasks.size(); } void maybeRegisterChannel(HttpChannel httpChannel) { if (channel.compareAndSet(null, httpChannel)) { //In case the channel is already closed when we register the listener, the listener will be immediately executed which will //remove the channel from the map straight-away. That is why we first create the CloseListener and later we associate it - //with the channel. This guarantees that the close listener is already in the map when the it gets registered to its + //with the channel. This guarantees that the close listener is already in the map when it gets registered to its //corresponding channel, hence it is always found in the map when it gets invoked if the channel gets closed. httpChannel.addCloseListener(this); } @@ -111,34 +144,31 @@ public final class HttpChannelTaskHandler { synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) { taskHolder.taskId = taskId; if (taskHolder.completed == false) { - this.taskIds.add(taskId); + this.tasks.add(taskId); } } synchronized void unregisterTask(TaskHolder taskHolder) { if (taskHolder.taskId != null) { - this.taskIds.remove(taskHolder.taskId); + this.tasks.remove(taskHolder.taskId); } taskHolder.completed = true; } @Override - public synchronized void onResponse(Void aVoid) { - //When the channel gets closed it won't be reused: we can remove it from the map and forget about it. - CloseListener closeListener = httpChannels.remove(channel.get()); + public void onResponse(Void aVoid) { + final HttpChannel httpChannel = channel.get(); + assert httpChannel != null : "channel not registered"; + // when the channel gets closed it won't be reused: we can remove it from the map and forget about it. + CloseListener closeListener = httpChannels.remove(httpChannel); assert closeListener != null : "channel not found in the map of tracked channels"; - for (TaskId taskId : taskIds) { - ThreadContext threadContext = client.threadPool().getThreadContext(); - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - // we stash any context here since this is an internal execution and should not leak any existing context information - threadContext.markAsSystemContext(); - ContextPreservingActionListener contextPreservingListener = new ContextPreservingActionListener<>( - threadContext.newRestorableContext(false), ActionListener.wrap(r -> {}, e -> {})); - CancelTasksRequest cancelTasksRequest = new CancelTasksRequest(); - cancelTasksRequest.setTaskId(taskId); - //We don't wait for cancel tasks to come back. Task cancellation is just best effort. - client.admin().cluster().cancelTasks(cancelTasksRequest, contextPreservingListener); - } + final List toCancel; + synchronized (this) { + toCancel = new ArrayList<>(tasks); + tasks.clear(); + } + for (TaskId taskId : toCancel) { + cancelTask(taskId); } } diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java index c238f47f765..8f4e052e8d3 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java @@ -22,7 +22,6 @@ package org.elasticsearch.rest.action.search; import org.apache.logging.log4j.LogManager; import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.Booleans; @@ -34,6 +33,7 @@ import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestActions; +import org.elasticsearch.rest.action.RestCancellableNodeClient; import org.elasticsearch.rest.action.RestStatusToXContentListener; import org.elasticsearch.search.Scroll; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -110,8 +110,8 @@ public class RestSearchAction extends BaseRestHandler { parseSearchRequest(searchRequest, request, parser, setSize)); return channel -> { - RestStatusToXContentListener listener = new RestStatusToXContentListener<>(channel); - HttpChannelTaskHandler.INSTANCE.execute(client, request.getHttpChannel(), searchRequest, SearchAction.INSTANCE, listener); + RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel()); + cancelClient.execute(SearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel)); }; } diff --git a/server/src/test/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandlerTests.java b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java similarity index 84% rename from server/src/test/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandlerTests.java rename to server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java index 103981abdc4..8121b315475 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java @@ -17,7 +17,7 @@ * under the License. */ -package org.elasticsearch.rest.action.search; +package org.elasticsearch.rest.action; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; @@ -45,7 +45,6 @@ import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.CountDownLatch; @@ -56,13 +55,13 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -public class HttpChannelTaskHandlerTests extends ESTestCase { +public class RestCancellableNodeClientTests extends ESTestCase { private ThreadPool threadPool; @Before public void createThreadPool() { - threadPool = new TestThreadPool(HttpChannelTaskHandlerTests.class.getName()); + threadPool = new TestThreadPool(RestCancellableNodeClientTests.class.getName()); } @After @@ -77,8 +76,7 @@ public class HttpChannelTaskHandlerTests extends ESTestCase { */ public void testCompletedTasks() throws Exception { try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) { - HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; - int initialHttpChannels = httpChannelTaskHandler.getNumChannels(); + int initialHttpChannels = RestCancellableNodeClient.getNumChannels(); int totalSearches = 0; List> futures = new ArrayList<>(); int numChannels = randomIntBetween(1, 30); @@ -88,8 +86,8 @@ public class HttpChannelTaskHandlerTests extends ESTestCase { totalSearches += numTasks; for (int j = 0; j < numTasks; j++) { PlainListenableActionFuture actionFuture = PlainListenableActionFuture.newListenableFuture(); - threadPool.generic().submit(() -> httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), - SearchAction.INSTANCE, actionFuture)); + RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel); + threadPool.generic().submit(() -> client.execute(SearchAction.INSTANCE, new SearchRequest(), actionFuture)); futures.add(actionFuture); } } @@ -97,10 +95,8 @@ public class HttpChannelTaskHandlerTests extends ESTestCase { future.get(); } //no channels get closed in this test, hence we expect as many channels as we created in the map - assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels()); - for (Map.Entry entry : httpChannelTaskHandler.httpChannels.entrySet()) { - assertEquals(0, entry.getValue().getNumTasks()); - } + assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels()); + assertEquals(0, RestCancellableNodeClient.getNumTasks()); assertEquals(totalSearches, testClient.searchRequests.get()); } } @@ -110,9 +106,8 @@ public class HttpChannelTaskHandlerTests extends ESTestCase { * removed and all of its corresponding tasks get cancelled. */ public void testCancelledTasks() throws Exception { - try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) { - HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; - int initialHttpChannels = httpChannelTaskHandler.getNumChannels(); + try (TestClient nodeClient = new TestClient(Settings.EMPTY, threadPool, true)) { + int initialHttpChannels = RestCancellableNodeClient.getNumChannels(); int numChannels = randomIntBetween(1, 30); int totalSearches = 0; List channels = new ArrayList<>(numChannels); @@ -121,18 +116,19 @@ public class HttpChannelTaskHandlerTests extends ESTestCase { channels.add(channel); int numTasks = randomIntBetween(1, 30); totalSearches += numTasks; + RestCancellableNodeClient client = new RestCancellableNodeClient(nodeClient, channel); for (int j = 0; j < numTasks; j++) { - httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null); + client.execute(SearchAction.INSTANCE, new SearchRequest(), null); } - assertEquals(numTasks, httpChannelTaskHandler.httpChannels.get(channel).getNumTasks()); + assertEquals(numTasks, RestCancellableNodeClient.getNumTasks(channel)); } - assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels()); + assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels()); for (TestHttpChannel channel : channels) { channel.awaitClose(); } - assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels()); - assertEquals(totalSearches, testClient.searchRequests.get()); - assertEquals(totalSearches, testClient.cancelledTasks.size()); + assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels()); + assertEquals(totalSearches, nodeClient.searchRequests.get()); + assertEquals(totalSearches, nodeClient.cancelledTasks.size()); } } @@ -144,8 +140,7 @@ public class HttpChannelTaskHandlerTests extends ESTestCase { */ public void testChannelAlreadyClosed() { try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) { - HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; - int initialHttpChannels = httpChannelTaskHandler.getNumChannels(); + int initialHttpChannels = RestCancellableNodeClient.getNumChannels(); int numChannels = randomIntBetween(1, 30); int totalSearches = 0; for (int i = 0; i < numChannels; i++) { @@ -154,12 +149,13 @@ public class HttpChannelTaskHandlerTests extends ESTestCase { channel.close(); int numTasks = randomIntBetween(1, 5); totalSearches += numTasks; + RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel); for (int j = 0; j < numTasks; j++) { //here the channel will be first registered, then straight-away removed from the map as the close listener is invoked - httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null); + client.execute(SearchAction.INSTANCE, new SearchRequest(), null); } } - assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels()); + assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels()); assertEquals(totalSearches, testClient.searchRequests.get()); assertEquals(totalSearches, testClient.cancelledTasks.size()); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java index 062c2f970b6..20240b26209 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java @@ -125,8 +125,8 @@ import org.elasticsearch.node.NodeMocksPlugin; import org.elasticsearch.plugins.NetworkPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.rest.action.RestCancellableNodeClient; import org.elasticsearch.script.ScriptMetaData; -import org.elasticsearch.rest.action.search.HttpChannelTaskHandler; import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.MockSearchService; import org.elasticsearch.search.SearchHit; @@ -536,9 +536,11 @@ public abstract class ESIntegTestCase extends ESTestCase { restClient.close(); restClient = null; } - assertBusy(() -> assertEquals(HttpChannelTaskHandler.INSTANCE.getNumChannels() + " channels still being tracked in " + - HttpChannelTaskHandler.class.getSimpleName() + " while there should be none", 0, - HttpChannelTaskHandler.INSTANCE.getNumChannels())); + assertBusy(() -> { + int numChannels = RestCancellableNodeClient.getNumChannels(); + assertEquals( numChannels+ " channels still being tracked in " + RestCancellableNodeClient.class.getSimpleName() + + " while there should be none", 0, numChannels); + }); } private void afterInternal(boolean afterClass) throws Exception {