Extract batch executor out of cluster service (#24102)

Refactoring that extracts the task batching functionality from ClusterService and makes it a reusable component that can be tested in isolation.
This commit is contained in:
Yannick Welsch 2017-04-20 17:28:43 +02:00 committed by GitHub
parent a427d1dc5e
commit 22e0795990
9 changed files with 1050 additions and 604 deletions

View File

@ -61,19 +61,16 @@ import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.FutureUtils;
import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
import org.elasticsearch.common.util.concurrent.PrioritizedRunnable;
import org.elasticsearch.common.util.iterable.Iterables;
import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.discovery.DiscoverySettings;
import org.elasticsearch.threadpool.ThreadPool;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@ -85,7 +82,6 @@ import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.UnaryOperator;
@ -114,6 +110,7 @@ public class ClusterService extends AbstractLifecycleComponent {
private TimeValue slowTaskLoggingThreshold;
private volatile PrioritizedEsThreadPoolExecutor threadPoolExecutor;
private volatile ClusterServiceTaskBatcher taskBatcher;
/**
* Those 3 state listeners are changing infrequently - CopyOnWriteArrayList is just fine
@ -121,7 +118,6 @@ public class ClusterService extends AbstractLifecycleComponent {
private final Collection<ClusterStateApplier> highPriorityStateAppliers = new CopyOnWriteArrayList<>();
private final Collection<ClusterStateApplier> normalPriorityStateAppliers = new CopyOnWriteArrayList<>();
private final Collection<ClusterStateApplier> lowPriorityStateAppliers = new CopyOnWriteArrayList<>();
final Map<ClusterStateTaskExecutor, LinkedHashSet<UpdateTask>> updateTasksPerExecutor = new HashMap<>();
private final Iterable<ClusterStateApplier> clusterStateAppliers = Iterables.concat(highPriorityStateAppliers,
normalPriorityStateAppliers, lowPriorityStateAppliers);
@ -219,8 +215,9 @@ public class ClusterService extends AbstractLifecycleComponent {
DiscoveryNodes nodes = DiscoveryNodes.builder(state.nodes()).add(localNode).localNodeId(localNode.getId()).build();
return ClusterState.builder(state).nodes(nodes).blocks(initialBlocks).build();
});
this.threadPoolExecutor = EsExecutors.newSinglePrioritizing(UPDATE_THREAD_NAME, daemonThreadFactory(settings, UPDATE_THREAD_NAME),
threadPool.getThreadContext());
this.threadPoolExecutor = EsExecutors.newSinglePrioritizing(UPDATE_THREAD_NAME,
daemonThreadFactory(settings, UPDATE_THREAD_NAME), threadPool.getThreadContext(), threadPool.scheduler());
this.taskBatcher = new ClusterServiceTaskBatcher(logger, threadPoolExecutor);
}
@Override
@ -244,6 +241,44 @@ public class ClusterService extends AbstractLifecycleComponent {
protected synchronized void doClose() {
}
class ClusterServiceTaskBatcher extends TaskBatcher {
ClusterServiceTaskBatcher(Logger logger, PrioritizedEsThreadPoolExecutor threadExecutor) {
super(logger, threadExecutor);
}
@Override
protected void onTimeout(List<? extends BatchedTask> tasks, TimeValue timeout) {
threadPool.generic().execute(
() -> tasks.forEach(
task -> ((UpdateTask) task).listener.onFailure(task.source,
new ProcessClusterEventTimeoutException(timeout, task.source))));
}
@Override
protected void run(Object batchingKey, List<? extends BatchedTask> tasks, String tasksSummary) {
ClusterStateTaskExecutor<Object> taskExecutor = (ClusterStateTaskExecutor<Object>) batchingKey;
List<UpdateTask> updateTasks = (List<UpdateTask>) tasks;
runTasks(new ClusterService.TaskInputs(taskExecutor, updateTasks, tasksSummary));
}
class UpdateTask extends BatchedTask {
final ClusterStateTaskListener listener;
UpdateTask(Priority priority, String source, Object task, ClusterStateTaskListener listener,
ClusterStateTaskExecutor<?> executor) {
super(priority, source, executor, task);
this.listener = listener;
}
@Override
public String describeTasks(List<? extends BatchedTask> tasks) {
return ((ClusterStateTaskExecutor<Object>) batchingKey).describeTasks(
tasks.stream().map(BatchedTask::getTask).collect(Collectors.toList()));
}
}
}
/**
* The local node.
*/
@ -350,6 +385,7 @@ public class ClusterService extends AbstractLifecycleComponent {
listener.onClose();
return;
}
// call the post added notification on the same event thread
try {
threadPoolExecutor.execute(new SourcePrioritizedRunnable(Priority.HIGH, "_add_listener_") {
@ -432,38 +468,11 @@ public class ClusterService extends AbstractLifecycleComponent {
if (!lifecycle.started()) {
return;
}
if (tasks.isEmpty()) {
return;
}
try {
@SuppressWarnings("unchecked")
ClusterStateTaskExecutor<Object> taskExecutor = (ClusterStateTaskExecutor<Object>) executor;
// convert to an identity map to check for dups based on update tasks semantics of using identity instead of equal
final IdentityHashMap<Object, ClusterStateTaskListener> tasksIdentity = new IdentityHashMap<>(tasks);
final List<UpdateTask> updateTasks = tasksIdentity.entrySet().stream().map(
entry -> new UpdateTask(source, entry.getKey(), config.priority(), taskExecutor, safe(entry.getValue(), logger))
).collect(Collectors.toList());
synchronized (updateTasksPerExecutor) {
LinkedHashSet<UpdateTask> existingTasks = updateTasksPerExecutor.computeIfAbsent(executor,
k -> new LinkedHashSet<>(updateTasks.size()));
for (UpdateTask existing : existingTasks) {
if (tasksIdentity.containsKey(existing.task)) {
throw new IllegalStateException("task [" + taskExecutor.describeTasks(Collections.singletonList(existing.task)) +
"] with source [" + source + "] is already queued");
}
}
existingTasks.addAll(updateTasks);
}
final UpdateTask firstTask = updateTasks.get(0);
final TimeValue timeout = config.timeout();
if (timeout != null) {
threadPoolExecutor.execute(firstTask, threadPool.scheduler(), timeout, () -> onTimeout(updateTasks, source, timeout));
} else {
threadPoolExecutor.execute(firstTask);
}
List<ClusterServiceTaskBatcher.UpdateTask> safeTasks = tasks.entrySet().stream()
.map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue(), logger), executor))
.collect(Collectors.toList());
taskBatcher.submitTasks(safeTasks, config.timeout());
} catch (EsRejectedExecutionException e) {
// ignore cases where we are shutting down..., there is really nothing interesting
// to be done here...
@ -473,60 +482,17 @@ public class ClusterService extends AbstractLifecycleComponent {
}
}
private void onTimeout(List<UpdateTask> updateTasks, String source, TimeValue timeout) {
threadPool.generic().execute(() -> {
final ArrayList<UpdateTask> toRemove = new ArrayList<>();
for (UpdateTask task : updateTasks) {
if (task.processed.getAndSet(true) == false) {
logger.debug("cluster state update task [{}] timed out after [{}]", source, timeout);
toRemove.add(task);
}
}
if (toRemove.isEmpty() == false) {
ClusterStateTaskExecutor<Object> clusterStateTaskExecutor = toRemove.get(0).executor;
synchronized (updateTasksPerExecutor) {
LinkedHashSet<UpdateTask> existingTasks = updateTasksPerExecutor.get(clusterStateTaskExecutor);
if (existingTasks != null) {
existingTasks.removeAll(toRemove);
if (existingTasks.isEmpty()) {
updateTasksPerExecutor.remove(clusterStateTaskExecutor);
}
}
}
for (UpdateTask task : toRemove) {
task.listener.onFailure(source, new ProcessClusterEventTimeoutException(timeout, source));
}
}
});
}
/**
* Returns the tasks that are pending.
*/
public List<PendingClusterTask> pendingTasks() {
PrioritizedEsThreadPoolExecutor.Pending[] pendings = threadPoolExecutor.getPending();
List<PendingClusterTask> pendingClusterTasks = new ArrayList<>(pendings.length);
for (PrioritizedEsThreadPoolExecutor.Pending pending : pendings) {
final String source;
final long timeInQueue;
// we have to capture the task as it will be nulled after execution and we don't want to change while we check things here.
final Object task = pending.task;
if (task == null) {
continue;
} else if (task instanceof SourcePrioritizedRunnable) {
SourcePrioritizedRunnable runnable = (SourcePrioritizedRunnable) task;
source = runnable.source();
timeInQueue = runnable.getAgeInMillis();
} else {
assert false : "expected SourcePrioritizedRunnable got " + task.getClass();
source = "unknown [" + task.getClass() + "]";
timeInQueue = 0;
}
pendingClusterTasks.add(
new PendingClusterTask(pending.insertionOrder, pending.priority, new Text(source), timeInQueue, pending.executing));
}
return pendingClusterTasks;
return Arrays.stream(threadPoolExecutor.getPending()).map(pending -> {
assert pending.task instanceof SourcePrioritizedRunnable :
"thread pool executor should only use SourcePrioritizedRunnable instances but found: " + pending.task.getClass().getName();
SourcePrioritizedRunnable task = (SourcePrioritizedRunnable) pending.task;
return new PendingClusterTask(pending.insertionOrder, pending.priority, new Text(task.source()),
task.getAgeInMillis(), pending.executing);
}).collect(Collectors.toList());
}
/**
@ -585,19 +551,6 @@ public class ClusterService extends AbstractLifecycleComponent {
this.discoverySettings = discoverySettings;
}
abstract static class SourcePrioritizedRunnable extends PrioritizedRunnable {
protected final String source;
SourcePrioritizedRunnable(Priority priority, String source) {
super(priority);
this.source = source;
}
public String source() {
return source;
}
}
void runTasks(TaskInputs taskInputs) {
if (!lifecycle.started()) {
logger.debug("processing [{}]: ignoring, cluster service not started", taskInputs.summary);
@ -657,8 +610,8 @@ public class ClusterService extends AbstractLifecycleComponent {
public TaskOutputs calculateTaskOutputs(TaskInputs taskInputs, ClusterState previousClusterState, long startTimeNS) {
ClusterTasksResult<Object> clusterTasksResult = executeTasks(taskInputs, startTimeNS, previousClusterState);
// extract those that are waiting for results
List<UpdateTask> nonFailedTasks = new ArrayList<>();
for (UpdateTask updateTask : taskInputs.updateTasks) {
List<ClusterServiceTaskBatcher.UpdateTask> nonFailedTasks = new ArrayList<>();
for (ClusterServiceTaskBatcher.UpdateTask updateTask : taskInputs.updateTasks) {
assert clusterTasksResult.executionResults.containsKey(updateTask.task) : "missing " + updateTask;
final ClusterStateTaskExecutor.TaskResult taskResult =
clusterTasksResult.executionResults.get(updateTask.task);
@ -675,7 +628,8 @@ public class ClusterService extends AbstractLifecycleComponent {
private ClusterTasksResult<Object> executeTasks(TaskInputs taskInputs, long startTimeNS, ClusterState previousClusterState) {
ClusterTasksResult<Object> clusterTasksResult;
try {
List<Object> inputs = taskInputs.updateTasks.stream().map(tUpdateTask -> tUpdateTask.task).collect(Collectors.toList());
List<Object> inputs = taskInputs.updateTasks.stream()
.map(ClusterServiceTaskBatcher.UpdateTask::getTask).collect(Collectors.toList());
clusterTasksResult = taskInputs.executor.execute(previousClusterState, inputs);
} catch (Exception e) {
TimeValue executionTime = TimeValue.timeValueMillis(Math.max(0, TimeValue.nsecToMSec(currentTimeInNanos() - startTimeNS)));
@ -693,7 +647,7 @@ public class ClusterService extends AbstractLifecycleComponent {
}
warnAboutSlowTaskIfNeeded(executionTime, taskInputs.summary);
clusterTasksResult = ClusterTasksResult.builder()
.failures(taskInputs.updateTasks.stream().map(updateTask -> updateTask.task)::iterator, e)
.failures(taskInputs.updateTasks.stream().map(ClusterServiceTaskBatcher.UpdateTask::getTask)::iterator, e)
.build(previousClusterState);
}
@ -704,7 +658,7 @@ public class ClusterService extends AbstractLifecycleComponent {
boolean assertsEnabled = false;
assert (assertsEnabled = true);
if (assertsEnabled) {
for (UpdateTask updateTask : taskInputs.updateTasks) {
for (ClusterServiceTaskBatcher.UpdateTask updateTask : taskInputs.updateTasks) {
assert clusterTasksResult.executionResults.containsKey(updateTask.task) :
"missing task result for " + updateTask;
}
@ -870,10 +824,10 @@ public class ClusterService extends AbstractLifecycleComponent {
*/
class TaskInputs {
public final String summary;
public final ArrayList<UpdateTask> updateTasks;
public final List<ClusterServiceTaskBatcher.UpdateTask> updateTasks;
public final ClusterStateTaskExecutor<Object> executor;
TaskInputs(ClusterStateTaskExecutor<Object> executor, ArrayList<UpdateTask> updateTasks, String summary) {
TaskInputs(ClusterStateTaskExecutor<Object> executor, List<ClusterServiceTaskBatcher.UpdateTask> updateTasks, String summary) {
this.summary = summary;
this.executor = executor;
this.updateTasks = updateTasks;
@ -895,11 +849,11 @@ public class ClusterService extends AbstractLifecycleComponent {
public final TaskInputs taskInputs;
public final ClusterState previousClusterState;
public final ClusterState newClusterState;
public final List<UpdateTask> nonFailedTasks;
public final List<ClusterServiceTaskBatcher.UpdateTask> nonFailedTasks;
public final Map<Object, ClusterStateTaskExecutor.TaskResult> executionResults;
TaskOutputs(TaskInputs taskInputs, ClusterState previousClusterState,
ClusterState newClusterState, List<UpdateTask> nonFailedTasks,
ClusterState newClusterState, List<ClusterServiceTaskBatcher.UpdateTask> nonFailedTasks,
Map<Object, ClusterStateTaskExecutor.TaskResult> executionResults) {
this.taskInputs = taskInputs;
this.previousClusterState = previousClusterState;
@ -951,7 +905,7 @@ public class ClusterService extends AbstractLifecycleComponent {
public void notifyFailedTasks() {
// fail all tasks that have failed
for (UpdateTask updateTask : taskInputs.updateTasks) {
for (ClusterServiceTaskBatcher.UpdateTask updateTask : taskInputs.updateTasks) {
assert executionResults.containsKey(updateTask.task) : "missing " + updateTask;
final ClusterStateTaskExecutor.TaskResult taskResult = executionResults.get(updateTask.task);
if (taskResult.isSuccess() == false) {
@ -1071,65 +1025,6 @@ public class ClusterService extends AbstractLifecycleComponent {
}
}
class UpdateTask extends SourcePrioritizedRunnable {
public final Object task;
public final ClusterStateTaskListener listener;
private final ClusterStateTaskExecutor<Object> executor;
public final AtomicBoolean processed = new AtomicBoolean();
UpdateTask(String source, Object task, Priority priority, ClusterStateTaskExecutor<Object> executor,
ClusterStateTaskListener listener) {
super(priority, source);
this.task = task;
this.executor = executor;
this.listener = listener;
}
@Override
public void run() {
// if this task is already processed, the executor shouldn't execute other tasks (that arrived later),
// to give other executors a chance to execute their tasks.
if (processed.get() == false) {
final ArrayList<UpdateTask> toExecute = new ArrayList<>();
final Map<String, ArrayList<Object>> processTasksBySource = new HashMap<>();
synchronized (updateTasksPerExecutor) {
LinkedHashSet<UpdateTask> pending = updateTasksPerExecutor.remove(executor);
if (pending != null) {
for (UpdateTask task : pending) {
if (task.processed.getAndSet(true) == false) {
logger.trace("will process {}", task);
toExecute.add(task);
processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task.task);
} else {
logger.trace("skipping {}, already processed", task);
}
}
}
}
if (toExecute.isEmpty() == false) {
final String tasksSummary = processTasksBySource.entrySet().stream().map(entry -> {
String tasks = executor.describeTasks(entry.getValue());
return tasks.isEmpty() ? entry.getKey() : entry.getKey() + "[" + tasks + "]";
}).reduce((s1, s2) -> s1 + ", " + s2).orElse("");
runTasks(new TaskInputs(executor, toExecute, tasksSummary));
}
}
}
@Override
public String toString() {
String taskDescription = executor.describeTasks(Collections.singletonList(task));
if (taskDescription.isEmpty()) {
return "[" + source + "]";
} else {
return "[" + source + "[" + taskDescription + "]]";
}
}
}
private void warnAboutSlowTaskIfNeeded(TimeValue executionTime, String source) {
if (executionTime.getMillis() > slowTaskLoggingThreshold.getMillis()) {
logger.warn("cluster state update task [{}] took [{}] above the warn threshold of {}", source, executionTime,

View File

@ -0,0 +1,44 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.cluster.service;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.util.concurrent.PrioritizedRunnable;
/**
* PrioritizedRunnable that also has a source string
*/
public abstract class SourcePrioritizedRunnable extends PrioritizedRunnable {
protected final String source;
public SourcePrioritizedRunnable(Priority priority, String source) {
super(priority);
this.source = source;
}
public String source() {
return source;
}
@Override
public String toString() {
return "[" + source + "]";
}
}

View File

@ -0,0 +1,207 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.cluster.service;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Batching support for {@link PrioritizedEsThreadPoolExecutor}
* Tasks that share the same batching key are batched (see {@link BatchedTask#batchingKey})
*/
public abstract class TaskBatcher {
private final Logger logger;
private final PrioritizedEsThreadPoolExecutor threadExecutor;
// package visible for tests
final Map<Object, LinkedHashSet<BatchedTask>> tasksPerBatchingKey = new HashMap<>();
public TaskBatcher(Logger logger, PrioritizedEsThreadPoolExecutor threadExecutor) {
this.logger = logger;
this.threadExecutor = threadExecutor;
}
public void submitTasks(List<? extends BatchedTask> tasks, @Nullable TimeValue timeout) throws EsRejectedExecutionException {
if (tasks.isEmpty()) {
return;
}
final BatchedTask firstTask = tasks.get(0);
assert tasks.stream().allMatch(t -> t.batchingKey == firstTask.batchingKey) :
"tasks submitted in a batch should share the same batching key: " + tasks;
// convert to an identity map to check for dups based on task identity
final Map<Object, BatchedTask> tasksIdentity = tasks.stream().collect(Collectors.toMap(
BatchedTask::getTask,
Function.identity(),
(a, b) -> { throw new IllegalStateException("cannot add duplicate task: " + a); },
IdentityHashMap::new));
synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.computeIfAbsent(firstTask.batchingKey,
k -> new LinkedHashSet<>(tasks.size()));
for (BatchedTask existing : existingTasks) {
// check that there won't be two tasks with the same identity for the same batching key
BatchedTask duplicateTask = tasksIdentity.get(existing.getTask());
if (duplicateTask != null) {
throw new IllegalStateException("task [" + duplicateTask.describeTasks(
Collections.singletonList(existing)) + "] with source [" + duplicateTask.source + "] is already queued");
}
}
existingTasks.addAll(tasks);
}
if (timeout != null) {
threadExecutor.execute(firstTask, timeout, () -> onTimeoutInternal(tasks, timeout));
} else {
threadExecutor.execute(firstTask);
}
}
private void onTimeoutInternal(List<? extends BatchedTask> tasks, TimeValue timeout) {
final ArrayList<BatchedTask> toRemove = new ArrayList<>();
for (BatchedTask task : tasks) {
if (task.processed.getAndSet(true) == false) {
logger.debug("task [{}] timed out after [{}]", task.source, timeout);
toRemove.add(task);
}
}
if (toRemove.isEmpty() == false) {
BatchedTask firstTask = toRemove.get(0);
Object batchingKey = firstTask.batchingKey;
assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey) :
"tasks submitted in a batch should share the same batching key: " + tasks;
synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.get(batchingKey);
if (existingTasks != null) {
existingTasks.removeAll(toRemove);
if (existingTasks.isEmpty()) {
tasksPerBatchingKey.remove(batchingKey);
}
}
}
onTimeout(toRemove, timeout);
}
}
/**
* Action to be implemented by the specific batching implementation.
* All tasks have the same batching key.
*/
protected abstract void onTimeout(List<? extends BatchedTask> tasks, TimeValue timeout);
void runIfNotProcessed(BatchedTask updateTask) {
// if this task is already processed, it shouldn't execute other tasks with same batching key that arrived later,
// to give other tasks with different batching key a chance to execute.
if (updateTask.processed.get() == false) {
final List<BatchedTask> toExecute = new ArrayList<>();
final Map<String, List<BatchedTask>> processTasksBySource = new HashMap<>();
synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
if (pending != null) {
for (BatchedTask task : pending) {
if (task.processed.getAndSet(true) == false) {
logger.trace("will process {}", task);
toExecute.add(task);
processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task);
} else {
logger.trace("skipping {}, already processed", task);
}
}
}
}
if (toExecute.isEmpty() == false) {
final String tasksSummary = processTasksBySource.entrySet().stream().map(entry -> {
String tasks = updateTask.describeTasks(entry.getValue());
return tasks.isEmpty() ? entry.getKey() : entry.getKey() + "[" + tasks + "]";
}).reduce((s1, s2) -> s1 + ", " + s2).orElse("");
run(updateTask.batchingKey, toExecute, tasksSummary);
}
}
}
/**
* Action to be implemented by the specific batching implementation
* All tasks have the given batching key.
*/
protected abstract void run(Object batchingKey, List<? extends BatchedTask> tasks, String tasksSummary);
/**
* Represents a runnable task that supports batching.
* Implementors of TaskBatcher can subclass this to add a payload to the task.
*/
protected abstract class BatchedTask extends SourcePrioritizedRunnable {
/**
* whether the task has been processed already
*/
protected final AtomicBoolean processed = new AtomicBoolean();
/**
* the object that is used as batching key
*/
protected final Object batchingKey;
/**
* the task object that is wrapped
*/
protected final Object task;
protected BatchedTask(Priority priority, String source, Object batchingKey, Object task) {
super(priority, source);
this.batchingKey = batchingKey;
this.task = task;
}
@Override
public void run() {
runIfNotProcessed(this);
}
@Override
public String toString() {
String taskDescription = describeTasks(Collections.singletonList(this));
if (taskDescription.isEmpty()) {
return "[" + source + "]";
} else {
return "[" + source + "[" + taskDescription + "]]";
}
}
public abstract String describeTasks(List<? extends BatchedTask> tasks);
public Object getTask() {
return task;
}
}
}

View File

@ -30,6 +30,7 @@ import java.util.concurrent.AbstractExecutorService;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
@ -57,8 +58,8 @@ public class EsExecutors {
return PROCESSORS_SETTING.get(settings);
}
public static PrioritizedEsThreadPoolExecutor newSinglePrioritizing(String name, ThreadFactory threadFactory, ThreadContext contextHolder) {
return new PrioritizedEsThreadPoolExecutor(name, 1, 1, 0L, TimeUnit.MILLISECONDS, threadFactory, contextHolder);
public static PrioritizedEsThreadPoolExecutor newSinglePrioritizing(String name, ThreadFactory threadFactory, ThreadContext contextHolder, ScheduledExecutorService timer) {
return new PrioritizedEsThreadPoolExecutor(name, 1, 1, 0L, TimeUnit.MILLISECONDS, threadFactory, contextHolder, timer);
}
public static EsThreadPoolExecutor newScaling(String name, int min, int max, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory, ThreadContext contextHolder) {

View File

@ -44,11 +44,14 @@ import java.util.concurrent.atomic.AtomicLong;
public class PrioritizedEsThreadPoolExecutor extends EsThreadPoolExecutor {
private static final TimeValue NO_WAIT_TIME_VALUE = TimeValue.timeValueMillis(0);
private AtomicLong insertionOrder = new AtomicLong();
private Queue<Runnable> current = ConcurrentCollections.newQueue();
private final AtomicLong insertionOrder = new AtomicLong();
private final Queue<Runnable> current = ConcurrentCollections.newQueue();
private final ScheduledExecutorService timer;
PrioritizedEsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory, ThreadContext contextHolder) {
PrioritizedEsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit,
ThreadFactory threadFactory, ThreadContext contextHolder, ScheduledExecutorService timer) {
super(name, corePoolSize, maximumPoolSize, keepAliveTime, unit, new PriorityBlockingQueue<>(), threadFactory, contextHolder);
this.timer = timer;
}
public Pending[] getPending() {
@ -111,7 +114,7 @@ public class PrioritizedEsThreadPoolExecutor extends EsThreadPoolExecutor {
current.remove(r);
}
public void execute(Runnable command, final ScheduledExecutorService timer, final TimeValue timeout, final Runnable timeoutCallback) {
public void execute(Runnable command, final TimeValue timeout, final Runnable timeoutCallback) {
command = wrapRunnable(command);
doExecute(command);
if (timeout.nanos() >= 0) {

View File

@ -20,7 +20,6 @@ package org.elasticsearch.cluster.service;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
@ -39,7 +38,6 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
@ -59,8 +57,6 @@ import org.junit.Before;
import org.junit.BeforeClass;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@ -69,7 +65,6 @@ import java.util.Set;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Semaphore;
@ -77,17 +72,14 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.elasticsearch.test.ClusterServiceUtils.setState;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.is;
public class ClusterServiceTests extends ESTestCase {
@ -151,118 +143,6 @@ public class ClusterServiceTests extends ESTestCase {
return timedClusterService;
}
public void testTimedOutUpdateTaskCleanedUp() throws Exception {
final CountDownLatch block = new CountDownLatch(1);
final CountDownLatch blockCompleted = new CountDownLatch(1);
clusterService.submitStateUpdateTask("block-task", new ClusterStateUpdateTask() {
@Override
public ClusterState execute(ClusterState currentState) {
try {
block.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
blockCompleted.countDown();
return currentState;
}
@Override
public void onFailure(String source, Exception e) {
throw new RuntimeException(e);
}
});
final CountDownLatch block2 = new CountDownLatch(1);
clusterService.submitStateUpdateTask("test", new ClusterStateUpdateTask() {
@Override
public ClusterState execute(ClusterState currentState) {
block2.countDown();
return currentState;
}
@Override
public TimeValue timeout() {
return TimeValue.ZERO;
}
@Override
public void onFailure(String source, Exception e) {
block2.countDown();
}
});
block.countDown();
block2.await();
blockCompleted.await();
synchronized (clusterService.updateTasksPerExecutor) {
assertTrue("expected empty map but was " + clusterService.updateTasksPerExecutor,
clusterService.updateTasksPerExecutor.isEmpty());
}
}
public void testTimeoutUpdateTask() throws Exception {
final CountDownLatch block = new CountDownLatch(1);
clusterService.submitStateUpdateTask("test1", new ClusterStateUpdateTask() {
@Override
public ClusterState execute(ClusterState currentState) {
try {
block.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return currentState;
}
@Override
public void onFailure(String source, Exception e) {
throw new RuntimeException(e);
}
});
final CountDownLatch timedOut = new CountDownLatch(1);
final AtomicBoolean executeCalled = new AtomicBoolean();
clusterService.submitStateUpdateTask("test2", new ClusterStateUpdateTask() {
@Override
public TimeValue timeout() {
return TimeValue.timeValueMillis(2);
}
@Override
public void onFailure(String source, Exception e) {
timedOut.countDown();
}
@Override
public ClusterState execute(ClusterState currentState) {
executeCalled.set(true);
return currentState;
}
@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
}
});
timedOut.await();
block.countDown();
final CountDownLatch allProcessed = new CountDownLatch(1);
clusterService.submitStateUpdateTask("test3", new ClusterStateUpdateTask() {
@Override
public void onFailure(String source, Exception e) {
throw new RuntimeException(e);
}
@Override
public ClusterState execute(ClusterState currentState) {
allProcessed.countDown();
return currentState;
}
});
allProcessed.await(); // executed another task to double check that execute on the timed out update task is not called...
assertThat(executeCalled.get(), equalTo(false));
}
public void testMasterAwareExecution() throws Exception {
ClusterService nonMaster = createTimedClusterService(false);
@ -394,164 +274,6 @@ public class ClusterServiceTests extends ESTestCase {
assertThat(assertionRef.get().getMessage(), containsString("not be the cluster state update thread. Reason: [Blocking operation]"));
}
public void testOneExecutorDontStarveAnother() throws InterruptedException {
final List<String> executionOrder = Collections.synchronizedList(new ArrayList<>());
final Semaphore allowProcessing = new Semaphore(0);
final Semaphore startedProcessing = new Semaphore(0);
class TaskExecutor implements ClusterStateTaskExecutor<String> {
@Override
public ClusterTasksResult<String> execute(ClusterState currentState, List<String> tasks) throws Exception {
executionOrder.addAll(tasks); // do this first, so startedProcessing can be used as a notification that this is done.
startedProcessing.release(tasks.size());
allowProcessing.acquire(tasks.size());
return ClusterTasksResult.<String>builder().successes(tasks).build(ClusterState.builder(currentState).build());
}
}
TaskExecutor executorA = new TaskExecutor();
TaskExecutor executorB = new TaskExecutor();
final ClusterStateTaskConfig config = ClusterStateTaskConfig.build(Priority.NORMAL);
final ClusterStateTaskListener noopListener = (source, e) -> { throw new AssertionError(source, e); };
// this blocks the cluster state queue, so we can set it up right
clusterService.submitStateUpdateTask("0", "A0", config, executorA, noopListener);
// wait to be processed
startedProcessing.acquire(1);
assertThat(executionOrder, equalTo(Arrays.asList("A0")));
// these will be the first batch
clusterService.submitStateUpdateTask("1", "A1", config, executorA, noopListener);
clusterService.submitStateUpdateTask("2", "A2", config, executorA, noopListener);
// release the first 0 task, but not the second
allowProcessing.release(1);
startedProcessing.acquire(2);
assertThat(executionOrder, equalTo(Arrays.asList("A0", "A1", "A2")));
// setup the queue with pending tasks for another executor same priority
clusterService.submitStateUpdateTask("3", "B3", config, executorB, noopListener);
clusterService.submitStateUpdateTask("4", "B4", config, executorB, noopListener);
clusterService.submitStateUpdateTask("5", "A5", config, executorA, noopListener);
clusterService.submitStateUpdateTask("6", "A6", config, executorA, noopListener);
// now release the processing
allowProcessing.release(6);
// wait for last task to be processed
startedProcessing.acquire(4);
assertThat(executionOrder, equalTo(Arrays.asList("A0", "A1", "A2", "B3", "B4", "A5", "A6")));
}
// test that for a single thread, tasks are executed in the order
// that they are submitted
public void testClusterStateUpdateTasksAreExecutedInOrder() throws BrokenBarrierException, InterruptedException {
class TaskExecutor implements ClusterStateTaskExecutor<Integer> {
List<Integer> tasks = new ArrayList<>();
@Override
public ClusterTasksResult<Integer> execute(ClusterState currentState, List<Integer> tasks) throws Exception {
this.tasks.addAll(tasks);
return ClusterTasksResult.<Integer>builder().successes(tasks).build(ClusterState.builder(currentState).build());
}
}
int numberOfThreads = randomIntBetween(2, 8);
TaskExecutor[] executors = new TaskExecutor[numberOfThreads];
for (int i = 0; i < numberOfThreads; i++) {
executors[i] = new TaskExecutor();
}
int tasksSubmittedPerThread = randomIntBetween(2, 1024);
CopyOnWriteArrayList<Tuple<String, Throwable>> failures = new CopyOnWriteArrayList<>();
CountDownLatch updateLatch = new CountDownLatch(numberOfThreads * tasksSubmittedPerThread);
ClusterStateTaskListener listener = new ClusterStateTaskListener() {
@Override
public void onFailure(String source, Exception e) {
logger.error((Supplier<?>) () -> new ParameterizedMessage("unexpected failure: [{}]", source), e);
failures.add(new Tuple<>(source, e));
updateLatch.countDown();
}
@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
updateLatch.countDown();
}
};
CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
final int index = i;
Thread thread = new Thread(() -> {
try {
barrier.await();
for (int j = 0; j < tasksSubmittedPerThread; j++) {
clusterService.submitStateUpdateTask("[" + index + "][" + j + "]", j,
ClusterStateTaskConfig.build(randomFrom(Priority.values())), executors[index], listener);
}
barrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
throw new AssertionError(e);
}
});
thread.start();
}
// wait for all threads to be ready
barrier.await();
// wait for all threads to finish
barrier.await();
updateLatch.await();
assertThat(failures, empty());
for (int i = 0; i < numberOfThreads; i++) {
assertEquals(tasksSubmittedPerThread, executors[i].tasks.size());
for (int j = 0; j < tasksSubmittedPerThread; j++) {
assertNotNull(executors[i].tasks.get(j));
assertEquals("cluster state update task executed out of order", j, (int) executors[i].tasks.get(j));
}
}
}
public void testSingleBatchSubmission() throws InterruptedException {
Map<Integer, ClusterStateTaskListener> tasks = new HashMap<>();
final int numOfTasks = randomInt(10);
final CountDownLatch latch = new CountDownLatch(numOfTasks);
for (int i = 0; i < numOfTasks; i++) {
while (null != tasks.put(randomInt(1024), new ClusterStateTaskListener() {
@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
latch.countDown();
}
@Override
public void onFailure(String source, Exception e) {
fail(ExceptionsHelper.detailedMessage(e));
}
})) ;
}
clusterService.submitStateUpdateTasks("test", tasks, ClusterStateTaskConfig.build(Priority.LANGUID),
(currentState, taskList) -> {
assertThat(taskList.size(), equalTo(tasks.size()));
assertThat(taskList.stream().collect(Collectors.toSet()), equalTo(tasks.keySet()));
return ClusterStateTaskExecutor.ClusterTasksResult.<Integer>builder().successes(taskList).build(currentState);
});
latch.await();
}
public void testClusterStateBatchedUpdates() throws BrokenBarrierException, InterruptedException {
AtomicInteger counter = new AtomicInteger();
class Task {
@ -745,76 +467,6 @@ public class ClusterServiceTests extends ESTestCase {
}
}
/**
* Note, this test can only work as long as we have a single thread executor executing the state update tasks!
*/
public void testPrioritizedTasks() throws Exception {
BlockingTask block = new BlockingTask(Priority.IMMEDIATE);
clusterService.submitStateUpdateTask("test", block);
int taskCount = randomIntBetween(5, 20);
// will hold all the tasks in the order in which they were executed
List<PrioritizedTask> tasks = new ArrayList<>(taskCount);
CountDownLatch latch = new CountDownLatch(taskCount);
for (int i = 0; i < taskCount; i++) {
Priority priority = randomFrom(Priority.values());
clusterService.submitStateUpdateTask("test", new PrioritizedTask(priority, latch, tasks));
}
block.close();
latch.await();
Priority prevPriority = null;
for (PrioritizedTask task : tasks) {
if (prevPriority == null) {
prevPriority = task.priority();
} else {
assertThat(task.priority().sameOrAfter(prevPriority), is(true));
}
}
}
public void testDuplicateSubmission() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(2);
try (BlockingTask blockingTask = new BlockingTask(Priority.IMMEDIATE)) {
clusterService.submitStateUpdateTask("blocking", blockingTask);
ClusterStateTaskExecutor<SimpleTask> executor = (currentState, tasks) ->
ClusterStateTaskExecutor.ClusterTasksResult.<SimpleTask>builder().successes(tasks).build(currentState);
SimpleTask task = new SimpleTask(1);
ClusterStateTaskListener listener = new ClusterStateTaskListener() {
@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
latch.countDown();
}
@Override
public void onFailure(String source, Exception e) {
fail(ExceptionsHelper.detailedMessage(e));
}
};
clusterService.submitStateUpdateTask("first time", task, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);
final IllegalStateException e =
expectThrows(
IllegalStateException.class,
() -> clusterService.submitStateUpdateTask(
"second time",
task,
ClusterStateTaskConfig.build(Priority.NORMAL),
executor, listener));
assertThat(e, hasToString(containsString("task [1] with source [second time] is already queued")));
clusterService.submitStateUpdateTask("third time a charm", new SimpleTask(1),
ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);
assertThat(latch.getCount(), equalTo(2L));
}
latch.await();
}
@TestLogging("org.elasticsearch.cluster.service:TRACE") // To ensure that we log cluster state events on TRACE level
public void testClusterStateUpdateLogging() throws Exception {
MockLogAppender mockAppender = new MockLogAppender();
@ -1250,77 +902,6 @@ public class ClusterServiceTests extends ESTestCase {
assertTrue(applierCalled.get());
}
private static class SimpleTask {
private final int id;
private SimpleTask(int id) {
this.id = id;
}
@Override
public int hashCode() {
return super.hashCode();
}
@Override
public boolean equals(Object obj) {
return super.equals(obj);
}
@Override
public String toString() {
return Integer.toString(id);
}
}
private static class BlockingTask extends ClusterStateUpdateTask implements Releasable {
private final CountDownLatch latch = new CountDownLatch(1);
BlockingTask(Priority priority) {
super(priority);
}
@Override
public ClusterState execute(ClusterState currentState) throws Exception {
latch.await();
return currentState;
}
@Override
public void onFailure(String source, Exception e) {
}
public void close() {
latch.countDown();
}
}
private static class PrioritizedTask extends ClusterStateUpdateTask {
private final CountDownLatch latch;
private final List<PrioritizedTask> tasks;
private PrioritizedTask(Priority priority, CountDownLatch latch, List<PrioritizedTask> tasks) {
super(priority);
this.latch = latch;
this.tasks = tasks;
}
@Override
public ClusterState execute(ClusterState currentState) throws Exception {
tasks.add(this);
latch.countDown();
return currentState;
}
@Override
public void onFailure(String source, Exception e) {
latch.countDown();
}
}
static class TimedClusterService extends ClusterService {
public volatile Long currentTimeOverride = null;

View File

@ -0,0 +1,350 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.cluster.service;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.cluster.ClusterStateTaskConfig;
import org.elasticsearch.cluster.metadata.ProcessClusterEventTimeoutException;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
import org.junit.Before;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Semaphore;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasToString;
public class TaskBatcherTests extends TaskExecutorTests {
protected TestTaskBatcher taskBatcher;
@Before
public void setUpBatchingTaskExecutor() throws Exception {
taskBatcher = new TestTaskBatcher(logger, threadExecutor);
}
class TestTaskBatcher extends TaskBatcher {
TestTaskBatcher(Logger logger, PrioritizedEsThreadPoolExecutor threadExecutor) {
super(logger, threadExecutor);
}
@Override
protected void run(Object batchingKey, List<? extends BatchedTask> tasks, String tasksSummary) {
List<UpdateTask> updateTasks = (List) tasks;
((TestExecutor) batchingKey).execute(updateTasks.stream().map(t -> t.task).collect(Collectors.toList()));
updateTasks.forEach(updateTask -> updateTask.listener.processed(updateTask.source));
}
@Override
protected void onTimeout(List<? extends BatchedTask> tasks, TimeValue timeout) {
threadPool.generic().execute(
() -> tasks.forEach(
task -> ((UpdateTask) task).listener.onFailure(task.source,
new ProcessClusterEventTimeoutException(timeout, task.source))));
}
class UpdateTask extends BatchedTask {
final TestListener listener;
UpdateTask(Priority priority, String source, Object task, TestListener listener, TestExecutor<?> executor) {
super(priority, source, executor, task);
this.listener = listener;
}
@Override
public String describeTasks(List<? extends BatchedTask> tasks) {
return ((TestExecutor<Object>) batchingKey).describeTasks(
tasks.stream().map(BatchedTask::getTask).collect(Collectors.toList()));
}
}
}
@Override
protected void submitTask(String source, TestTask testTask) {
submitTask(source, testTask, testTask, testTask, testTask);
}
private <T> void submitTask(String source, T task, ClusterStateTaskConfig config, TestExecutor<T> executor,
TestListener listener) {
submitTasks(source, Collections.singletonMap(task, listener), config, executor);
}
private <T> void submitTasks(final String source,
final Map<T, TestListener> tasks, final ClusterStateTaskConfig config,
final TestExecutor<T> executor) {
List<TestTaskBatcher.UpdateTask> safeTasks = tasks.entrySet().stream()
.map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), e.getValue(), executor))
.collect(Collectors.toList());
taskBatcher.submitTasks(safeTasks, config.timeout());
}
@Override
public void testTimedOutTaskCleanedUp() throws Exception {
super.testTimedOutTaskCleanedUp();
synchronized (taskBatcher.tasksPerBatchingKey) {
assertTrue("expected empty map but was " + taskBatcher.tasksPerBatchingKey,
taskBatcher.tasksPerBatchingKey.isEmpty());
}
}
public void testOneExecutorDoesntStarveAnother() throws InterruptedException {
final List<String> executionOrder = Collections.synchronizedList(new ArrayList<>());
final Semaphore allowProcessing = new Semaphore(0);
final Semaphore startedProcessing = new Semaphore(0);
class TaskExecutor implements TestExecutor<String> {
@Override
public void execute(List<String> tasks) {
executionOrder.addAll(tasks); // do this first, so startedProcessing can be used as a notification that this is done.
startedProcessing.release(tasks.size());
try {
allowProcessing.acquire(tasks.size());
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
TaskExecutor executorA = new TaskExecutor();
TaskExecutor executorB = new TaskExecutor();
final ClusterStateTaskConfig config = ClusterStateTaskConfig.build(Priority.NORMAL);
final TestListener noopListener = (source, e) -> {
throw new AssertionError(e);
};
// this blocks the cluster state queue, so we can set it up right
submitTask("0", "A0", config, executorA, noopListener);
// wait to be processed
startedProcessing.acquire(1);
assertThat(executionOrder, equalTo(Arrays.asList("A0")));
// these will be the first batch
submitTask("1", "A1", config, executorA, noopListener);
submitTask("2", "A2", config, executorA, noopListener);
// release the first 0 task, but not the second
allowProcessing.release(1);
startedProcessing.acquire(2);
assertThat(executionOrder, equalTo(Arrays.asList("A0", "A1", "A2")));
// setup the queue with pending tasks for another executor same priority
submitTask("3", "B3", config, executorB, noopListener);
submitTask("4", "B4", config, executorB, noopListener);
submitTask("5", "A5", config, executorA, noopListener);
submitTask("6", "A6", config, executorA, noopListener);
// now release the processing
allowProcessing.release(6);
// wait for last task to be processed
startedProcessing.acquire(4);
assertThat(executionOrder, equalTo(Arrays.asList("A0", "A1", "A2", "B3", "B4", "A5", "A6")));
}
static class TaskExecutor implements TestExecutor<Integer> {
List<Integer> tasks = new ArrayList<>();
@Override
public void execute(List<Integer> tasks) {
this.tasks.addAll(tasks);
}
}
// test that for a single thread, tasks are executed in the order
// that they are submitted
public void testTasksAreExecutedInOrder() throws BrokenBarrierException, InterruptedException {
int numberOfThreads = randomIntBetween(2, 8);
TaskExecutor[] executors = new TaskExecutor[numberOfThreads];
for (int i = 0; i < numberOfThreads; i++) {
executors[i] = new TaskExecutor();
}
int tasksSubmittedPerThread = randomIntBetween(2, 1024);
CopyOnWriteArrayList<Tuple<String, Throwable>> failures = new CopyOnWriteArrayList<>();
CountDownLatch updateLatch = new CountDownLatch(numberOfThreads * tasksSubmittedPerThread);
final TestListener listener = new TestListener() {
@Override
public void onFailure(String source, Exception e) {
logger.error((Supplier<?>) () -> new ParameterizedMessage("unexpected failure: [{}]", source), e);
failures.add(new Tuple<>(source, e));
updateLatch.countDown();
}
@Override
public void processed(String source) {
updateLatch.countDown();
}
};
CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
final int index = i;
Thread thread = new Thread(() -> {
try {
barrier.await();
for (int j = 0; j < tasksSubmittedPerThread; j++) {
submitTask("[" + index + "][" + j + "]", j,
ClusterStateTaskConfig.build(randomFrom(Priority.values())), executors[index], listener);
}
barrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
throw new AssertionError(e);
}
});
thread.start();
}
// wait for all threads to be ready
barrier.await();
// wait for all threads to finish
barrier.await();
updateLatch.await();
assertThat(failures, empty());
for (int i = 0; i < numberOfThreads; i++) {
assertEquals(tasksSubmittedPerThread, executors[i].tasks.size());
for (int j = 0; j < tasksSubmittedPerThread; j++) {
assertNotNull(executors[i].tasks.get(j));
assertEquals("cluster state update task executed out of order", j, (int) executors[i].tasks.get(j));
}
}
}
public void testSingleBatchSubmission() throws InterruptedException {
Map<Integer, TestListener> tasks = new HashMap<>();
final int numOfTasks = randomInt(10);
final CountDownLatch latch = new CountDownLatch(numOfTasks);
for (int i = 0; i < numOfTasks; i++) {
while (null != tasks.put(randomInt(1024), new TestListener() {
@Override
public void processed(String source) {
latch.countDown();
}
@Override
public void onFailure(String source, Exception e) {
fail(ExceptionsHelper.detailedMessage(e));
}
})) ;
}
TestExecutor<Integer> executor = taskList -> {
assertThat(taskList.size(), equalTo(tasks.size()));
assertThat(taskList.stream().collect(Collectors.toSet()), equalTo(tasks.keySet()));
};
submitTasks("test", tasks, ClusterStateTaskConfig.build(Priority.LANGUID), executor);
latch.await();
}
public void testDuplicateSubmission() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(2);
try (BlockingTask blockingTask = new BlockingTask(Priority.IMMEDIATE)) {
submitTask("blocking", blockingTask);
TestExecutor<SimpleTask> executor = tasks -> {};
SimpleTask task = new SimpleTask(1);
TestListener listener = new TestListener() {
@Override
public void processed(String source) {
latch.countDown();
}
@Override
public void onFailure(String source, Exception e) {
fail(ExceptionsHelper.detailedMessage(e));
}
};
submitTask("first time", task, ClusterStateTaskConfig.build(Priority.NORMAL), executor,
listener);
final IllegalStateException e =
expectThrows(
IllegalStateException.class,
() -> submitTask(
"second time",
task,
ClusterStateTaskConfig.build(Priority.NORMAL),
executor, listener));
assertThat(e, hasToString(containsString("task [1] with source [second time] is already queued")));
submitTask("third time a charm", new SimpleTask(1),
ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);
assertThat(latch.getCount(), equalTo(2L));
}
latch.await();
}
private static class SimpleTask {
private final int id;
private SimpleTask(int id) {
this.id = id;
}
@Override
public int hashCode() {
return super.hashCode();
}
@Override
public boolean equals(Object obj) {
return super.equals(obj);
}
@Override
public String toString() {
return Integer.toString(id);
}
}
}

View File

@ -0,0 +1,365 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.cluster.service;
import org.elasticsearch.cluster.ClusterStateTaskConfig;
import org.elasticsearch.cluster.metadata.ProcessClusterEventTimeoutException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.core.Is.is;
public class TaskExecutorTests extends ESTestCase {
protected static ThreadPool threadPool;
protected PrioritizedEsThreadPoolExecutor threadExecutor;
@BeforeClass
public static void createThreadPool() {
threadPool = new TestThreadPool(getTestClass().getName());
}
@AfterClass
public static void stopThreadPool() {
if (threadPool != null) {
threadPool.shutdownNow();
threadPool = null;
}
}
@Before
public void setUpExecutor() {
threadExecutor = EsExecutors.newSinglePrioritizing("test_thread",
daemonThreadFactory(Settings.EMPTY, "test_thread"), threadPool.getThreadContext(), threadPool.scheduler());
}
@After
public void shutDownThreadExecutor() {
ThreadPool.terminate(threadExecutor, 10, TimeUnit.SECONDS);
}
protected interface TestListener {
void onFailure(String source, Exception e);
default void processed(String source) {
// do nothing by default
}
}
protected interface TestExecutor<T> {
void execute(List<T> tasks);
default String describeTasks(List<T> tasks) {
return tasks.stream().map(T::toString).reduce((s1,s2) -> {
if (s1.isEmpty()) {
return s2;
} else if (s2.isEmpty()) {
return s1;
} else {
return s1 + ", " + s2;
}
}).orElse("");
}
}
/**
* Task class that works for single tasks as well as batching (see {@link TaskBatcherTests})
*/
protected abstract static class TestTask implements TestExecutor<TestTask>, TestListener, ClusterStateTaskConfig {
@Override
public void execute(List<TestTask> tasks) {
tasks.forEach(TestTask::run);
}
@Nullable
@Override
public TimeValue timeout() {
return null;
}
@Override
public Priority priority() {
return Priority.NORMAL;
}
public abstract void run();
}
class UpdateTask extends SourcePrioritizedRunnable {
final TestTask testTask;
UpdateTask(String source, TestTask testTask) {
super(testTask.priority(), source);
this.testTask = testTask;
}
@Override
public void run() {
logger.trace("will process {}", source);
testTask.execute(Collections.singletonList(testTask));
testTask.processed(source);
}
}
// can be overridden by TaskBatcherTests
protected void submitTask(String source, TestTask testTask) {
SourcePrioritizedRunnable task = new UpdateTask(source, testTask);
TimeValue timeout = testTask.timeout();
if (timeout != null) {
threadExecutor.execute(task, timeout, () -> threadPool.generic().execute(() -> {
logger.debug("task [{}] timed out after [{}]", task, timeout);
testTask.onFailure(source, new ProcessClusterEventTimeoutException(timeout, source));
}));
} else {
threadExecutor.execute(task);
}
}
public void testTimedOutTaskCleanedUp() throws Exception {
final CountDownLatch block = new CountDownLatch(1);
final CountDownLatch blockCompleted = new CountDownLatch(1);
TestTask blockTask = new TestTask() {
@Override
public void run() {
try {
block.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
blockCompleted.countDown();
}
@Override
public void onFailure(String source, Exception e) {
throw new RuntimeException(e);
}
};
submitTask("block-task", blockTask);
final CountDownLatch block2 = new CountDownLatch(1);
TestTask unblockTask = new TestTask() {
@Override
public void run() {
block2.countDown();
}
@Override
public void onFailure(String source, Exception e) {
block2.countDown();
}
@Override
public TimeValue timeout() {
return TimeValue.ZERO;
}
};
submitTask("unblock-task", unblockTask);
block.countDown();
block2.await();
blockCompleted.await();
}
public void testTimeoutTask() throws Exception {
final CountDownLatch block = new CountDownLatch(1);
TestTask test1 = new TestTask() {
@Override
public void run() {
try {
block.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
@Override
public void onFailure(String source, Exception e) {
throw new RuntimeException(e);
}
};
submitTask("block-task", test1);
final CountDownLatch timedOut = new CountDownLatch(1);
final AtomicBoolean executeCalled = new AtomicBoolean();
TestTask test2 = new TestTask() {
@Override
public TimeValue timeout() {
return TimeValue.timeValueMillis(2);
}
@Override
public void run() {
executeCalled.set(true);
}
@Override
public void onFailure(String source, Exception e) {
timedOut.countDown();
}
};
submitTask("block-task", test2);
timedOut.await();
block.countDown();
final CountDownLatch allProcessed = new CountDownLatch(1);
TestTask test3 = new TestTask() {
@Override
public void run() {
allProcessed.countDown();
}
@Override
public void onFailure(String source, Exception e) {
throw new RuntimeException(e);
}
};
submitTask("block-task", test3);
allProcessed.await(); // executed another task to double check that execute on the timed out update task is not called...
assertThat(executeCalled.get(), equalTo(false));
}
static class TaskExecutor implements TestExecutor<Integer> {
List<Integer> tasks = new ArrayList<>();
@Override
public void execute(List<Integer> tasks) {
this.tasks.addAll(tasks);
}
}
/**
* Note, this test can only work as long as we have a single thread executor executing the state update tasks!
*/
public void testPrioritizedTasks() throws Exception {
BlockingTask block = new BlockingTask(Priority.IMMEDIATE);
submitTask("test", block);
int taskCount = randomIntBetween(5, 20);
// will hold all the tasks in the order in which they were executed
List<PrioritizedTask> tasks = new ArrayList<>(taskCount);
CountDownLatch latch = new CountDownLatch(taskCount);
for (int i = 0; i < taskCount; i++) {
Priority priority = randomFrom(Priority.values());
PrioritizedTask task = new PrioritizedTask(priority, latch, tasks);
submitTask("test", task);
}
block.close();
latch.await();
Priority prevPriority = null;
for (PrioritizedTask task : tasks) {
if (prevPriority == null) {
prevPriority = task.priority();
} else {
assertThat(task.priority().sameOrAfter(prevPriority), is(true));
}
}
}
protected static class BlockingTask extends TestTask implements Releasable {
private final CountDownLatch latch = new CountDownLatch(1);
private final Priority priority;
BlockingTask(Priority priority) {
super();
this.priority = priority;
}
@Override
public void run() {
try {
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
@Override
public void onFailure(String source, Exception e) {
}
@Override
public Priority priority() {
return priority;
}
public void close() {
latch.countDown();
}
}
protected static class PrioritizedTask extends TestTask {
private final CountDownLatch latch;
private final List<PrioritizedTask> tasks;
private final Priority priority;
private PrioritizedTask(Priority priority, CountDownLatch latch, List<PrioritizedTask> tasks) {
super();
this.latch = latch;
this.tasks = tasks;
this.priority = priority;
}
@Override
public void run() {
tasks.add(this);
latch.countDown();
}
@Override
public Priority priority() {
return priority;
}
@Override
public void onFailure(String source, Exception e) {
latch.countDown();
}
}
}

View File

@ -65,7 +65,7 @@ public class PrioritizedExecutorsTests extends ESTestCase {
}
public void testSubmitPrioritizedExecutorWithRunnables() throws Exception {
ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder);
ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, null);
List<Integer> results = new ArrayList<>(8);
CountDownLatch awaitingLatch = new CountDownLatch(1);
CountDownLatch finishedLatch = new CountDownLatch(8);
@ -94,7 +94,7 @@ public class PrioritizedExecutorsTests extends ESTestCase {
}
public void testExecutePrioritizedExecutorWithRunnables() throws Exception {
ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder);
ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, null);
List<Integer> results = new ArrayList<>(8);
CountDownLatch awaitingLatch = new CountDownLatch(1);
CountDownLatch finishedLatch = new CountDownLatch(8);
@ -123,7 +123,7 @@ public class PrioritizedExecutorsTests extends ESTestCase {
}
public void testSubmitPrioritizedExecutorWithCallables() throws Exception {
ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder);
ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, null);
List<Integer> results = new ArrayList<>(8);
CountDownLatch awaitingLatch = new CountDownLatch(1);
CountDownLatch finishedLatch = new CountDownLatch(8);
@ -152,7 +152,7 @@ public class PrioritizedExecutorsTests extends ESTestCase {
}
public void testSubmitPrioritizedExecutorWithMixed() throws Exception {
ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder);
ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, null);
List<Integer> results = new ArrayList<>(8);
CountDownLatch awaitingLatch = new CountDownLatch(1);
CountDownLatch finishedLatch = new CountDownLatch(8);
@ -182,7 +182,7 @@ public class PrioritizedExecutorsTests extends ESTestCase {
public void testTimeout() throws Exception {
ScheduledExecutorService timer = Executors.newSingleThreadScheduledExecutor(EsExecutors.daemonThreadFactory(getTestName()));
PrioritizedEsThreadPoolExecutor executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder);
PrioritizedEsThreadPoolExecutor executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, timer);
final CountDownLatch invoked = new CountDownLatch(1);
final CountDownLatch block = new CountDownLatch(1);
executor.execute(new Runnable() {
@ -219,7 +219,7 @@ public class PrioritizedExecutorsTests extends ESTestCase {
public String toString() {
return "the waiting";
}
}, timer, TimeValue.timeValueMillis(100) /* enough timeout to catch them in the pending list... */, new Runnable() {
}, TimeValue.timeValueMillis(100) /* enough timeout to catch them in the pending list... */, new Runnable() {
@Override
public void run() {
timedOut.countDown();
@ -245,14 +245,14 @@ public class PrioritizedExecutorsTests extends ESTestCase {
ThreadPool threadPool = new TestThreadPool("test");
final ScheduledThreadPoolExecutor timer = (ScheduledThreadPoolExecutor) threadPool.scheduler();
final AtomicBoolean timeoutCalled = new AtomicBoolean();
PrioritizedEsThreadPoolExecutor executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder);
PrioritizedEsThreadPoolExecutor executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, timer);
final CountDownLatch invoked = new CountDownLatch(1);
executor.execute(new Runnable() {
@Override
public void run() {
invoked.countDown();
}
}, timer, TimeValue.timeValueHours(1), new Runnable() {
}, TimeValue.timeValueHours(1), new Runnable() {
@Override
public void run() {
// We should never get here