diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java b/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java new file mode 100644 index 00000000000..5864551854f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java @@ -0,0 +1,155 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.rest.action.search; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; +import org.elasticsearch.action.support.ContextPreservingActionListener; +import org.elasticsearch.client.Client; +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +/** + * This class executes a request and associates the corresponding {@link Task} with the {@link HttpChannel} that it was originated from, + * so that the tasks associated with a certain channel get cancelled when the underlying connection gets closed. + */ +public final class HttpChannelTaskHandler { + + public static final HttpChannelTaskHandler INSTANCE = new HttpChannelTaskHandler(); + //package private for testing + final Map httpChannels = new ConcurrentHashMap<>(); + + private HttpChannelTaskHandler() { + } + + void execute(NodeClient client, HttpChannel httpChannel, ActionRequest request, + ActionType actionType, ActionListener listener) { + + CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener(client)); + TaskHolder taskHolder = new TaskHolder(); + Task task = client.executeLocally(actionType, request, + new ActionListener<>() { + @Override + public void onResponse(Response searchResponse) { + try { + closeListener.unregisterTask(taskHolder); + } finally { + listener.onResponse(searchResponse); + } + } + + @Override + public void onFailure(Exception e) { + try { + closeListener.unregisterTask(taskHolder); + } finally { + listener.onFailure(e); + } + } + }); + closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId())); + closeListener.maybeRegisterChannel(httpChannel); + } + + public int getNumChannels() { + return httpChannels.size(); + } + + final class CloseListener implements ActionListener { + private final Client client; + private final AtomicReference channel = new AtomicReference<>(); + private final Set taskIds = new HashSet<>(); + + CloseListener(Client client) { + this.client = client; + } + + int getNumTasks() { + return taskIds.size(); + } + + void maybeRegisterChannel(HttpChannel httpChannel) { + if (channel.compareAndSet(null, httpChannel)) { + //In case the channel is already closed when we register the listener, the listener will be immediately executed which will + //remove the channel from the map straight-away. That is why we first create the CloseListener and later we associate it + //with the channel. This guarantees that the close listener is already in the map when the it gets registered to its + //corresponding channel, hence it is always found in the map when it gets invoked if the channel gets closed. + httpChannel.addCloseListener(this); + } + } + + synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) { + taskHolder.taskId = taskId; + if (taskHolder.completed == false) { + this.taskIds.add(taskId); + } + } + + synchronized void unregisterTask(TaskHolder taskHolder) { + if (taskHolder.taskId != null) { + this.taskIds.remove(taskHolder.taskId); + } + taskHolder.completed = true; + } + + @Override + public synchronized void onResponse(Void aVoid) { + //When the channel gets closed it won't be reused: we can remove it from the map and forget about it. + CloseListener closeListener = httpChannels.remove(channel.get()); + assert closeListener != null : "channel not found in the map of tracked channels"; + for (TaskId taskId : taskIds) { + ThreadContext threadContext = client.threadPool().getThreadContext(); + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + // we stash any context here since this is an internal execution and should not leak any existing context information + threadContext.markAsSystemContext(); + ContextPreservingActionListener contextPreservingListener = new ContextPreservingActionListener<>( + threadContext.newRestorableContext(false), ActionListener.wrap(r -> {}, e -> {})); + CancelTasksRequest cancelTasksRequest = new CancelTasksRequest(); + cancelTasksRequest.setTaskId(taskId); + //We don't wait for cancel tasks to come back. Task cancellation is just best effort. + client.admin().cluster().cancelTasks(cancelTasksRequest, contextPreservingListener); + } + } + } + + @Override + public void onFailure(Exception e) { + onResponse(null); + } + } + + private static class TaskHolder { + private TaskId taskId; + private boolean completed = false; + } +} diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java index 4e935211dba..20dbbd4b55c 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java @@ -20,7 +20,9 @@ package org.elasticsearch.rest.action.search; import org.apache.logging.log4j.LogManager; +import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.Booleans; @@ -107,7 +109,10 @@ public class RestSearchAction extends BaseRestHandler { request.withContentOrSourceParamParserOrNull(parser -> parseSearchRequest(searchRequest, request, parser, setSize)); - return channel -> client.search(searchRequest, new RestStatusToXContentListener<>(channel)); + return channel -> { + RestStatusToXContentListener listener = new RestStatusToXContentListener<>(channel); + HttpChannelTaskHandler.INSTANCE.execute(client, request.getHttpChannel(), searchRequest, SearchAction.INSTANCE, listener); + }; } /** diff --git a/server/src/test/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandlerTests.java b/server/src/test/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandlerTests.java new file mode 100644 index 00000000000..103981abdc4 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandlerTests.java @@ -0,0 +1,280 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.rest.action.search; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksAction; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.PlainListenableActionFuture; +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpResponse; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CopyOnWriteArraySet; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +public class HttpChannelTaskHandlerTests extends ESTestCase { + + private ThreadPool threadPool; + + @Before + public void createThreadPool() { + threadPool = new TestThreadPool(HttpChannelTaskHandlerTests.class.getName()); + } + + @After + public void stopThreadPool() { + ThreadPool.terminate(threadPool, 5, TimeUnit.SECONDS); + } + + /** + * This test verifies that no tasks are left in the map where channels and their corresponding tasks are tracked. + * Through the {@link TestClient} we simulate a scenario where the task may complete even before it has been + * associated with its corresponding channel. Either way, we need to make sure that no tasks are left in the map. + */ + public void testCompletedTasks() throws Exception { + try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) { + HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; + int initialHttpChannels = httpChannelTaskHandler.getNumChannels(); + int totalSearches = 0; + List> futures = new ArrayList<>(); + int numChannels = randomIntBetween(1, 30); + for (int i = 0; i < numChannels; i++) { + int numTasks = randomIntBetween(1, 30); + TestHttpChannel channel = new TestHttpChannel(); + totalSearches += numTasks; + for (int j = 0; j < numTasks; j++) { + PlainListenableActionFuture actionFuture = PlainListenableActionFuture.newListenableFuture(); + threadPool.generic().submit(() -> httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), + SearchAction.INSTANCE, actionFuture)); + futures.add(actionFuture); + } + } + for (Future future : futures) { + future.get(); + } + //no channels get closed in this test, hence we expect as many channels as we created in the map + assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels()); + for (Map.Entry entry : httpChannelTaskHandler.httpChannels.entrySet()) { + assertEquals(0, entry.getValue().getNumTasks()); + } + assertEquals(totalSearches, testClient.searchRequests.get()); + } + } + + /** + * This test verifies the behaviour when the channel gets closed. The channel is expected to be + * removed and all of its corresponding tasks get cancelled. + */ + public void testCancelledTasks() throws Exception { + try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) { + HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; + int initialHttpChannels = httpChannelTaskHandler.getNumChannels(); + int numChannels = randomIntBetween(1, 30); + int totalSearches = 0; + List channels = new ArrayList<>(numChannels); + for (int i = 0; i < numChannels; i++) { + TestHttpChannel channel = new TestHttpChannel(); + channels.add(channel); + int numTasks = randomIntBetween(1, 30); + totalSearches += numTasks; + for (int j = 0; j < numTasks; j++) { + httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null); + } + assertEquals(numTasks, httpChannelTaskHandler.httpChannels.get(channel).getNumTasks()); + } + assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels()); + for (TestHttpChannel channel : channels) { + channel.awaitClose(); + } + assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels()); + assertEquals(totalSearches, testClient.searchRequests.get()); + assertEquals(totalSearches, testClient.cancelledTasks.size()); + } + } + + /** + * This test verified what happens when a request comes through yet its corresponding http channel is already closed. + * The close listener is straight-away executed, the task is cancelled. This can even happen multiple times, it's the only case + * where we may end up registering a close listener multiple times to the channel, but the channel is already closed hence only + * the newly added listener will be invoked at registration time. + */ + public void testChannelAlreadyClosed() { + try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) { + HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; + int initialHttpChannels = httpChannelTaskHandler.getNumChannels(); + int numChannels = randomIntBetween(1, 30); + int totalSearches = 0; + for (int i = 0; i < numChannels; i++) { + TestHttpChannel channel = new TestHttpChannel(); + //no need to wait here, there will be no close listener registered, nothing to wait for. + channel.close(); + int numTasks = randomIntBetween(1, 5); + totalSearches += numTasks; + for (int j = 0; j < numTasks; j++) { + //here the channel will be first registered, then straight-away removed from the map as the close listener is invoked + httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null); + } + } + assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels()); + assertEquals(totalSearches, testClient.searchRequests.get()); + assertEquals(totalSearches, testClient.cancelledTasks.size()); + } + } + + private static class TestClient extends NodeClient { + private final AtomicLong counter = new AtomicLong(0); + private final Set cancelledTasks = new CopyOnWriteArraySet<>(); + private final AtomicInteger searchRequests = new AtomicInteger(0); + private final boolean timeout; + + TestClient(Settings settings, ThreadPool threadPool, boolean timeout) { + super(settings, threadPool); + this.timeout = timeout; + } + + @Override + public Task executeLocally(ActionType action, + Request request, + ActionListener listener) { + switch(action.name()) { + case CancelTasksAction.NAME: + CancelTasksRequest cancelTasksRequest = (CancelTasksRequest) request; + assertTrue("tried to cancel the same task more than once", cancelledTasks.add(cancelTasksRequest.getTaskId())); + Task task = request.createTask(counter.getAndIncrement(), "cancel_task", action.name(), null, Collections.emptyMap()); + if (randomBoolean()) { + listener.onResponse(null); + } else { + //test that cancel tasks is best effort, failure received are not propagated + listener.onFailure(new IllegalStateException()); + } + + return task; + case SearchAction.NAME: + searchRequests.incrementAndGet(); + Task searchTask = request.createTask(counter.getAndIncrement(), "search", action.name(), null, Collections.emptyMap()); + if (timeout == false) { + if (rarely()) { + //make sure that search is sometimes also called from the same thread before the task is returned + listener.onResponse(null); + } else { + threadPool().generic().submit(() -> listener.onResponse(null)); + } + } + return searchTask; + default: + throw new UnsupportedOperationException(); + } + + } + + @Override + public String getLocalNodeId() { + return "node"; + } + } + + private class TestHttpChannel implements HttpChannel { + private final AtomicBoolean open = new AtomicBoolean(true); + private final AtomicReference> closeListener = new AtomicReference<>(); + private final CountDownLatch closeLatch = new CountDownLatch(1); + + @Override + public void sendResponse(HttpResponse response, ActionListener listener) { + } + + @Override + public InetSocketAddress getLocalAddress() { + return null; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return null; + } + + @Override + public void close() { + if (open.compareAndSet(true, false) == false) { + throw new IllegalStateException("channel already closed!"); + } + ActionListener listener = closeListener.get(); + if (listener != null) { + boolean failure = randomBoolean(); + threadPool.generic().submit(() -> { + if (failure) { + listener.onFailure(new IllegalStateException()); + } else { + listener.onResponse(null); + } + closeLatch.countDown(); + }); + } + } + + private void awaitClose() throws InterruptedException { + close(); + closeLatch.await(); + } + + @Override + public boolean isOpen() { + return open.get(); + } + + @Override + public void addCloseListener(ActionListener listener) { + //if the channel is already closed, the listener gets notified immediately, from the same thread. + if (open.get() == false) { + listener.onResponse(null); + } else { + if (closeListener.compareAndSet(null, listener) == false) { + throw new IllegalStateException("close listener already set, only one is allowed!"); + } + } + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java index b8428a23554..2754b1ff414 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java @@ -125,6 +125,7 @@ import org.elasticsearch.plugins.NetworkPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.script.ScriptMetaData; +import org.elasticsearch.rest.action.search.HttpChannelTaskHandler; import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.MockSearchService; import org.elasticsearch.search.SearchHit; @@ -536,6 +537,9 @@ public abstract class ESIntegTestCase extends ESTestCase { restClient.close(); restClient = null; } + assertEquals(HttpChannelTaskHandler.INSTANCE.getNumChannels() + " channels still being tracked in " + + HttpChannelTaskHandler.class.getSimpleName() + " while there should be none", 0, + HttpChannelTaskHandler.INSTANCE.getNumChannels()); } private void afterInternal(boolean afterClass) throws Exception {