diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentActionCoordinator.java b/server/src/main/java/org/elasticsearch/persistent/PersistentActionCoordinator.java index c6491a36d37..b2c0b66f7be 100644 --- a/server/src/main/java/org/elasticsearch/persistent/PersistentActionCoordinator.java +++ b/server/src/main/java/org/elasticsearch/persistent/PersistentActionCoordinator.java @@ -82,7 +82,7 @@ public class PersistentActionCoordinator extends AbstractComponent implements Cl String localNodeId = event.state().getNodes().getLocalNodeId(); Set notVisitedTasks = new HashSet<>(runningTasks.keySet()); if (tasks != null) { - for (PersistentTaskInProgress taskInProgress : tasks.entries()) { + for (PersistentTaskInProgress taskInProgress : tasks.tasks()) { if (localNodeId.equals(taskInProgress.getExecutorNode())) { PersistentTaskId persistentTaskId = new PersistentTaskId(taskInProgress.getId(), taskInProgress.getAllocationId()); RunningPersistentTask persistentTask = runningTasks.get(persistentTaskId); diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTaskClusterService.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTaskClusterService.java index 53c8ecefb64..c99b934eca5 100644 --- a/server/src/main/java/org/elasticsearch/persistent/PersistentTaskClusterService.java +++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTaskClusterService.java @@ -30,11 +30,12 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.component.AbstractComponent; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.tasks.Task; -import org.elasticsearch.persistent.PersistentTasksInProgress.PersistentTaskInProgress; import org.elasticsearch.transport.TransportResponse.Empty; +import org.elasticsearch.persistent.PersistentTasksInProgress.PersistentTaskInProgress; -import java.util.ArrayList; -import java.util.List; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -69,17 +70,13 @@ public class PersistentTaskClusterService extends AbstractComponent implements C public ClusterState execute(ClusterState currentState) throws Exception { final String executorNodeId = executorNode(action, currentState, request); PersistentTasksInProgress tasksInProgress = currentState.custom(PersistentTasksInProgress.TYPE); - final List> currentTasks = new ArrayList<>(); - final long nextId; + long nextId; if (tasksInProgress != null) { nextId = tasksInProgress.getCurrentId() + 1; - currentTasks.addAll(tasksInProgress.entries()); } else { nextId = 1; } - currentTasks.add(new PersistentTaskInProgress<>(nextId, action, request, executorNodeId)); - ClusterState.Builder builder = ClusterState.builder(currentState); - return builder.putCustom(PersistentTasksInProgress.TYPE, new PersistentTasksInProgress(nextId, currentTasks)).build(); + return createPersistentTask(currentState, new PersistentTaskInProgress<>(nextId, action, request, executorNodeId)); } @Override @@ -118,23 +115,18 @@ public class PersistentTaskClusterService extends AbstractComponent implements C // Nothing to do, the task was already deleted return currentState; } - - boolean found = false; - final List> currentTasks = new ArrayList<>(); - for (PersistentTaskInProgress taskInProgress : tasksInProgress.entries()) { - if (taskInProgress.getId() == id) { - assert found == false; - found = true; - if (failure != null) { - // If the task failed - we need to restart it on another node, otherwise we just remove it - String executorNode = executorNode(taskInProgress.getAction(), currentState, taskInProgress.getRequest()); - currentTasks.add(new PersistentTaskInProgress<>(taskInProgress, executorNode)); - } - } else { - currentTasks.add(taskInProgress); + if (failure != null) { + // If the task failed - we need to restart it on another node, otherwise we just remove it + PersistentTaskInProgress taskInProgress = tasksInProgress.getTask(id); + if (taskInProgress != null) { + String executorNode = executorNode(taskInProgress.getAction(), currentState, taskInProgress.getRequest()); + return updatePersistentTask(currentState, new PersistentTaskInProgress<>(taskInProgress, executorNode)); } + return currentState; + } else { + return removePersistentTask(currentState, id); } - return rebuildClusterStateIfNeeded(found, currentState, currentTasks); + } @Override @@ -165,19 +157,11 @@ public class PersistentTaskClusterService extends AbstractComponent implements C // Nothing to do, the task no longer exists return currentState; } - - boolean found = false; - final List> currentTasks = new ArrayList<>(); - for (PersistentTaskInProgress taskInProgress : tasksInProgress.entries()) { - if (taskInProgress.getId() == id) { - assert found == false; - found = true; - currentTasks.add(new PersistentTaskInProgress<>(taskInProgress, status)); - } else { - currentTasks.add(taskInProgress); - } + PersistentTaskInProgress task = tasksInProgress.getTask(id); + if (task != null) { + return updatePersistentTask(currentState, new PersistentTaskInProgress<>(task, status)); } - return rebuildClusterStateIfNeeded(found, currentState, currentTasks); + return currentState; } @Override @@ -192,14 +176,40 @@ public class PersistentTaskClusterService extends AbstractComponent implements C }); } - private ClusterState rebuildClusterStateIfNeeded(boolean rebuild, ClusterState oldState, - List> currentTasks) { - if (rebuild) { + private ClusterState updatePersistentTask(ClusterState oldState, PersistentTaskInProgress newTask) { + PersistentTasksInProgress oldTasks = oldState.custom(PersistentTasksInProgress.TYPE); + Map> taskMap = new HashMap<>(); + taskMap.putAll(oldTasks.taskMap()); + taskMap.put(newTask.getId(), newTask); + ClusterState.Builder builder = ClusterState.builder(oldState); + PersistentTasksInProgress newTasks = new PersistentTasksInProgress(oldTasks.getCurrentId(), Collections.unmodifiableMap(taskMap)); + return builder.putCustom(PersistentTasksInProgress.TYPE, newTasks).build(); + } + + private ClusterState createPersistentTask(ClusterState oldState, PersistentTaskInProgress newTask) { + PersistentTasksInProgress oldTasks = oldState.custom(PersistentTasksInProgress.TYPE); + Map> taskMap = new HashMap<>(); + if (oldTasks != null) { + taskMap.putAll(oldTasks.taskMap()); + } + taskMap.put(newTask.getId(), newTask); + ClusterState.Builder builder = ClusterState.builder(oldState); + PersistentTasksInProgress newTasks = new PersistentTasksInProgress(newTask.getId(), Collections.unmodifiableMap(taskMap)); + return builder.putCustom(PersistentTasksInProgress.TYPE, newTasks).build(); + } + + private ClusterState removePersistentTask(ClusterState oldState, long taskId) { + PersistentTasksInProgress oldTasks = oldState.custom(PersistentTasksInProgress.TYPE); + if (oldTasks != null) { + Map> taskMap = new HashMap<>(); ClusterState.Builder builder = ClusterState.builder(oldState); - PersistentTasksInProgress oldTasks = oldState.custom(PersistentTasksInProgress.TYPE); - PersistentTasksInProgress tasks = new PersistentTasksInProgress(oldTasks.getCurrentId(), currentTasks); - return builder.putCustom(PersistentTasksInProgress.TYPE, tasks).build(); + taskMap.putAll(oldTasks.taskMap()); + taskMap.remove(taskId); + PersistentTasksInProgress newTasks = + new PersistentTasksInProgress(oldTasks.getCurrentId(), Collections.unmodifiableMap(taskMap)); + return builder.putCustom(PersistentTasksInProgress.TYPE, newTasks).build(); } else { + // no tasks - nothing to do return oldState; } } @@ -227,7 +237,7 @@ public class PersistentTaskClusterService extends AbstractComponent implements C // We need to check if removed nodes were running any of the tasks and reassign them boolean reassignmentRequired = false; Set removedNodes = event.nodesDelta().removedNodes().stream().map(DiscoveryNode::getId).collect(Collectors.toSet()); - for (PersistentTaskInProgress taskInProgress : tasks.entries()) { + for (PersistentTaskInProgress taskInProgress : tasks.tasks()) { if (taskInProgress.getExecutorNode() == null) { // there is an unassigned task - we need to try assigning it reassignmentRequired = true; @@ -258,22 +268,12 @@ public class PersistentTaskClusterService extends AbstractComponent implements C DiscoveryNodes nodes = currentState.nodes(); if (tasks != null) { // We need to check if removed nodes were running any of the tasks and reassign them - for (PersistentTaskInProgress task : tasks.entries()) { + for (PersistentTaskInProgress task : tasks.tasks()) { if (task.getExecutorNode() == null || nodes.nodeExists(task.getExecutorNode()) == false) { // there is an unassigned task - we need to try assigning it String executorNode = executorNode(task.getAction(), currentState, task.getRequest()); if (Objects.equals(executorNode, task.getExecutorNode()) == false) { - PersistentTasksInProgress tasksInProgress = newClusterState.custom(PersistentTasksInProgress.TYPE); - final List> currentTasks = new ArrayList<>(); - for (PersistentTaskInProgress taskInProgress : tasksInProgress.entries()) { - if (task.getId() == taskInProgress.getId()) { - currentTasks.add(new PersistentTaskInProgress<>(task, executorNode)); - } else { - currentTasks.add(taskInProgress); - } - } - newClusterState = ClusterState.builder(newClusterState).putCustom(PersistentTasksInProgress.TYPE, - new PersistentTasksInProgress(tasksInProgress.getCurrentId(), currentTasks)).build(); + newClusterState = updatePersistentTask(newClusterState, new PersistentTaskInProgress<>(task, executorNode)); } } } diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksInProgress.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksInProgress.java index cd29a8e484b..5f6a4b14e58 100644 --- a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksInProgress.java +++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksInProgress.java @@ -33,7 +33,7 @@ import org.elasticsearch.tasks.Task.Status; import java.io.IOException; import java.util.Collection; -import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -44,29 +44,37 @@ import java.util.stream.Collectors; public final class PersistentTasksInProgress extends AbstractNamedDiffable implements ClusterState.Custom { public static final String TYPE = "persistent_tasks"; - // TODO: Implement custom Diff for entries - private final List> entries; + // TODO: Implement custom Diff for tasks + private final Map> tasks; private final long currentId; - public PersistentTasksInProgress(long currentId, List> entries) { + public PersistentTasksInProgress(long currentId, Map> tasks) { this.currentId = currentId; - this.entries = entries; + this.tasks = tasks; } - public List> entries() { - return this.entries; + public Collection> tasks() { + return this.tasks.values(); } - public Collection> findEntries(String actionName, Predicate> predicate) { - return this.entries().stream() + public Map> taskMap() { + return this.tasks; + } + + public PersistentTaskInProgress getTask(long id) { + return this.tasks.get(id); + } + + public Collection> findTasks(String actionName, Predicate> predicate) { + return this.tasks().stream() .filter(p -> actionName.equals(p.getAction())) .filter(predicate) .collect(Collectors.toList()); } - public boolean entriesExist(String actionName, Predicate> predicate) { - return this.entries().stream() + public boolean tasksExist(String actionName, Predicate> predicate) { + return this.tasks().stream() .filter(p -> actionName.equals(p.getAction())) .anyMatch(predicate); } @@ -77,16 +85,16 @@ public final class PersistentTasksInProgress extends AbstractNamedDiffable action.equals(task.action) && nodeId.equals(task.executorNode)).count(); + return tasks.values().stream().filter(task -> action.equals(task.action) && nodeId.equals(task.executorNode)).count(); } @Override @@ -97,7 +105,7 @@ public final class PersistentTasksInProgress extends AbstractNamedDiffable implements Writeable { + public static class PersistentTaskInProgress implements Writeable, ToXContent { private final long id; private final long allocationId; private final String action; @@ -196,6 +204,28 @@ public final class PersistentTasksInProgress extends AbstractNamedDiffable { + value.writeTo(stream); + }); } public static NamedDiff readDiffFrom(StreamInput in) throws IOException { @@ -227,25 +259,11 @@ public final class PersistentTasksInProgress extends AbstractNamedDiffable entry : entries) { - toXContent(entry, builder, params); + for (PersistentTaskInProgress entry : tasks.values()) { + entry.toXContent(builder, params); } builder.endArray(); return builder; } - public void toXContent(PersistentTaskInProgress entry, XContentBuilder builder, ToXContent.Params params) throws IOException { - builder.startObject(); - { - builder.field("uuid", entry.id); - builder.field("action", entry.action); - builder.field("request"); - entry.request.toXContent(builder, params); - if (entry.status != null) { - builder.field("status", entry.status, params); - } - builder.field("executor_node", entry.executorNode); - } - builder.endObject(); - } } diff --git a/server/src/test/java/org/elasticsearch/persistent/PersistentActionCoordinatorTests.java b/server/src/test/java/org/elasticsearch/persistent/PersistentActionCoordinatorTests.java index 2454ac5d890..6a510a55108 100644 --- a/server/src/test/java/org/elasticsearch/persistent/PersistentActionCoordinatorTests.java +++ b/server/src/test/java/org/elasticsearch/persistent/PersistentActionCoordinatorTests.java @@ -41,7 +41,9 @@ import org.elasticsearch.transport.TransportResponse.Empty; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -85,17 +87,17 @@ public class PersistentActionCoordinatorTests extends ESTestCase { ClusterState state = ClusterState.builder(clusterService.state()).nodes(createTestNodes(nonLocalNodesCount, Settings.EMPTY)) .build(); - List> tasks = new ArrayList<>(); + Map> tasks = new HashMap<>(); long taskId = randomLong(); boolean added = false; if (nonLocalNodesCount > 0) { for (int i = 0; i < randomInt(5); i++) { - tasks.add(new PersistentTaskInProgress<>(taskId, "test_action", new TestRequest("other_" + i), + tasks.put(taskId, new PersistentTaskInProgress<>(taskId, "test_action", new TestRequest("other_" + i), "other_node_" + randomInt(nonLocalNodesCount))); taskId++; if (added == false && randomBoolean()) { added = true; - tasks.add(new PersistentTaskInProgress<>(taskId, "test", new TestRequest("this_param"), "this_node")); + tasks.put(taskId, new PersistentTaskInProgress<>(taskId, "test", new TestRequest("this_param"), "this_node")); taskId++; } } @@ -302,38 +304,33 @@ public class PersistentActionCoordinatorTests extends ESTestCase { private ClusterState addTask(ClusterState state, String action, Request request, String node) { PersistentTasksInProgress prevTasks = state.custom(PersistentTasksInProgress.TYPE); - List> tasks = prevTasks == null ? new ArrayList<>() : new ArrayList<>(prevTasks.entries()); - tasks.add(new PersistentTaskInProgress<>(prevTasks == null ? 0 : prevTasks.getCurrentId(), action, request, node)); + Map> tasks = prevTasks == null ? new HashMap<>() : new HashMap<>(prevTasks.taskMap()); + long id = prevTasks == null ? 0 : prevTasks.getCurrentId(); + tasks.put(id, new PersistentTaskInProgress<>(id, action, request, node)); return ClusterState.builder(state).putCustom(PersistentTasksInProgress.TYPE, new PersistentTasksInProgress(prevTasks == null ? 1 : prevTasks.getCurrentId() + 1, tasks)).build(); } private ClusterState reallocateTask(ClusterState state, long taskId, String node) { PersistentTasksInProgress prevTasks = state.custom(PersistentTasksInProgress.TYPE); - List> tasks = prevTasks == null ? new ArrayList<>() : new ArrayList<>(prevTasks.entries()); - for (int i = 0; i < tasks.size(); i++) { - if (tasks.get(i).getId() == taskId) { - tasks.set(i, new PersistentTaskInProgress<>(tasks.get(i), node)); - return ClusterState.builder(state).putCustom(PersistentTasksInProgress.TYPE, - new PersistentTasksInProgress(prevTasks == null ? 1 : prevTasks.getCurrentId() + 1, tasks)).build(); - } - } - fail("didn't find task with id " + taskId); - return null; + assertNotNull(prevTasks); + Map> tasks = new HashMap<>(prevTasks.taskMap()); + PersistentTaskInProgress prevTask = tasks.get(taskId); + assertNotNull(prevTask); + tasks.put(prevTask.getId(), new PersistentTaskInProgress<>(prevTask, node)); + return ClusterState.builder(state).putCustom(PersistentTasksInProgress.TYPE, + new PersistentTasksInProgress(prevTasks.getCurrentId(), tasks)).build(); } private ClusterState removeTask(ClusterState state, long taskId) { PersistentTasksInProgress prevTasks = state.custom(PersistentTasksInProgress.TYPE); - List> tasks = prevTasks == null ? new ArrayList<>() : new ArrayList<>(prevTasks.entries()); - for (int i = 0; i < tasks.size(); i++) { - if (tasks.get(i).getId() == taskId) { - tasks.remove(i); - return ClusterState.builder(state).putCustom(PersistentTasksInProgress.TYPE, - new PersistentTasksInProgress(prevTasks == null ? 1 : prevTasks.getCurrentId() + 1, tasks)).build(); - } - } - fail("didn't find task with id " + taskId); - return null; + assertNotNull(prevTasks); + Map> tasks = new HashMap<>(prevTasks.taskMap()); + PersistentTaskInProgress prevTask = tasks.get(taskId); + assertNotNull(prevTask); + tasks.remove(prevTask.getId()); + return ClusterState.builder(state).putCustom(PersistentTasksInProgress.TYPE, + new PersistentTasksInProgress(prevTasks.getCurrentId(), tasks)).build(); } private class Execution { diff --git a/server/src/test/java/org/elasticsearch/persistent/PersistentActionIT.java b/server/src/test/java/org/elasticsearch/persistent/PersistentActionIT.java index 2c1b7ccefb9..4d0163321b6 100644 --- a/server/src/test/java/org/elasticsearch/persistent/PersistentActionIT.java +++ b/server/src/test/java/org/elasticsearch/persistent/PersistentActionIT.java @@ -25,6 +25,7 @@ import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.persistent.TestPersistentActionPlugin.TestPersistentAction; import org.elasticsearch.persistent.TestPersistentActionPlugin.TestTasksRequestBuilder; +import org.junit.After; import java.util.Collection; import java.util.Collections; @@ -66,6 +67,11 @@ public class PersistentActionIT extends ESIntegTestCase { .build(); } + @After + public void cleanup() throws Exception { + assertNoRunningTasks(); + } + public void testPersistentActionRestart() throws Exception { long taskId = TestPersistentAction.INSTANCE.newRequestBuilder(client()).testParam("Blah").get().getTaskId(); assertBusy(() -> { @@ -133,8 +139,6 @@ public class PersistentActionIT extends ESIntegTestCase { .get().getTasks().size(), equalTo(1)); } - - assertNoRunningTasks(); } public void testPersistentActionWithNoAvailableNode() throws Exception { @@ -179,8 +183,8 @@ public class PersistentActionIT extends ESIntegTestCase { .get().getTasks().get(0); PersistentTasksInProgress tasksInProgress = internalCluster().clusterService().state().custom(PersistentTasksInProgress.TYPE); - assertThat(tasksInProgress.entries().size(), equalTo(1)); - assertThat(tasksInProgress.entries().get(0).getStatus(), nullValue()); + assertThat(tasksInProgress.tasks().size(), equalTo(1)); + assertThat(tasksInProgress.tasks().iterator().next().getStatus(), nullValue()); int numberOfUpdates = randomIntBetween(1, 10); for (int i = 0; i < numberOfUpdates; i++) { @@ -192,9 +196,9 @@ public class PersistentActionIT extends ESIntegTestCase { int finalI = i; assertBusy(() -> { PersistentTasksInProgress tasks = internalCluster().clusterService().state().custom(PersistentTasksInProgress.TYPE); - assertThat(tasks.entries().size(), equalTo(1)); - assertThat(tasks.entries().get(0).getStatus(), notNullValue()); - assertThat(tasks.entries().get(0).getStatus().toString(), equalTo("{\"phase\":\"phase " + (finalI + 1) + "\"}")); + assertThat(tasks.tasks().size(), equalTo(1)); + assertThat(tasks.tasks().iterator().next().getStatus(), notNullValue()); + assertThat(tasks.tasks().iterator().next().getStatus().toString(), equalTo("{\"phase\":\"phase " + (finalI + 1) + "\"}")); }); } @@ -203,8 +207,6 @@ public class PersistentActionIT extends ESIntegTestCase { // Complete the running task and make sure it finishes properly assertThat(new TestTasksRequestBuilder(client()).setOperation("finish").setTaskId(firstRunningTask.getTaskId()) .get().getTasks().size(), equalTo(1)); - - assertNoRunningTasks(); } private void assertNoRunningTasks() throws Exception { @@ -217,7 +219,7 @@ public class PersistentActionIT extends ESIntegTestCase { // Make sure the task is removed from the cluster state assertThat(((PersistentTasksInProgress) internalCluster().clusterService().state().custom(PersistentTasksInProgress.TYPE)) - .entries(), empty()); + .tasks(), empty()); }); } diff --git a/server/src/test/java/org/elasticsearch/persistent/PersistentTasksInProgressTests.java b/server/src/test/java/org/elasticsearch/persistent/PersistentTasksInProgressTests.java index e9f91740507..582eac2d884 100644 --- a/server/src/test/java/org/elasticsearch/persistent/PersistentTasksInProgressTests.java +++ b/server/src/test/java/org/elasticsearch/persistent/PersistentTasksInProgressTests.java @@ -27,16 +27,16 @@ import org.elasticsearch.persistent.PersistentTasksInProgress.PersistentTaskInPr import org.elasticsearch.persistent.TestPersistentActionPlugin.Status; import org.elasticsearch.persistent.TestPersistentActionPlugin.TestPersistentAction; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; +import java.util.HashMap; +import java.util.Map; public class PersistentTasksInProgressTests extends AbstractWireSerializingTestCase { @Override protected PersistentTasksInProgress createTestInstance() { int numberOfTasks = randomInt(10); - List> entries = new ArrayList<>(); + Map> entries = new HashMap<>(); for (int i = 0; i < numberOfTasks; i++) { PersistentTaskInProgress taskInProgress = new PersistentTaskInProgress<>( randomLong(), randomAsciiOfLength(10), new TestPersistentActionPlugin.TestRequest(randomAsciiOfLength(10)), @@ -45,7 +45,7 @@ public class PersistentTasksInProgressTests extends AbstractWireSerializingTestC // From time to time update status taskInProgress = new PersistentTaskInProgress<>(taskInProgress, new Status(randomAsciiOfLength(10))); } - entries.add(taskInProgress); + entries.put(taskInProgress.getId(), taskInProgress); } return new PersistentTasksInProgress(randomLong(), entries); }