diff --git a/core/src/test/java/org/elasticsearch/cluster/ClusterServiceIT.java b/core/src/test/java/org/elasticsearch/cluster/ClusterServiceIT.java index 1fb6c06a73c..60e7fb29041 100644 --- a/core/src/test/java/org/elasticsearch/cluster/ClusterServiceIT.java +++ b/core/src/test/java/org/elasticsearch/cluster/ClusterServiceIT.java @@ -44,9 +44,12 @@ import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.threadpool.ThreadPool; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.common.settings.Settings.settingsBuilder; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; @@ -708,18 +711,18 @@ public class ClusterServiceIT extends ESIntegTestCase { Priority[] priorities = Priority.values(); // will hold all the tasks in the order in which they were executed - List tasks = new ArrayList<>(taskCount); + List tasks = new ArrayList<>(taskCount); CountDownLatch latch = new CountDownLatch(taskCount); for (int i = 0; i < taskCount; i++) { Priority priority = priorities[randomIntBetween(0, priorities.length - 1)]; - clusterService.submitStateUpdateTask("test", new PrioritiezedTask(priority, latch, tasks)); + clusterService.submitStateUpdateTask("test", new PrioritizedTask(priority, latch, tasks)); } block.release(); latch.await(); Priority prevPriority = null; - for (PrioritiezedTask task : tasks) { + for (PrioritizedTask task : tasks) { if (prevPriority == null) { prevPriority = task.priority(); } else { @@ -728,6 +731,120 @@ public class ClusterServiceIT extends ESIntegTestCase { } } + public void testClusterStateBatchedUpdates() throws InterruptedException { + Settings settings = settingsBuilder() + .put("discovery.type", "local") + .build(); + internalCluster().startNode(settings); + ClusterService clusterService = internalCluster().getInstance(ClusterService.class); + + AtomicInteger counter = new AtomicInteger(); + class Task { + private AtomicBoolean state = new AtomicBoolean(); + + public void execute() { + if (!state.compareAndSet(false, true)) { + throw new IllegalStateException(); + } else { + counter.incrementAndGet(); + } + } + } + + class TaskExecutor implements ClusterStateTaskExecutor { + private AtomicInteger counter = new AtomicInteger(); + + @Override + public Result execute(ClusterState currentState, List tasks) throws Exception { + tasks.forEach(task -> task.execute()); + counter.addAndGet(tasks.size()); + return new Result(currentState, tasks.size()); + } + + @Override + public boolean runOnlyOnMaster() { + return false; + } + } + int numberOfThreads = randomIntBetween(2, 256); + int tasksSubmittedPerThread = randomIntBetween(1, 1024); + + ConcurrentMap counters = new ConcurrentHashMap<>(); + CountDownLatch latch = new CountDownLatch(numberOfThreads * tasksSubmittedPerThread); + ClusterStateTaskListener listener = new ClusterStateTaskListener() { + @Override + public void onFailure(String source, Throwable t) { + assert false; + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + counters.computeIfAbsent(source, key -> new AtomicInteger()).incrementAndGet(); + latch.countDown(); + } + }; + + int numberOfExecutors = Math.max(1, numberOfThreads / 4); + List executors = new ArrayList<>(); + for (int i = 0; i < numberOfExecutors; i++) { + executors.add(new TaskExecutor()); + } + + // randomly assign tasks to executors + List assignments = new ArrayList<>(); + for (int i = 0; i < numberOfThreads; i++) { + for (int j = 0; j < tasksSubmittedPerThread; j++) { + assignments.add(randomFrom(executors)); + } + } + + Map counts = new HashMap<>(); + for (TaskExecutor executor : assignments) { + counts.merge(executor, 1, (previous, one) -> previous + one); + } + + CountDownLatch startingGun = new CountDownLatch(1 + numberOfThreads); + List threads = new ArrayList<>(); + for (int i = 0; i < numberOfThreads; i++) { + final int index = i; + Thread thread = new Thread(() -> { + startingGun.countDown(); + for (int j = 0; j < tasksSubmittedPerThread; j++) { + ClusterStateTaskExecutor executor = assignments.get(index * tasksSubmittedPerThread + j); + clusterService.submitStateUpdateTask( + Thread.currentThread().getName(), + new Task(), + ClusterStateTaskConfig.build(Priority.NORMAL), + executor, + listener); + } + }); + threads.add(thread); + thread.start(); + } + + startingGun.countDown(); + for (Thread thread : threads) { + thread.join(); + } + + // wait until all the cluster state updates have been processed + latch.await(); + + // assert the number of executed tasks is correct + assertEquals(numberOfThreads * tasksSubmittedPerThread, counter.get()); + + // assert each executor executed the correct number of tasks + for (TaskExecutor executor : executors) { + assertEquals((int)counts.get(executor), executor.counter.get()); + } + + // assert the correct number of clusterStateProcessed events were triggered + for (Map.Entry entry : counters.entrySet()) { + assertEquals(entry.getValue().get(), tasksSubmittedPerThread); + } + } + @TestLogging("cluster:TRACE") // To ensure that we log cluster state events on TRACE level public void testClusterStateUpdateLogging() throws Exception { Settings settings = settingsBuilder() @@ -958,12 +1075,12 @@ public class ClusterServiceIT extends ESIntegTestCase { } - private static class PrioritiezedTask extends ClusterStateUpdateTask { + private static class PrioritizedTask extends ClusterStateUpdateTask { private final CountDownLatch latch; - private final List tasks; + private final List tasks; - private PrioritiezedTask(Priority priority, CountDownLatch latch, List tasks) { + private PrioritizedTask(Priority priority, CountDownLatch latch, List tasks) { super(priority); this.latch = latch; this.tasks = tasks;