Task queue unblock (#12099)

* concurrency: introduce GuardedBy to TaskQueue

* perf: Introduce TaskQueueScaleTest to test performance of TaskQueue with large task counts

This introduces a test case to confirm how long it will take to launch and manage (aka shutdown)
a large number of threads in the TaskQueue.

h/t to @gianm for main implementation.

* perf: improve scalability of TaskQueue with large task counts

* linter fixes, expand test coverage

* pr feedback suggestion; swap to different linter

* swap to use SuppressWarnings

* Fix TaskQueueScaleTest.

Co-authored-by: Gian Merlino <gian@imply.io>
This commit is contained in:
Jason Koch 2022-05-14 16:44:29 -07:00 committed by GitHub
parent 7ab2170802
commit bb1a6def9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 689 additions and 76 deletions

View File

@ -28,6 +28,8 @@ import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import org.apache.druid.annotations.SuppressFBWarnings;
import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskLocation;
import org.apache.druid.indexer.TaskStatus; import org.apache.druid.indexer.TaskStatus;
import org.apache.druid.indexing.common.Counters; import org.apache.druid.indexing.common.Counters;
@ -53,9 +55,12 @@ import org.apache.druid.utils.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
@ -63,7 +68,6 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -79,8 +83,11 @@ import java.util.stream.Collectors;
public class TaskQueue public class TaskQueue
{ {
private final long MANAGEMENT_WAIT_TIMEOUT_NANOS = TimeUnit.SECONDS.toNanos(60); private final long MANAGEMENT_WAIT_TIMEOUT_NANOS = TimeUnit.SECONDS.toNanos(60);
private final long MIN_WAIT_TIME_MS = 100;
@GuardedBy("giant")
private final List<Task> tasks = new ArrayList<>(); private final List<Task> tasks = new ArrayList<>();
@GuardedBy("giant")
private final Map<String, ListenableFuture<TaskStatus>> taskFutures = new HashMap<>(); private final Map<String, ListenableFuture<TaskStatus>> taskFutures = new HashMap<>();
private final TaskLockConfig lockConfig; private final TaskLockConfig lockConfig;
@ -93,7 +100,8 @@ public class TaskQueue
private final ServiceEmitter emitter; private final ServiceEmitter emitter;
private final ReentrantLock giant = new ReentrantLock(true); private final ReentrantLock giant = new ReentrantLock(true);
private final Condition managementMayBeNecessary = giant.newCondition(); @SuppressWarnings("MismatchedQueryAndUpdateOfCollection")
private final BlockingQueue<Object> managementMayBeNecessary = new ArrayBlockingQueue<>(8);
private final ExecutorService managerExec = Executors.newSingleThreadExecutor( private final ExecutorService managerExec = Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder() new ThreadFactoryBuilder()
.setDaemon(false) .setDaemon(false)
@ -111,7 +119,9 @@ public class TaskQueue
private final ConcurrentHashMap<String, AtomicLong> totalSuccessfulTaskCount = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String, AtomicLong> totalSuccessfulTaskCount = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, AtomicLong> totalFailedTaskCount = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String, AtomicLong> totalFailedTaskCount = new ConcurrentHashMap<>();
@GuardedBy("totalSuccessfulTaskCount")
private Map<String, Long> prevTotalSuccessfulTaskCount = new HashMap<>(); private Map<String, Long> prevTotalSuccessfulTaskCount = new HashMap<>();
@GuardedBy("totalFailedTaskCount")
private Map<String, Long> prevTotalFailedTaskCount = new HashMap<>(); private Map<String, Long> prevTotalFailedTaskCount = new HashMap<>();
public TaskQueue( public TaskQueue(
@ -207,7 +217,7 @@ public class TaskQueue
} }
} }
); );
managementMayBeNecessary.signalAll(); requestManagement();
} }
finally { finally {
giant.unlock(); giant.unlock();
@ -228,7 +238,7 @@ public class TaskQueue
active = false; active = false;
managerExec.shutdownNow(); managerExec.shutdownNow();
storageSyncExec.shutdownNow(); storageSyncExec.shutdownNow();
managementMayBeNecessary.signalAll(); requestManagement();
} }
finally { finally {
giant.unlock(); giant.unlock();
@ -240,6 +250,52 @@ public class TaskQueue
return active; return active;
} }
/**
* Request management from the management thread. Non-blocking.
*
* Other callers (such as notifyStatus) should trigger activity on the
* TaskQueue thread by requesting management here.
*/
void requestManagement()
{
// use a BlockingQueue since the offer/poll/wait behaviour is simple
// and very easy to reason about
// the request has to be offer (non blocking), since someone might request
// while already holding giant lock
// do not care if the item fits into the queue:
// if the queue is already full, request has been triggered anyway
managementMayBeNecessary.offer(this);
}
/**
* Await for an event to manage.
*
* This should only be called from the management thread to wait for activity.
*
* @param nanos
* @throws InterruptedException
*/
@SuppressFBWarnings(value = "RV_RETURN_VALUE_IGNORED", justification = "using queue as notification mechanism, result has no value")
void awaitManagementNanos(long nanos) throws InterruptedException
{
// mitigate a busy loop, it can get pretty busy when there are a lot of start/stops
try {
Thread.sleep(MIN_WAIT_TIME_MS);
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
// wait for an item, if an item arrives (or is already available), complete immediately
// (does not actually matter what the item is)
managementMayBeNecessary.poll(nanos - (TimeUnit.MILLISECONDS.toNanos(MIN_WAIT_TIME_MS)), TimeUnit.NANOSECONDS);
// there may have been multiple requests, clear them all
managementMayBeNecessary.clear();
}
/** /**
* Main task runner management loop. Meant to run forever, or, at least until we're stopped. * Main task runner management loop. Meant to run forever, or, at least until we're stopped.
*/ */
@ -252,31 +308,54 @@ public class TaskQueue
taskRunner.restore(); taskRunner.restore();
while (active) { while (active) {
giant.lock();
try {
manageInternal(); manageInternal();
// awaitNanos because management may become necessary without this condition signalling, // awaitNanos because management may become necessary without this condition signalling,
// due to e.g. tasks becoming ready when other folks mess with the TaskLockbox. // due to e.g. tasks becoming ready when other folks mess with the TaskLockbox.
managementMayBeNecessary.awaitNanos(MANAGEMENT_WAIT_TIMEOUT_NANOS); awaitManagementNanos(MANAGEMENT_WAIT_TIMEOUT_NANOS);
}
finally {
giant.unlock();
}
} }
} }
@VisibleForTesting @VisibleForTesting
void manageInternal() void manageInternal()
{
Set<String> knownTaskIds = new HashSet<>();
Map<String, ListenableFuture<TaskStatus>> runnerTaskFutures = new HashMap<>();
giant.lock();
try {
manageInternalCritical(knownTaskIds, runnerTaskFutures);
}
finally {
giant.unlock();
}
manageInternalPostCritical(knownTaskIds, runnerTaskFutures);
}
/**
* Management loop critical section tasks.
*
* @param knownTaskIds will be modified - filled with known task IDs
* @param runnerTaskFutures will be modified - filled with futures related to getting the running tasks
*/
@GuardedBy("giant")
private void manageInternalCritical(
final Set<String> knownTaskIds,
final Map<String, ListenableFuture<TaskStatus>> runnerTaskFutures
)
{ {
// Task futures available from the taskRunner // Task futures available from the taskRunner
final Map<String, ListenableFuture<TaskStatus>> runnerTaskFutures = new HashMap<>();
for (final TaskRunnerWorkItem workItem : taskRunner.getKnownTasks()) { for (final TaskRunnerWorkItem workItem : taskRunner.getKnownTasks()) {
runnerTaskFutures.put(workItem.getTaskId(), workItem.getResult()); runnerTaskFutures.put(workItem.getTaskId(), workItem.getResult());
} }
// Attain futures for all active tasks (assuming they are ready to run). // Attain futures for all active tasks (assuming they are ready to run).
// Copy tasks list, as notifyStatus may modify it. // Copy tasks list, as notifyStatus may modify it.
for (final Task task : ImmutableList.copyOf(tasks)) { for (final Task task : ImmutableList.copyOf(tasks)) {
knownTaskIds.add(task.getId());
if (!taskFutures.containsKey(task.getId())) { if (!taskFutures.containsKey(task.getId())) {
final ListenableFuture<TaskStatus> runnerTaskFuture; final ListenableFuture<TaskStatus> runnerTaskFuture;
if (runnerTaskFutures.containsKey(task.getId())) { if (runnerTaskFutures.containsKey(task.getId())) {
@ -317,11 +396,15 @@ public class TaskQueue
taskRunner.run(task); taskRunner.run(task);
} }
} }
}
@VisibleForTesting
private void manageInternalPostCritical(
final Set<String> knownTaskIds,
final Map<String, ListenableFuture<TaskStatus>> runnerTaskFutures
)
{
// Kill tasks that shouldn't be running // Kill tasks that shouldn't be running
final Set<String> knownTaskIds = tasks
.stream()
.map(Task::getId)
.collect(Collectors.toSet());
final Set<String> tasksToKill = Sets.difference(runnerTaskFutures.keySet(), knownTaskIds); final Set<String> tasksToKill = Sets.difference(runnerTaskFutures.keySet(), knownTaskIds);
if (!tasksToKill.isEmpty()) { if (!tasksToKill.isEmpty()) {
log.info("Asking taskRunner to clean up %,d tasks.", tasksToKill.size()); log.info("Asking taskRunner to clean up %,d tasks.", tasksToKill.size());
@ -387,7 +470,7 @@ public class TaskQueue
// insert the task into our queue. So don't catch it. // insert the task into our queue. So don't catch it.
taskStorage.insert(task, TaskStatus.running(task.getId())); taskStorage.insert(task, TaskStatus.running(task.getId()));
addTaskInternal(task); addTaskInternal(task);
managementMayBeNecessary.signalAll(); requestManagement();
return true; return true;
} }
finally { finally {
@ -396,6 +479,7 @@ public class TaskQueue
} }
// Should always be called after taking giantLock // Should always be called after taking giantLock
@GuardedBy("giant")
private void addTaskInternal(final Task task) private void addTaskInternal(final Task task)
{ {
tasks.add(task); tasks.add(task);
@ -403,6 +487,7 @@ public class TaskQueue
} }
// Should always be called after taking giantLock // Should always be called after taking giantLock
@GuardedBy("giant")
private void removeTaskInternal(final Task task) private void removeTaskInternal(final Task task)
{ {
taskLockbox.remove(task); taskLockbox.remove(task);
@ -473,11 +558,6 @@ public class TaskQueue
*/ */
private void notifyStatus(final Task task, final TaskStatus taskStatus, String reasonFormat, Object... args) private void notifyStatus(final Task task, final TaskStatus taskStatus, String reasonFormat, Object... args)
{ {
giant.lock();
TaskLocation taskLocation = TaskLocation.unknown();
try {
Preconditions.checkNotNull(task, "task"); Preconditions.checkNotNull(task, "task");
Preconditions.checkNotNull(taskStatus, "status"); Preconditions.checkNotNull(taskStatus, "status");
Preconditions.checkState(active, "Queue is not active!"); Preconditions.checkState(active, "Queue is not active!");
@ -487,7 +567,9 @@ public class TaskQueue
task.getId(), task.getId(),
taskStatus.getId() taskStatus.getId()
); );
// Inform taskRunner that this task can be shut down // Inform taskRunner that this task can be shut down
TaskLocation taskLocation = TaskLocation.unknown();
try { try {
taskLocation = taskRunner.getTaskLocation(task.getId()); taskLocation = taskRunner.getTaskLocation(task.getId());
taskRunner.shutdown(task.getId(), reasonFormat, args); taskRunner.shutdown(task.getId(), reasonFormat, args);
@ -495,8 +577,14 @@ public class TaskQueue
catch (Exception e) { catch (Exception e) {
log.warn(e, "TaskRunner failed to cleanup task after completion: %s", task.getId()); log.warn(e, "TaskRunner failed to cleanup task after completion: %s", task.getId());
} }
// Remove from running tasks
int removed = 0; int removed = 0;
///////// critical section
giant.lock();
try {
// Remove from running tasks
for (int i = tasks.size() - 1; i >= 0; i--) { for (int i = tasks.size() - 1; i >= 0; i--) {
if (tasks.get(i).getId().equals(task.getId())) { if (tasks.get(i).getId().equals(task.getId())) {
removed++; removed++;
@ -504,13 +592,20 @@ public class TaskQueue
break; break;
} }
} }
if (removed == 0) {
log.warn("Unknown task completed: %s", task.getId());
} else if (removed > 1) {
log.makeAlert("Removed multiple copies of task").addData("count", removed).addData("task", task.getId()).emit();
}
// Remove from futures list // Remove from futures list
taskFutures.remove(task.getId()); taskFutures.remove(task.getId());
}
finally {
giant.unlock();
}
///////// end critical
if (removed == 0) {
log.warn("Unknown task completed: %s", task.getId());
}
if (removed > 0) { if (removed > 0) {
// If we thought this task should be running, save status to DB // If we thought this task should be running, save status to DB
try { try {
@ -520,7 +615,7 @@ public class TaskQueue
} else { } else {
taskStorage.setStatus(taskStatus.withLocation(taskLocation)); taskStorage.setStatus(taskStatus.withLocation(taskLocation));
log.info("Task done: %s", task); log.info("Task done: %s", task);
managementMayBeNecessary.signalAll(); requestManagement();
} }
} }
catch (Exception e) { catch (Exception e) {
@ -531,10 +626,6 @@ public class TaskQueue
} }
} }
} }
finally {
giant.unlock();
}
}
/** /**
* Attach success and failure handlers to a task status future, such that when it completes, we perform the * Attach success and failure handlers to a task status future, such that when it completes, we perform the
@ -655,7 +746,7 @@ public class TaskQueue
addedTasks.size(), addedTasks.size(),
removedTasks.size() removedTasks.size()
); );
managementMayBeNecessary.signalAll(); requestManagement();
} else { } else {
log.info("Not active. Skipping storage sync."); log.info("Not active. Skipping storage sync.");
} }
@ -688,22 +779,37 @@ public class TaskQueue
public Map<String, Long> getSuccessfulTaskCount() public Map<String, Long> getSuccessfulTaskCount()
{ {
Map<String, Long> total = CollectionUtils.mapValues(totalSuccessfulTaskCount, AtomicLong::get); Map<String, Long> total = CollectionUtils.mapValues(totalSuccessfulTaskCount, AtomicLong::get);
synchronized (totalSuccessfulTaskCount) {
Map<String, Long> delta = getDeltaValues(total, prevTotalSuccessfulTaskCount); Map<String, Long> delta = getDeltaValues(total, prevTotalSuccessfulTaskCount);
prevTotalSuccessfulTaskCount = total; prevTotalSuccessfulTaskCount = total;
return delta; return delta;
} }
}
public Map<String, Long> getFailedTaskCount() public Map<String, Long> getFailedTaskCount()
{ {
Map<String, Long> total = CollectionUtils.mapValues(totalFailedTaskCount, AtomicLong::get); Map<String, Long> total = CollectionUtils.mapValues(totalFailedTaskCount, AtomicLong::get);
synchronized (totalFailedTaskCount) {
Map<String, Long> delta = getDeltaValues(total, prevTotalFailedTaskCount); Map<String, Long> delta = getDeltaValues(total, prevTotalFailedTaskCount);
prevTotalFailedTaskCount = total; prevTotalFailedTaskCount = total;
return delta; return delta;
} }
}
Map<String, String> getCurrentTaskDatasources()
{
giant.lock();
try {
return tasks.stream().collect(Collectors.toMap(Task::getId, Task::getDataSource));
}
finally {
giant.unlock();
}
}
public Map<String, Long> getRunningTaskCount() public Map<String, Long> getRunningTaskCount()
{ {
Map<String, String> taskDatasources = tasks.stream().collect(Collectors.toMap(Task::getId, Task::getDataSource)); Map<String, String> taskDatasources = getCurrentTaskDatasources();
return taskRunner.getRunningTasks() return taskRunner.getRunningTasks()
.stream() .stream()
.collect(Collectors.toMap( .collect(Collectors.toMap(
@ -715,7 +821,7 @@ public class TaskQueue
public Map<String, Long> getPendingTaskCount() public Map<String, Long> getPendingTaskCount()
{ {
Map<String, String> taskDatasources = tasks.stream().collect(Collectors.toMap(Task::getId, Task::getDataSource)); Map<String, String> taskDatasources = getCurrentTaskDatasources();
return taskRunner.getPendingTasks() return taskRunner.getPendingTasks()
.stream() .stream()
.collect(Collectors.toMap( .collect(Collectors.toMap(
@ -731,13 +837,26 @@ public class TaskQueue
.stream() .stream()
.map(TaskRunnerWorkItem::getTaskId) .map(TaskRunnerWorkItem::getTaskId)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
giant.lock();
try {
return tasks.stream().filter(task -> !runnerKnownTaskIds.contains(task.getId())) return tasks.stream().filter(task -> !runnerKnownTaskIds.contains(task.getId()))
.collect(Collectors.toMap(Task::getDataSource, task -> 1L, Long::sum)); .collect(Collectors.toMap(Task::getDataSource, task -> 1L, Long::sum));
} }
finally {
giant.unlock();
}
}
@VisibleForTesting @VisibleForTesting
List<Task> getTasks() List<Task> getTasks()
{ {
return tasks; giant.lock();
try {
return new ArrayList<Task>(tasks);
}
finally {
giant.unlock();
}
} }
} }

View File

@ -0,0 +1,494 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.indexing.overlord;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import org.apache.druid.indexer.RunnerTaskState;
import org.apache.druid.indexer.TaskLocation;
import org.apache.druid.indexer.TaskStatus;
import org.apache.druid.indexing.common.actions.TaskAction;
import org.apache.druid.indexing.common.actions.TaskActionClient;
import org.apache.druid.indexing.common.actions.TaskActionClientFactory;
import org.apache.druid.indexing.common.config.TaskStorageConfig;
import org.apache.druid.indexing.common.task.NoopTask;
import org.apache.druid.indexing.common.task.Task;
import org.apache.druid.indexing.overlord.autoscaling.ScalingStats;
import org.apache.druid.indexing.overlord.config.DefaultTaskConfig;
import org.apache.druid.indexing.overlord.config.TaskLockConfig;
import org.apache.druid.indexing.overlord.config.TaskQueueConfig;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.concurrent.ScheduledExecutors;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.metadata.IndexerSQLMetadataStorageCoordinator;
import org.apache.druid.metadata.TaskLookup;
import org.apache.druid.metadata.TestDerbyConnector;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.server.metrics.NoopServiceEmitter;
import org.joda.time.Duration;
import org.joda.time.Period;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/**
* Tests that {@link TaskQueue} is able to handle large numbers of concurrently-running tasks.
*/
public class TaskQueueScaleTest
{
private static final String DATASOURCE = "ds";
private final int numTasks = 1000;
@Rule
public final TestDerbyConnector.DerbyConnectorRule derbyConnectorRule = new TestDerbyConnector.DerbyConnectorRule();
private TaskQueue taskQueue;
private TaskStorage taskStorage;
private TestTaskRunner taskRunner;
private Closer closer;
@Before
public void setUp()
{
EmittingLogger.registerEmitter(new NoopServiceEmitter());
closer = Closer.create();
// Be as realistic as possible; use actual classes for storage rather than mocks.
taskStorage = new HeapMemoryTaskStorage(new TaskStorageConfig(Period.hours(1)));
taskRunner = new TestTaskRunner();
closer.register(taskRunner::stop);
final ObjectMapper jsonMapper = TestHelper.makeJsonMapper();
final IndexerSQLMetadataStorageCoordinator storageCoordinator = new IndexerSQLMetadataStorageCoordinator(
jsonMapper,
derbyConnectorRule.metadataTablesConfigSupplier().get(),
derbyConnectorRule.getConnector()
);
final TaskActionClientFactory unsupportedTaskActionFactory =
task -> new TaskActionClient()
{
@Override
public <RetType> RetType submit(TaskAction<RetType> taskAction)
{
throw new UnsupportedOperationException();
}
};
taskQueue = new TaskQueue(
new TaskLockConfig(),
new TaskQueueConfig(null, Period.millis(1), null, null),
new DefaultTaskConfig(),
taskStorage,
taskRunner,
unsupportedTaskActionFactory, // Not used for anything serious
new TaskLockbox(taskStorage, storageCoordinator),
new NoopServiceEmitter()
);
taskQueue.start();
closer.register(taskQueue::stop);
}
@After
public void tearDown() throws Exception
{
closer.close();
}
@Test(timeout = 60_000L) // more than enough time if the task queue is efficient
public void doMassLaunchAndExit() throws Exception
{
Assert.assertEquals("no tasks should be running", 0, taskRunner.getKnownTasks().size());
Assert.assertEquals("no tasks should be known", 0, taskQueue.getTasks().size());
Assert.assertEquals("no tasks should be running", 0, taskQueue.getRunningTaskCount().size());
// Add all tasks.
for (int i = 0; i < numTasks; i++) {
final TestTask testTask = new TestTask(i, 2000L /* runtime millis */);
taskQueue.add(testTask);
}
// in theory we can get a race here, since we fetch the counts at separate times
Assert.assertEquals("all tasks should be known", numTasks, taskQueue.getTasks().size());
long runningTasks = taskQueue.getRunningTaskCount().values().stream().mapToLong(Long::longValue).sum();
long pendingTasks = taskQueue.getPendingTaskCount().values().stream().mapToLong(Long::longValue).sum();
long waitingTasks = taskQueue.getWaitingTaskCount().values().stream().mapToLong(Long::longValue).sum();
Assert.assertEquals("all tasks should be known", numTasks, (runningTasks + pendingTasks + waitingTasks));
// Wait for all tasks to finish.
final TaskLookup.CompleteTaskLookup completeTaskLookup =
TaskLookup.CompleteTaskLookup.of(numTasks, Duration.standardHours(1));
while (taskStorage.getTaskInfos(completeTaskLookup, DATASOURCE).size() < numTasks) {
Thread.sleep(100);
}
Thread.sleep(100);
Assert.assertEquals("no tasks should be active", 0, taskStorage.getActiveTasks().size());
runningTasks = taskQueue.getRunningTaskCount().values().stream().mapToLong(Long::longValue).sum();
pendingTasks = taskQueue.getPendingTaskCount().values().stream().mapToLong(Long::longValue).sum();
waitingTasks = taskQueue.getWaitingTaskCount().values().stream().mapToLong(Long::longValue).sum();
Assert.assertEquals("no tasks should be running", 0, runningTasks);
Assert.assertEquals("no tasks should be pending", 0, pendingTasks);
Assert.assertEquals("no tasks should be waiting", 0, waitingTasks);
}
@Test(timeout = 60_000L) // more than enough time if the task queue is efficient
public void doMassLaunchAndShutdown() throws Exception
{
Assert.assertEquals("no tasks should be running", 0, taskRunner.getKnownTasks().size());
// Add all tasks.
final List<String> taskIds = new ArrayList<>();
for (int i = 0; i < numTasks; i++) {
final TestTask testTask = new TestTask(
i,
Duration.standardHours(1).getMillis() /* very long runtime millis, so we can do a shutdown */
);
taskQueue.add(testTask);
taskIds.add(testTask.getId());
}
// wait for all tasks to progress to running state
while (taskStorage.getActiveTasks().size() < numTasks) {
Thread.sleep(100);
}
Assert.assertEquals("all tasks should be running", numTasks, taskStorage.getActiveTasks().size());
// Shut down all tasks.
for (final String taskId : taskIds) {
taskQueue.shutdown(taskId, "test shutdown");
}
// Wait for all tasks to finish.
while (!taskStorage.getActiveTasks().isEmpty()) {
Thread.sleep(100);
}
Assert.assertEquals("no tasks should be running", 0, taskStorage.getActiveTasks().size());
int completed = taskStorage.getTaskInfos(
TaskLookup.CompleteTaskLookup.of(numTasks, Duration.standardHours(1)),
DATASOURCE
).size();
Assert.assertEquals("all tasks should have completed", numTasks, completed);
}
private static class TestTask extends NoopTask
{
private final int number;
private final long runtime;
public TestTask(int number, long runtime)
{
super(null, null, DATASOURCE, 0, 0, null, null, Collections.emptyMap());
this.number = number;
this.runtime = runtime;
}
public int getNumber()
{
return number;
}
public long getRuntimeMillis()
{
return runtime;
}
}
private static class TestTaskRunner implements TaskRunner
{
private static final Logger log = new Logger(TestTaskRunner.class);
private static final Duration T_PENDING_TO_RUNNING = Duration.standardSeconds(2);
private static final Duration T_SHUTDOWN_ACK = Duration.millis(8);
private static final Duration T_SHUTDOWN_COMPLETE = Duration.standardSeconds(2);
@GuardedBy("knownTasks")
private final Map<String, TestTaskRunnerWorkItem> knownTasks = new HashMap<>();
private final ScheduledExecutorService exec = ScheduledExecutors.fixed(8, "TaskQueueScaleTest-%s");
@Override
public void start()
{
throw new UnsupportedOperationException();
}
@Override
public ListenableFuture<TaskStatus> run(Task task)
{
// Production task runners generally do not take a long time to execute "run", but may take a long time to
// go from "running" to "pending".
synchronized (knownTasks) {
final TestTaskRunnerWorkItem item = knownTasks.computeIfAbsent(task.getId(), TestTaskRunnerWorkItem::new);
exec.schedule(
() -> {
try {
synchronized (knownTasks) {
final TestTaskRunnerWorkItem item2 = knownTasks.get(task.getId());
if (item2.getState() == RunnerTaskState.PENDING) {
knownTasks.put(task.getId(), item2.withState(RunnerTaskState.RUNNING));
}
}
exec.schedule(
() -> {
try {
final TestTaskRunnerWorkItem item2;
synchronized (knownTasks) {
item2 = knownTasks.get(task.getId());
knownTasks.put(task.getId(), item2.withState(RunnerTaskState.NONE));
}
if (item2 != null) {
item2.setResult(TaskStatus.success(task.getId()));
}
}
catch (Throwable e) {
log.error(e, "Error in scheduled executor");
}
},
((TestTask) task).getRuntimeMillis(),
TimeUnit.MILLISECONDS
);
}
catch (Throwable e) {
log.error(e, "Error in scheduled executor");
}
},
T_PENDING_TO_RUNNING.getMillis(),
TimeUnit.MILLISECONDS
);
return item.getResult();
}
}
@Override
public void shutdown(String taskid, String reason)
{
// Production task runners take a long time to execute "shutdown" if the task is currently running.
synchronized (knownTasks) {
if (!knownTasks.containsKey(taskid)) {
return;
}
}
threadSleep(T_SHUTDOWN_ACK);
final TestTaskRunnerWorkItem existingTask;
synchronized (knownTasks) {
existingTask = knownTasks.get(taskid);
}
if (!existingTask.getResult().isDone()) {
exec.schedule(() -> {
existingTask.setResult(TaskStatus.failure("taskId", "stopped"));
synchronized (knownTasks) {
knownTasks.remove(taskid);
}
}, T_SHUTDOWN_COMPLETE.getMillis(), TimeUnit.MILLISECONDS);
}
}
static void threadSleep(Duration duration)
{
try {
Thread.sleep(duration.getMillis());
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
}
@Override
public void registerListener(TaskRunnerListener listener, Executor executor)
{
throw new UnsupportedOperationException();
}
@Override
public void unregisterListener(String listenerId)
{
throw new UnsupportedOperationException();
}
@Override
public List<Pair<Task, ListenableFuture<TaskStatus>>> restore()
{
// Do nothing, and return null. (TaskQueue doesn't use the return value.)
return null;
}
@Override
public void stop()
{
exec.shutdownNow();
}
@Override
public Collection<? extends TaskRunnerWorkItem> getRunningTasks()
{
synchronized (knownTasks) {
return knownTasks.values()
.stream()
.filter(item -> item.getState() == RunnerTaskState.RUNNING)
.collect(Collectors.toList());
}
}
@Override
public Collection<? extends TaskRunnerWorkItem> getPendingTasks()
{
synchronized (knownTasks) {
return knownTasks.values()
.stream()
.filter(item -> item.getState() == RunnerTaskState.PENDING)
.collect(Collectors.toList());
}
}
@Override
public Collection<? extends TaskRunnerWorkItem> getKnownTasks()
{
synchronized (knownTasks) {
return ImmutableList.copyOf(knownTasks.values());
}
}
@Override
public Optional<ScalingStats> getScalingStats()
{
throw new UnsupportedOperationException();
}
@Override
public Map<String, Long> getTotalTaskSlotCount()
{
throw new UnsupportedOperationException();
}
@Override
public Map<String, Long> getIdleTaskSlotCount()
{
throw new UnsupportedOperationException();
}
@Override
public Map<String, Long> getUsedTaskSlotCount()
{
throw new UnsupportedOperationException();
}
@Override
public Map<String, Long> getLazyTaskSlotCount()
{
throw new UnsupportedOperationException();
}
@Override
public Map<String, Long> getBlacklistedTaskSlotCount()
{
throw new UnsupportedOperationException();
}
}
private static class TestTaskRunnerWorkItem extends TaskRunnerWorkItem
{
private final RunnerTaskState state;
public TestTaskRunnerWorkItem(final String taskId)
{
this(taskId, SettableFuture.create(), RunnerTaskState.PENDING);
}
private TestTaskRunnerWorkItem(
final String taskId,
final ListenableFuture<TaskStatus> result,
final RunnerTaskState state
)
{
super(taskId, result);
this.state = state;
}
public RunnerTaskState getState()
{
return state;
}
@Override
public TaskLocation getLocation()
{
return TaskLocation.unknown();
}
@Nullable
@Override
public String getTaskType()
{
throw new UnsupportedOperationException();
}
@Override
public String getDataSource()
{
throw new UnsupportedOperationException();
}
public void setResult(final TaskStatus result)
{
((SettableFuture<TaskStatus>) getResult()).set(result);
// possibly a parallel shutdown request was issued during the
// shutdown time; ignore it
}
public TestTaskRunnerWorkItem withState(final RunnerTaskState newState)
{
return new TestTaskRunnerWorkItem(getTaskId(), getResult(), newState);
}
}
}