Support task resource tracking in OpenSearch (#2639)

* Add Task id in Thread Context

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Add resource tracking update support for tasks

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* List tasks action support for task resource refresh

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Handle task unregistration case on same thread

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Add lazy initialisation for RunnableTaskExecutionListener

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Segregate resource tracking logic to a separate service.

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Check for running threads during task unregister

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Moved thread context logic to resource tracking service

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* preserve task id in thread context even after stash

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Add null check for resource tracking service

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Tracking service tests and minor refactoring

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Preserve task id fix with test

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Minor test changes and Task tracking call update

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Fix Auto Queue executor method's signature

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Make task runnable task listener factory implement consumer

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Use reflection for ThreadMXBean

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Formatting

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Replace RunnableTaskExecutionListenerFactory with AtomicReference

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Revert "Use reflection for ThreadMXBean"

This reverts commit cbcf3c525bf516fb7164f0221491a7b25c1f96ec.

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Suppress Warning related to ThreadMXBean

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Add separate method for task resource tracking supported check

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Enabled setting by default

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Add debug logs for stale context id

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Remove hardcoded task overhead in tests

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Bump stale task id in thread context log level to warn

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

* Improve assertions and logging

Signed-off-by: Tushar Kharbanda <tushar.kharbanda72@gmail.com>

Co-authored-by: Tushar Kharbanda <tkharban@amazon.com>
This commit is contained in:
Tushar Kharbanda 2022-04-21 19:21:44 +05:30 committed by GitHub
parent e9ad90b9f6
commit 6517eeca50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1426 additions and 66 deletions

View File

@ -470,6 +470,9 @@ public class TasksIT extends OpenSearchIntegTestCase {
@Override
public void waitForTaskCompletion(Task task) {}
@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}
});
}
// Need to run the task in a separate thread because node client's .execute() is blocked by our task listener
@ -651,6 +654,9 @@ public class TasksIT extends OpenSearchIntegTestCase {
waitForWaitingToStart.countDown();
}
@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}
@Override
public void onTaskRegistered(Task task) {}

View File

@ -42,6 +42,7 @@ import org.opensearch.common.inject.Inject;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskInfo;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
@ -60,8 +61,15 @@ public class TransportListTasksAction extends TransportTasksAction<Task, ListTas
private static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = timeValueSeconds(30);
private final TaskResourceTrackingService taskResourceTrackingService;
@Inject
public TransportListTasksAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters) {
public TransportListTasksAction(
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
TaskResourceTrackingService taskResourceTrackingService
) {
super(
ListTasksAction.NAME,
clusterService,
@ -72,6 +80,7 @@ public class TransportListTasksAction extends TransportTasksAction<Task, ListTas
TaskInfo::new,
ThreadPool.Names.MANAGEMENT
);
this.taskResourceTrackingService = taskResourceTrackingService;
}
@Override
@ -101,6 +110,8 @@ public class TransportListTasksAction extends TransportTasksAction<Task, ListTas
}
taskManager.waitForTaskCompletion(task, timeoutNanos);
});
} else {
operation = operation.andThen(taskResourceTrackingService::refreshResourceStats);
}
super.processTasks(request, operation);
}

View File

@ -49,6 +49,11 @@ public class SearchShardTask extends CancellableTask {
super(id, type, action, description, parentTaskId, headers);
}
@Override
public boolean supportsResourceTracking() {
return true;
}
@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;

View File

@ -78,6 +78,11 @@ public class SearchTask extends CancellableTask {
return descriptionSupplier.get();
}
@Override
public boolean supportsResourceTracking() {
return true;
}
/**
* Attach a {@link SearchProgressListener} to this task.
*/

View File

@ -40,6 +40,7 @@ import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.ActionResponse;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskCancelledException;
import org.opensearch.tasks.TaskId;
@ -88,31 +89,39 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
*/
final Releasable unregisterChildNode = registerChildNode(request.getParentTask());
final Task task;
try {
task = taskManager.register("transport", actionName, request);
} catch (TaskCancelledException e) {
unregisterChildNode.close();
throw e;
}
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(response);
}
}
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task);
try {
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(response);
}
}
}
});
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
}
}
});
} finally {
storedContext.close();
}
return task;
}
@ -129,25 +138,30 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
unregisterChildNode.close();
throw e;
}
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(task, response);
ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task);
try {
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(task, response);
}
}
}
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
}
}
}
});
});
} finally {
storedContext.close();
}
return task;
}

View File

@ -94,6 +94,7 @@ import org.opensearch.plugins.ClusterPlugin;
import org.opensearch.script.ScriptMetadata;
import org.opensearch.snapshots.SnapshotsInfoService;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.tasks.TaskResultsService;
import java.util.ArrayList;
@ -394,6 +395,7 @@ public class ClusterModule extends AbstractModule {
bind(NodeMappingRefreshAction.class).asEagerSingleton();
bind(MappingUpdatedAction.class).asEagerSingleton();
bind(TaskResultsService.class).asEagerSingleton();
bind(TaskResourceTrackingService.class).asEagerSingleton();
bind(AllocationDeciders.class).toInstance(allocationDeciders);
bind(ShardsAllocator.class).toInstance(shardsAllocator);
}

View File

@ -40,6 +40,7 @@ import org.opensearch.index.IndexingPressure;
import org.opensearch.index.ShardIndexingPressureMemoryManager;
import org.opensearch.index.ShardIndexingPressureSettings;
import org.opensearch.index.ShardIndexingPressureStore;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.watcher.ResourceWatcherService;
import org.opensearch.action.admin.cluster.configuration.TransportAddVotingConfigExclusionsAction;
import org.opensearch.action.admin.indices.close.TransportCloseIndexAction;
@ -568,7 +569,8 @@ public final class ClusterSettings extends AbstractScopedSettings {
ShardIndexingPressureMemoryManager.THROUGHPUT_DEGRADATION_LIMITS,
ShardIndexingPressureMemoryManager.SUCCESSFUL_REQUEST_ELAPSED_TIMEOUT,
ShardIndexingPressureMemoryManager.MAX_OUTSTANDING_REQUESTS,
IndexingPressure.MAX_INDEXING_BYTES
IndexingPressure.MAX_INDEXING_BYTES,
TaskResourceTrackingService.TASK_RESOURCE_TRACKING_ENABLED
)
)
);

View File

@ -40,6 +40,8 @@ import org.opensearch.common.settings.Setting.Property;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.node.Node;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.TaskAwareRunnable;
import java.util.List;
import java.util.Optional;
@ -55,6 +57,7 @@ import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
public class OpenSearchExecutors {
@ -172,15 +175,6 @@ public class OpenSearchExecutors {
);
}
/**
* Return a new executor that will automatically adjust the queue size based on queue throughput.
*
* @param size number of fixed threads to use for executing tasks
* @param initialQueueCapacity initial size of the executor queue
* @param minQueueSize minimum queue size that the queue can be adjusted to
* @param maxQueueSize maximum queue size that the queue can be adjusted to
* @param frameSize number of tasks during which stats are collected before adjusting queue size
*/
public static OpenSearchThreadPoolExecutor newAutoQueueFixed(
String name,
int size,
@ -191,6 +185,41 @@ public class OpenSearchExecutors {
TimeValue targetedResponseTime,
ThreadFactory threadFactory,
ThreadContext contextHolder
) {
return newAutoQueueFixed(
name,
size,
initialQueueCapacity,
minQueueSize,
maxQueueSize,
frameSize,
targetedResponseTime,
threadFactory,
contextHolder,
null
);
}
/**
* Return a new executor that will automatically adjust the queue size based on queue throughput.
*
* @param size number of fixed threads to use for executing tasks
* @param initialQueueCapacity initial size of the executor queue
* @param minQueueSize minimum queue size that the queue can be adjusted to
* @param maxQueueSize maximum queue size that the queue can be adjusted to
* @param frameSize number of tasks during which stats are collected before adjusting queue size
*/
public static OpenSearchThreadPoolExecutor newAutoQueueFixed(
String name,
int size,
int initialQueueCapacity,
int minQueueSize,
int maxQueueSize,
int frameSize,
TimeValue targetedResponseTime,
ThreadFactory threadFactory,
ThreadContext contextHolder,
AtomicReference<RunnableTaskExecutionListener> runnableTaskListener
) {
if (initialQueueCapacity <= 0) {
throw new IllegalArgumentException(
@ -201,6 +230,17 @@ public class OpenSearchExecutors {
ConcurrentCollections.<Runnable>newBlockingQueue(),
initialQueueCapacity
);
Function<Runnable, WrappedRunnable> runnableWrapper;
if (runnableTaskListener != null) {
runnableWrapper = (runnable) -> {
TaskAwareRunnable taskAwareRunnable = new TaskAwareRunnable(contextHolder, runnable, runnableTaskListener);
return new TimedRunnable(taskAwareRunnable);
};
} else {
runnableWrapper = TimedRunnable::new;
}
return new QueueResizingOpenSearchThreadPoolExecutor(
name,
size,
@ -210,7 +250,7 @@ public class OpenSearchExecutors {
queue,
minQueueSize,
maxQueueSize,
TimedRunnable::new,
runnableWrapper,
frameSize,
targetedResponseTime,
threadFactory,

View File

@ -66,6 +66,7 @@ import java.util.stream.Stream;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_COUNT;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_SIZE;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;
/**
* A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with
@ -134,16 +135,23 @@ public final class ThreadContext implements Writeable {
* This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user.
* Otherwise when context is stash, it should be empty.
*/
ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT;
if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) {
ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putHeaders(
threadContextStruct = threadContextStruct.putHeaders(
MapBuilder.<String, String>newMapBuilder()
.put(Task.X_OPAQUE_ID, context.requestHeaders.get(Task.X_OPAQUE_ID))
.immutableMap()
);
threadLocal.set(threadContextStruct);
} else {
threadLocal.set(DEFAULT_CONTEXT);
}
if (context.transientHeaders.containsKey(TASK_ID)) {
threadContextStruct = threadContextStruct.putTransient(TASK_ID, context.transientHeaders.get(TASK_ID));
}
threadLocal.set(threadContextStruct);
return () -> {
// If the node and thus the threadLocal get closed while this task
// is still executing, we don't want this runnable to fail with an

View File

@ -37,6 +37,8 @@ import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SetOnce;
import org.opensearch.index.IndexingPressureService;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.watcher.ResourceWatcherService;
import org.opensearch.Assertions;
import org.opensearch.Build;
@ -213,6 +215,7 @@ import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
@ -324,6 +327,7 @@ public class Node implements Closeable {
private final LocalNodeFactory localNodeFactory;
private final NodeService nodeService;
final NamedWriteableRegistry namedWriteableRegistry;
private final AtomicReference<RunnableTaskExecutionListener> runnableTaskListener;
public Node(Environment environment) {
this(environment, Collections.emptyList(), true);
@ -433,7 +437,8 @@ public class Node implements Closeable {
final List<ExecutorBuilder<?>> executorBuilders = pluginsService.getExecutorBuilders(settings);
final ThreadPool threadPool = new ThreadPool(settings, executorBuilders.toArray(new ExecutorBuilder[0]));
runnableTaskListener = new AtomicReference<>();
final ThreadPool threadPool = new ThreadPool(settings, runnableTaskListener, executorBuilders.toArray(new ExecutorBuilder[0]));
resourcesToClose.add(() -> ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS));
final ResourceWatcherService resourceWatcherService = new ResourceWatcherService(settings, threadPool);
resourcesToClose.add(resourceWatcherService);
@ -1057,6 +1062,11 @@ public class Node implements Closeable {
TransportService transportService = injector.getInstance(TransportService.class);
transportService.getTaskManager().setTaskResultsService(injector.getInstance(TaskResultsService.class));
transportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(transportService));
TaskResourceTrackingService taskResourceTrackingService = injector.getInstance(TaskResourceTrackingService.class);
transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService);
runnableTaskListener.set(taskResourceTrackingService);
transportService.start();
assert localNodeFactory.getNode() != null;
assert transportService.getLocalNode().equals(localNodeFactory.getNode())
@ -1490,4 +1500,5 @@ public class Node implements Closeable {
return localNode.get();
}
}
}

View File

@ -32,8 +32,6 @@
package org.opensearch.tasks;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.NamedWriteable;
@ -53,8 +51,6 @@ import java.util.concurrent.ConcurrentHashMap;
*/
public class Task {
private static final Logger logger = LogManager.getLogger(Task.class);
/**
* The request header to mark tasks with specific ids
*/
@ -289,7 +285,7 @@ public class Task {
);
}
}
threadResourceInfoList.add(new ThreadResourceInfo(statsType, resourceUsageMetrics));
threadResourceInfoList.add(new ThreadResourceInfo(threadId, statsType, resourceUsageMetrics));
}
/**
@ -336,6 +332,17 @@ public class Task {
throw new IllegalStateException("cannot update final values if active thread resource entry is not present");
}
/**
* Individual tasks can override this if they want to support task resource tracking. We just need to make sure that
* the ThreadPool on which the task runs on have runnable wrapper similar to
* {@link org.opensearch.common.util.concurrent.OpenSearchExecutors#newAutoQueueFixed}
*
* @return true if resource tracking is supported by the task
*/
public boolean supportsResourceTracking() {
return false;
}
/**
* Report of the internal status of a task. These can vary wildly from task
* to task because each task is implemented differently but we should try

View File

@ -89,7 +89,9 @@ public class TaskManager implements ClusterStateApplier {
private static final TimeValue WAIT_FOR_COMPLETION_POLL = timeValueMillis(100);
/** Rest headers that are copied to the task */
/**
* Rest headers that are copied to the task
*/
private final List<String> taskHeaders;
private final ThreadPool threadPool;
@ -103,6 +105,7 @@ public class TaskManager implements ClusterStateApplier {
private final Map<TaskId, String> banedParents = new ConcurrentHashMap<>();
private TaskResultsService taskResultsService;
private final SetOnce<TaskResourceTrackingService> taskResourceTrackingService = new SetOnce<>();
private volatile DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES;
@ -125,6 +128,10 @@ public class TaskManager implements ClusterStateApplier {
this.cancellationService.set(taskCancellationService);
}
public void setTaskResourceTrackingService(TaskResourceTrackingService taskResourceTrackingService) {
this.taskResourceTrackingService.set(taskResourceTrackingService);
}
/**
* Registers a task without parent task
*/
@ -202,6 +209,11 @@ public class TaskManager implements ClusterStateApplier {
*/
public Task unregister(Task task) {
logger.trace("unregister task for id: {}", task.getId());
if (taskResourceTrackingService.get() != null && task.supportsResourceTracking()) {
taskResourceTrackingService.get().stopTracking(task);
}
if (task instanceof CancellableTask) {
CancellableTaskHolder holder = cancellableTasks.remove(task.getId());
if (holder != null) {
@ -361,6 +373,7 @@ public class TaskManager implements ClusterStateApplier {
* Bans all tasks with the specified parent task from execution, cancels all tasks that are currently executing.
* <p>
* This method is called when a parent task that has children is cancelled.
*
* @return a list of pending cancellable child tasks
*/
public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
@ -448,6 +461,18 @@ public class TaskManager implements ClusterStateApplier {
throw new OpenSearchTimeoutException("Timed out waiting for completion of [{}]", task);
}
/**
* Takes actions when a task is registered and its execution starts
*
* @param task getting executed.
* @return AutoCloseable to free up resources (clean up thread context) when task execution block returns
*/
public ThreadContext.StoredContext taskExecutionStarted(Task task) {
if (taskResourceTrackingService.get() == null) return () -> {};
return taskResourceTrackingService.get().startTracking(task);
}
private static class CancellableTaskHolder {
private final CancellableTask task;
private boolean finished = false;

View File

@ -0,0 +1,255 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.tasks;
import com.sun.management.ThreadMXBean;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.common.util.concurrent.ConcurrentMapLong;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.ThreadPool;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.opensearch.tasks.ResourceStatsType.WORKER_STATS;
/**
* Service that helps track resource usage of tasks running on a node.
*/
@SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes")
public class TaskResourceTrackingService implements RunnableTaskExecutionListener {
private static final Logger logger = LogManager.getLogger(TaskManager.class);
public static final Setting<Boolean> TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting(
"task_resource_tracking.enabled",
true,
Setting.Property.Dynamic,
Setting.Property.NodeScope
);
public static final String TASK_ID = "TASK_ID";
private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean();
private final ConcurrentMapLong<Task> resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency();
private final ThreadPool threadPool;
private volatile boolean taskResourceTrackingEnabled;
@Inject
public TaskResourceTrackingService(Settings settings, ClusterSettings clusterSettings, ThreadPool threadPool) {
this.taskResourceTrackingEnabled = TASK_RESOURCE_TRACKING_ENABLED.get(settings);
this.threadPool = threadPool;
clusterSettings.addSettingsUpdateConsumer(TASK_RESOURCE_TRACKING_ENABLED, this::setTaskResourceTrackingEnabled);
}
public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) {
this.taskResourceTrackingEnabled = taskResourceTrackingEnabled;
}
public boolean isTaskResourceTrackingEnabled() {
return taskResourceTrackingEnabled;
}
public boolean isTaskResourceTrackingSupported() {
return threadMXBean.isThreadAllocatedMemorySupported() && threadMXBean.isThreadAllocatedMemoryEnabled();
}
/**
* Executes logic only if task supports resource tracking and resource tracking setting is enabled.
* <p>
* 1. Starts tracking the task in map of resourceAwareTasks.
* 2. Adds Task Id in thread context to make sure it's available while task is processed across multiple threads.
*
* @param task for which resources needs to be tracked
* @return Autocloseable stored context to restore ThreadContext to the state before this method changed it.
*/
public ThreadContext.StoredContext startTracking(Task task) {
if (task.supportsResourceTracking() == false
|| isTaskResourceTrackingEnabled() == false
|| isTaskResourceTrackingSupported() == false) {
return () -> {};
}
logger.debug("Starting resource tracking for task: {}", task.getId());
resourceAwareTasks.put(task.getId(), task);
return addTaskIdToThreadContext(task);
}
/**
* Stops tracking task registered earlier for tracking.
* <p>
* It doesn't have feature enabled check to avoid any issues if setting was disable while the task was in progress.
* <p>
* It's also responsible to stop tracking the current thread's resources against this task if not already done.
* This happens when the thread executing the request logic itself calls the unregister method. So in this case unregister
* happens before runnable finishes.
*
* @param task task which has finished and doesn't need resource tracking.
*/
public void stopTracking(Task task) {
logger.debug("Stopping resource tracking for task: {}", task.getId());
try {
if (isCurrentThreadWorkingOnTask(task)) {
taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId());
}
List<Long> threadsWorkingOnTask = getThreadsWorkingOnTask(task);
if (threadsWorkingOnTask.size() > 0) {
logger.warn("No thread should be active when task finishes. Active threads: {}", threadsWorkingOnTask);
assert false : "No thread should be marked active when task finishes";
}
} catch (Exception e) {
logger.warn("Failed while trying to mark the task execution on current thread completed.", e);
assert false;
} finally {
resourceAwareTasks.remove(task.getId());
}
}
/**
* Refreshes the resource stats for the tasks provided by looking into which threads are actively working on these
* and how much resources these have consumed till now.
*
* @param tasks for which resource stats needs to be refreshed.
*/
public void refreshResourceStats(Task... tasks) {
if (isTaskResourceTrackingEnabled() == false || isTaskResourceTrackingSupported() == false) {
return;
}
for (Task task : tasks) {
if (task.supportsResourceTracking() && resourceAwareTasks.containsKey(task.getId())) {
refreshResourceStats(task);
}
}
}
private void refreshResourceStats(Task resourceAwareTask) {
try {
logger.debug("Refreshing resource stats for Task: {}", resourceAwareTask.getId());
List<Long> threadsWorkingOnTask = getThreadsWorkingOnTask(resourceAwareTask);
threadsWorkingOnTask.forEach(
threadId -> resourceAwareTask.updateThreadResourceStats(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId))
);
} catch (IllegalStateException e) {
logger.debug("Resource stats already updated.");
}
}
/**
* Called when a thread starts working on a task's runnable.
*
* @param taskId of the task for which runnable is starting
* @param threadId of the thread which will be executing the runnable and we need to check resource usage for this
* thread
*/
@Override
public void taskExecutionStartedOnThread(long taskId, long threadId) {
try {
if (resourceAwareTasks.containsKey(taskId)) {
logger.debug("Task execution started on thread. Task: {}, Thread: {}", taskId, threadId);
resourceAwareTasks.get(taskId)
.startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId));
}
} catch (Exception e) {
logger.warn(new ParameterizedMessage("Failed to mark thread execution started for task: [{}]", taskId), e);
assert false;
}
}
/**
* Called when a thread finishes working on a task's runnable.
*
* @param taskId of the task for which runnable is complete
* @param threadId of the thread which executed the runnable and we need to check resource usage for this thread
*/
@Override
public void taskExecutionFinishedOnThread(long taskId, long threadId) {
try {
if (resourceAwareTasks.containsKey(taskId)) {
logger.debug("Task execution finished on thread. Task: {}, Thread: {}", taskId, threadId);
resourceAwareTasks.get(taskId)
.stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId));
}
} catch (Exception e) {
logger.warn(new ParameterizedMessage("Failed to mark thread execution finished for task: [{}]", taskId), e);
assert false;
}
}
public Map<Long, Task> getResourceAwareTasks() {
return Collections.unmodifiableMap(resourceAwareTasks);
}
private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) {
ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric(
ResourceStats.MEMORY,
threadMXBean.getThreadAllocatedBytes(threadId)
);
ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId));
return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage };
}
private boolean isCurrentThreadWorkingOnTask(Task task) {
long threadId = Thread.currentThread().getId();
List<ThreadResourceInfo> threadResourceInfos = task.getResourceStats().getOrDefault(threadId, Collections.emptyList());
for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) {
if (threadResourceInfo.isActive()) {
return true;
}
}
return false;
}
private List<Long> getThreadsWorkingOnTask(Task task) {
List<Long> activeThreads = new ArrayList<>();
for (List<ThreadResourceInfo> threadResourceInfos : task.getResourceStats().values()) {
for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) {
if (threadResourceInfo.isActive()) {
activeThreads.add(threadResourceInfo.getThreadId());
}
}
}
return activeThreads;
}
/**
* Adds Task Id in the ThreadContext.
* <p>
* Stashes the existing ThreadContext and preserves all the existing ThreadContext's data in the new ThreadContext
* as well.
*
* @param task for which Task Id needs to be added in ThreadContext.
* @return StoredContext reference to restore the ThreadContext from which we created a new one.
* Caller can call context.restore() to get the existing ThreadContext back.
*/
private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) {
ThreadContext threadContext = threadPool.getThreadContext();
ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true, Collections.singletonList(TASK_ID));
threadContext.putTransient(TASK_ID, task.getId());
return storedContext;
}
}

View File

@ -15,11 +15,13 @@ package org.opensearch.tasks;
* for a specific stats type like worker_stats or response_stats etc.,
*/
public class ThreadResourceInfo {
private final long threadId;
private volatile boolean isActive = true;
private final ResourceStatsType statsType;
private final ResourceUsageInfo resourceUsageInfo;
public ThreadResourceInfo(ResourceStatsType statsType, ResourceUsageMetric... resourceUsageMetrics) {
public ThreadResourceInfo(long threadId, ResourceStatsType statsType, ResourceUsageMetric... resourceUsageMetrics) {
this.threadId = threadId;
this.statsType = statsType;
this.resourceUsageInfo = new ResourceUsageInfo(resourceUsageMetrics);
}
@ -43,12 +45,16 @@ public class ThreadResourceInfo {
return statsType;
}
public long getThreadId() {
return threadId;
}
public ResourceUsageInfo getResourceUsageInfo() {
return resourceUsageInfo;
}
@Override
public String toString() {
return resourceUsageInfo + ", stats_type=" + statsType + ", is_active=" + isActive;
return resourceUsageInfo + ", stats_type=" + statsType + ", is_active=" + isActive + ", threadId=" + threadId;
}
}

View File

@ -48,6 +48,7 @@ import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicReference;
/**
* A builder for executors that automatically adjust the queue length as needed, depending on
@ -61,6 +62,7 @@ public final class AutoQueueAdjustingExecutorBuilder extends ExecutorBuilder<Aut
private final Setting<Integer> maxQueueSizeSetting;
private final Setting<TimeValue> targetedResponseTimeSetting;
private final Setting<Integer> frameSizeSetting;
private final AtomicReference<RunnableTaskExecutionListener> runnableTaskListener;
AutoQueueAdjustingExecutorBuilder(
final Settings settings,
@ -70,6 +72,19 @@ public final class AutoQueueAdjustingExecutorBuilder extends ExecutorBuilder<Aut
final int minQueueSize,
final int maxQueueSize,
final int frameSize
) {
this(settings, name, size, initialQueueSize, minQueueSize, maxQueueSize, frameSize, null);
}
AutoQueueAdjustingExecutorBuilder(
final Settings settings,
final String name,
final int size,
final int initialQueueSize,
final int minQueueSize,
final int maxQueueSize,
final int frameSize,
final AtomicReference<RunnableTaskExecutionListener> runnableTaskListener
) {
super(name);
final String prefix = "thread_pool." + name;
@ -184,6 +199,7 @@ public final class AutoQueueAdjustingExecutorBuilder extends ExecutorBuilder<Aut
Setting.Property.Deprecated,
Setting.Property.Deprecated
);
this.runnableTaskListener = runnableTaskListener;
}
@Override
@ -230,7 +246,8 @@ public final class AutoQueueAdjustingExecutorBuilder extends ExecutorBuilder<Aut
frameSize,
targetedResponseTime,
threadFactory,
threadContext
threadContext,
runnableTaskListener
);
// TODO: in a subsequent change we hope to extend ThreadPool.Info to be more specific for the thread pool type
final ThreadPool.Info info = new ThreadPool.Info(

View File

@ -0,0 +1,33 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.threadpool;
/**
* Listener for events when a runnable execution starts or finishes on a thread and is aware of the task for which the
* runnable is associated to.
*/
public interface RunnableTaskExecutionListener {
/**
* Sends an update when ever a task's execution start on a thread
*
* @param taskId of task which has started
* @param threadId of thread which is executing the task
*/
void taskExecutionStartedOnThread(long taskId, long threadId);
/**
*
* Sends an update when task execution finishes on a thread
*
* @param taskId of task which has finished
* @param threadId of thread which executed the task
*/
void taskExecutionFinishedOnThread(long taskId, long threadId);
}

View File

@ -0,0 +1,90 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.threadpool;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.WrappedRunnable;
import org.opensearch.tasks.TaskManager;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import static java.lang.Thread.currentThread;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;
/**
* Responsible for wrapping the original task's runnable and sending updates on when it starts and finishes to
* entities listening to the events.
*
* It's able to associate runnable with a task with the help of task Id available in thread context.
*/
public class TaskAwareRunnable extends AbstractRunnable implements WrappedRunnable {
private static final Logger logger = LogManager.getLogger(TaskManager.class);
private final Runnable original;
private final ThreadContext threadContext;
private final AtomicReference<RunnableTaskExecutionListener> runnableTaskListener;
public TaskAwareRunnable(
final ThreadContext threadContext,
final Runnable original,
final AtomicReference<RunnableTaskExecutionListener> runnableTaskListener
) {
this.original = original;
this.threadContext = threadContext;
this.runnableTaskListener = runnableTaskListener;
}
@Override
public void onFailure(Exception e) {
ExceptionsHelper.reThrowIfNotNull(e);
}
@Override
public boolean isForceExecution() {
return original instanceof AbstractRunnable && ((AbstractRunnable) original).isForceExecution();
}
@Override
public void onRejection(final Exception e) {
if (original instanceof AbstractRunnable) {
((AbstractRunnable) original).onRejection(e);
} else {
ExceptionsHelper.reThrowIfNotNull(e);
}
}
@Override
protected void doRun() throws Exception {
assert runnableTaskListener.get() != null : "Listener should be attached";
Long taskId = threadContext.getTransient(TASK_ID);
if (Objects.nonNull(taskId)) {
runnableTaskListener.get().taskExecutionStartedOnThread(taskId, currentThread().getId());
} else {
logger.debug("Task Id not available in thread context. Skipping update. Thread Info: {}", Thread.currentThread());
}
try {
original.run();
} finally {
if (Objects.nonNull(taskId)) {
runnableTaskListener.get().taskExecutionFinishedOnThread(taskId, currentThread().getId());
}
}
}
@Override
public Runnable unwrap() {
return original;
}
}

View File

@ -68,6 +68,7 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import static java.util.Collections.unmodifiableMap;
@ -184,6 +185,14 @@ public class ThreadPool implements ReportingService<ThreadPoolInfo>, Scheduler {
);
public ThreadPool(final Settings settings, final ExecutorBuilder<?>... customBuilders) {
this(settings, null, customBuilders);
}
public ThreadPool(
final Settings settings,
final AtomicReference<RunnableTaskExecutionListener> runnableTaskListener,
final ExecutorBuilder<?>... customBuilders
) {
assert Node.NODE_NAME_SETTING.exists(settings);
final Map<String, ExecutorBuilder> builders = new HashMap<>();
@ -197,11 +206,20 @@ public class ThreadPool implements ReportingService<ThreadPoolInfo>, Scheduler {
builders.put(Names.ANALYZE, new FixedExecutorBuilder(settings, Names.ANALYZE, 1, 16));
builders.put(
Names.SEARCH,
new AutoQueueAdjustingExecutorBuilder(settings, Names.SEARCH, searchThreadPoolSize(allocatedProcessors), 1000, 1000, 1000, 2000)
new AutoQueueAdjustingExecutorBuilder(
settings,
Names.SEARCH,
searchThreadPoolSize(allocatedProcessors),
1000,
1000,
1000,
2000,
runnableTaskListener
)
);
builders.put(
Names.SEARCH_THROTTLED,
new AutoQueueAdjustingExecutorBuilder(settings, Names.SEARCH_THROTTLED, 1, 100, 100, 100, 200)
new AutoQueueAdjustingExecutorBuilder(settings, Names.SEARCH_THROTTLED, 1, 100, 100, 100, 200, runnableTaskListener)
);
builders.put(Names.MANAGEMENT, new ScalingExecutorBuilder(Names.MANAGEMENT, 1, 5, TimeValue.timeValueMinutes(5)));
// no queue as this means clients will need to handle rejections on listener queue even if the operation succeeded

View File

@ -37,6 +37,7 @@ import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskManager;
@ -81,6 +82,8 @@ public class RequestHandlerRegistry<Request extends TransportRequest> {
public void processMessageReceived(Request request, TransportChannel channel) throws Exception {
final Task task = taskManager.register(channel.getChannelType(), action, request);
ThreadContext.StoredContext contextToRestore = taskManager.taskExecutionStarted(task);
Releasable unregisterTask = () -> taskManager.unregister(task);
try {
if (channel instanceof TcpTransportChannel && task instanceof CancellableTask) {
@ -99,6 +102,7 @@ public class RequestHandlerRegistry<Request extends TransportRequest> {
unregisterTask = null;
} finally {
Releasables.close(unregisterTask);
contextToRestore.restore();
}
}

View File

@ -75,6 +75,9 @@ public class RecordingTaskManagerListener implements MockTaskManagerListener {
@Override
public void waitForTaskCompletion(Task task) {}
@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}
public synchronized List<Tuple<Boolean, TaskInfo>> getEvents() {
return Collections.unmodifiableList(new ArrayList<>(events));
}

View File

@ -0,0 +1,633 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.action.admin.cluster.node.tasks;
import com.sun.management.ThreadMXBean;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.ActionListener;
import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksRequest;
import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.opensearch.action.support.ActionTestUtils;
import org.opensearch.action.support.nodes.BaseNodeRequest;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskCancelledException;
import org.opensearch.tasks.TaskId;
import org.opensearch.tasks.TaskInfo;
import org.opensearch.test.tasks.MockTaskManager;
import org.opensearch.test.tasks.MockTaskManagerListener;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import java.io.IOException;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;
@SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes")
public class ResourceAwareTasksTests extends TaskManagerTestCase {
private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean();
public static class ResourceAwareNodeRequest extends BaseNodeRequest {
protected String requestName;
public ResourceAwareNodeRequest() {
super();
}
public ResourceAwareNodeRequest(StreamInput in) throws IOException {
super(in);
requestName = in.readString();
}
public ResourceAwareNodeRequest(NodesRequest request) {
requestName = request.requestName;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(requestName);
}
@Override
public String getDescription() {
return "ResourceAwareNodeRequest[" + requestName + "]";
}
@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) {
@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;
}
@Override
public boolean supportsResourceTracking() {
return true;
}
};
}
}
public static class NodesRequest extends BaseNodesRequest<NodesRequest> {
private final String requestName;
private NodesRequest(StreamInput in) throws IOException {
super(in);
requestName = in.readString();
}
public NodesRequest(String requestName, String... nodesIds) {
super(nodesIds);
this.requestName = requestName;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(requestName);
}
@Override
public String getDescription() {
return "NodesRequest[" + requestName + "]";
}
@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) {
@Override
public boolean shouldCancelChildrenOnCancellation() {
return true;
}
};
}
}
/**
* Simulates a task which executes work on search executor.
*/
class ResourceAwareNodesAction extends AbstractTestNodesAction<NodesRequest, ResourceAwareNodeRequest> {
private final TaskTestContext taskTestContext;
private final boolean blockForCancellation;
ResourceAwareNodesAction(
String actionName,
ThreadPool threadPool,
ClusterService clusterService,
TransportService transportService,
boolean shouldBlock,
TaskTestContext taskTestContext
) {
super(actionName, threadPool, clusterService, transportService, NodesRequest::new, ResourceAwareNodeRequest::new);
this.taskTestContext = taskTestContext;
this.blockForCancellation = shouldBlock;
}
@Override
protected ResourceAwareNodeRequest newNodeRequest(NodesRequest request) {
return new ResourceAwareNodeRequest(request);
}
@Override
protected NodeResponse nodeOperation(ResourceAwareNodeRequest request, Task task) {
assert task.supportsResourceTracking();
AtomicLong threadId = new AtomicLong();
Future<?> result = threadPool.executor(ThreadPool.Names.SEARCH).submit(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
ExceptionsHelper.reThrowIfNotNull(e);
}
@Override
@SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes")
protected void doRun() {
taskTestContext.memoryConsumptionWhenExecutionStarts = threadMXBean.getThreadAllocatedBytes(
Thread.currentThread().getId()
);
threadId.set(Thread.currentThread().getId());
if (taskTestContext.operationStartValidator != null) {
try {
taskTestContext.operationStartValidator.accept(threadId.get());
} catch (AssertionError error) {
throw new RuntimeException(error);
}
}
Object[] allocation1 = new Object[1000000]; // 4MB
if (blockForCancellation) {
// Simulate a job that takes forever to finish
// Using periodic checks method to identify that the task was cancelled
try {
boolean taskCancelled = waitUntil(((CancellableTask) task)::isCancelled);
if (taskCancelled) {
throw new TaskCancelledException("Task Cancelled");
} else {
fail("It should have thrown an exception");
}
} catch (InterruptedException ex) {
Thread.currentThread().interrupt();
}
}
Object[] allocation2 = new Object[1000000]; // 4MB
}
});
try {
result.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e.getCause());
} finally {
if (taskTestContext.operationFinishedValidator != null) {
taskTestContext.operationFinishedValidator.accept(threadId.get());
}
}
return new NodeResponse(clusterService.localNode());
}
@Override
protected NodeResponse nodeOperation(ResourceAwareNodeRequest request) {
throw new UnsupportedOperationException("the task parameter is required");
}
}
private TaskTestContext startResourceAwareNodesAction(
TestNode node,
boolean blockForCancellation,
TaskTestContext taskTestContext,
ActionListener<NodesResponse> listener
) {
NodesRequest request = new NodesRequest("Test Request", node.getNodeId());
taskTestContext.requestCompleteLatch = new CountDownLatch(1);
ResourceAwareNodesAction action = new ResourceAwareNodesAction(
"internal:resourceAction",
threadPool,
node.clusterService,
node.transportService,
blockForCancellation,
taskTestContext
);
taskTestContext.mainTask = action.execute(request, listener);
return taskTestContext;
}
private static class TaskTestContext {
private Task mainTask;
private CountDownLatch requestCompleteLatch;
private Consumer<Long> operationStartValidator;
private Consumer<Long> operationFinishedValidator;
private long memoryConsumptionWhenExecutionStarts;
}
public void testBasicTaskResourceTracking() throws Exception {
setup(true, false);
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
TaskTestContext taskTestContext = new TaskTestContext();
Map<Long, Task> resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks();
taskTestContext.operationStartValidator = threadId -> {
Task task = resourceTasks.values().stream().findAny().get();
// One thread is currently working on task but not finished
assertEquals(1, resourceTasks.size());
assertEquals(1, task.getResourceStats().size());
assertEquals(1, task.getResourceStats().get(threadId).size());
assertTrue(task.getResourceStats().get(threadId).get(0).isActive());
assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos());
assertEquals(0, task.getTotalResourceStats().getMemoryInBytes());
};
taskTestContext.operationFinishedValidator = threadId -> {
Task task = resourceTasks.values().stream().findAny().get();
// Thread has finished working on the task's runnable
assertEquals(1, resourceTasks.size());
assertEquals(1, task.getResourceStats().size());
assertEquals(1, task.getResourceStats().get(threadId).size());
assertFalse(task.getResourceStats().get(threadId).get(0).isActive());
long expectedArrayAllocationOverhead = 2 * 4012688; // Task's memory overhead due to array allocations
long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes();
assertTrue(actualTaskMemoryOverhead - expectedArrayAllocationOverhead < taskTestContext.memoryConsumptionWhenExecutionStarts);
assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0);
};
startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
taskTestContext.requestCompleteLatch.countDown();
}
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
taskTestContext.requestCompleteLatch.countDown();
}
});
// Waiting for whole request to complete and return successfully till client
taskTestContext.requestCompleteLatch.await();
assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get());
}
public void testTaskResourceTrackingDuringTaskCancellation() throws Exception {
setup(true, false);
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
TaskTestContext taskTestContext = new TaskTestContext();
Map<Long, Task> resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks();
taskTestContext.operationStartValidator = threadId -> {
Task task = resourceTasks.values().stream().findAny().get();
// One thread is currently working on task but not finished
assertEquals(1, resourceTasks.size());
assertEquals(1, task.getResourceStats().size());
assertEquals(1, task.getResourceStats().get(threadId).size());
assertTrue(task.getResourceStats().get(threadId).get(0).isActive());
assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos());
assertEquals(0, task.getTotalResourceStats().getMemoryInBytes());
};
taskTestContext.operationFinishedValidator = threadId -> {
Task task = resourceTasks.values().stream().findAny().get();
// Thread has finished working on the task's runnable
assertEquals(1, resourceTasks.size());
assertEquals(1, task.getResourceStats().size());
assertEquals(1, task.getResourceStats().get(threadId).size());
assertFalse(task.getResourceStats().get(threadId).get(0).isActive());
// allocations are completed before the task is cancelled
long expectedArrayAllocationOverhead = 4012688; // Task's memory overhead due to array allocations
long taskCancellationOverhead = 30000; // Task cancellation overhead ~ 30Kb
long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes();
long expectedOverhead = expectedArrayAllocationOverhead + taskCancellationOverhead;
assertTrue(actualTaskMemoryOverhead - expectedOverhead < taskTestContext.memoryConsumptionWhenExecutionStarts);
assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0);
};
startResourceAwareNodesAction(testNodes[0], true, taskTestContext, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
taskTestContext.requestCompleteLatch.countDown();
}
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
taskTestContext.requestCompleteLatch.countDown();
}
});
// Cancel main task
CancelTasksRequest request = new CancelTasksRequest();
request.setReason("Cancelling request to verify Task resource tracking behaviour");
request.setTaskId(new TaskId(testNodes[0].getNodeId(), taskTestContext.mainTask.getId()));
ActionTestUtils.executeBlocking(testNodes[0].transportCancelTasksAction, request);
// Waiting for whole request to complete and return successfully till client
taskTestContext.requestCompleteLatch.await();
assertEquals(0, resourceTasks.size());
assertNull(throwableReference.get());
assertNotNull(responseReference.get());
assertEquals(1, responseReference.get().failureCount());
assertEquals(TaskCancelledException.class, findActualException(responseReference.get().failures().get(0)).getClass());
}
public void testTaskResourceTrackingDisabled() throws Exception {
setup(false, false);
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
TaskTestContext taskTestContext = new TaskTestContext();
Map<Long, Task> resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks();
taskTestContext.operationStartValidator = threadId -> { assertEquals(0, resourceTasks.size()); };
taskTestContext.operationFinishedValidator = threadId -> { assertEquals(0, resourceTasks.size()); };
startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
taskTestContext.requestCompleteLatch.countDown();
}
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
taskTestContext.requestCompleteLatch.countDown();
}
});
// Waiting for whole request to complete and return successfully till client
taskTestContext.requestCompleteLatch.await();
assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get());
}
public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Exception {
setup(true, false);
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
TaskTestContext taskTestContext = new TaskTestContext();
Map<Long, Task> resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks();
taskTestContext.operationStartValidator = threadId -> {
Task task = resourceTasks.values().stream().findAny().get();
// One thread is currently working on task but not finished
assertEquals(1, resourceTasks.size());
assertEquals(1, task.getResourceStats().size());
assertEquals(1, task.getResourceStats().get(threadId).size());
assertTrue(task.getResourceStats().get(threadId).get(0).isActive());
assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos());
assertEquals(0, task.getTotalResourceStats().getMemoryInBytes());
testNodes[0].taskResourceTrackingService.setTaskResourceTrackingEnabled(false);
};
taskTestContext.operationFinishedValidator = threadId -> {
Task task = resourceTasks.values().stream().findAny().get();
// Thread has finished working on the task's runnable
assertEquals(1, resourceTasks.size());
assertEquals(1, task.getResourceStats().size());
assertEquals(1, task.getResourceStats().get(threadId).size());
assertFalse(task.getResourceStats().get(threadId).get(0).isActive());
long expectedArrayAllocationOverhead = 2 * 4012688; // Task's memory overhead due to array allocations
long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes();
assertTrue(actualTaskMemoryOverhead - expectedArrayAllocationOverhead < taskTestContext.memoryConsumptionWhenExecutionStarts);
assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0);
};
startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
taskTestContext.requestCompleteLatch.countDown();
}
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
taskTestContext.requestCompleteLatch.countDown();
}
});
// Waiting for whole request to complete and return successfully till client
taskTestContext.requestCompleteLatch.await();
assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get());
}
public void testTaskResourceTrackingEnabledWhileTaskInProgress() throws Exception {
setup(false, false);
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
TaskTestContext taskTestContext = new TaskTestContext();
Map<Long, Task> resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks();
taskTestContext.operationStartValidator = threadId -> {
assertEquals(0, resourceTasks.size());
testNodes[0].taskResourceTrackingService.setTaskResourceTrackingEnabled(true);
};
taskTestContext.operationFinishedValidator = threadId -> { assertEquals(0, resourceTasks.size()); };
startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
taskTestContext.requestCompleteLatch.countDown();
}
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
taskTestContext.requestCompleteLatch.countDown();
}
});
// Waiting for whole request to complete and return successfully till client
taskTestContext.requestCompleteLatch.await();
assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get());
}
public void testOnDemandRefreshWhileFetchingTasks() throws InterruptedException {
setup(true, false);
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
TaskTestContext taskTestContext = new TaskTestContext();
Map<Long, Task> resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks();
taskTestContext.operationStartValidator = threadId -> {
ListTasksResponse listTasksResponse = ActionTestUtils.executeBlocking(
testNodes[0].transportListTasksAction,
new ListTasksRequest().setActions("internal:resourceAction*").setDetailed(true)
);
TaskInfo taskInfo = listTasksResponse.getTasks().get(1);
assertNotNull(taskInfo.getResourceStats());
assertNotNull(taskInfo.getResourceStats().getResourceUsageInfo());
assertTrue(taskInfo.getResourceStats().getResourceUsageInfo().get("total").getCpuTimeInNanos() > 0);
assertTrue(taskInfo.getResourceStats().getResourceUsageInfo().get("total").getMemoryInBytes() > 0);
};
startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
taskTestContext.requestCompleteLatch.countDown();
}
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
taskTestContext.requestCompleteLatch.countDown();
}
});
// Waiting for whole request to complete and return successfully till client
taskTestContext.requestCompleteLatch.await();
assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get());
}
public void testTaskIdPersistsInThreadContext() throws InterruptedException {
setup(true, true);
final List<Long> taskIdsAddedToThreadContext = new ArrayList<>();
final List<Long> taskIdsRemovedFromThreadContext = new ArrayList<>();
AtomicLong actualTaskIdInThreadContext = new AtomicLong(-1);
AtomicLong expectedTaskIdInThreadContext = new AtomicLong(-2);
((MockTaskManager) testNodes[0].transportService.getTaskManager()).addListener(new MockTaskManagerListener() {
@Override
public void waitForTaskCompletion(Task task) {}
@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {
if (closeableInvoked) {
taskIdsRemovedFromThreadContext.add(task.getId());
} else {
taskIdsAddedToThreadContext.add(task.getId());
}
}
@Override
public void onTaskRegistered(Task task) {}
@Override
public void onTaskUnregistered(Task task) {
if (task.getAction().equals("internal:resourceAction[n]")) {
expectedTaskIdInThreadContext.set(task.getId());
actualTaskIdInThreadContext.set(threadPool.getThreadContext().getTransient(TASK_ID));
}
}
});
TaskTestContext taskTestContext = new TaskTestContext();
startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
taskTestContext.requestCompleteLatch.countDown();
}
@Override
public void onFailure(Exception e) {
taskTestContext.requestCompleteLatch.countDown();
}
});
taskTestContext.requestCompleteLatch.await();
assertEquals(expectedTaskIdInThreadContext.get(), actualTaskIdInThreadContext.get());
assertThat(taskIdsAddedToThreadContext, containsInAnyOrder(taskIdsRemovedFromThreadContext.toArray()));
}
private void setup(boolean resourceTrackingEnabled, boolean useMockTaskManager) {
Settings settings = Settings.builder()
.put("task_resource_tracking.enabled", resourceTrackingEnabled)
.put(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.getKey(), useMockTaskManager)
.build();
setupTestNodes(settings);
connectNodes(testNodes[0]);
runnableTaskListener.set(testNodes[0].taskResourceTrackingService);
}
private Throwable findActualException(Exception e) {
Throwable throwable = e.getCause();
while (throwable.getCause() != null) {
throwable = throwable.getCause();
}
return throwable;
}
private void assertTasksRequestFinishedSuccessfully(int activeResourceTasks, NodesResponse nodesResponse, Throwable throwable) {
assertEquals(0, activeResourceTasks);
assertNull(throwable);
assertNotNull(nodesResponse);
assertEquals(0, nodesResponse.failureCount());
}
}

View File

@ -59,8 +59,10 @@ import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.tasks.TaskCancellationService;
import org.opensearch.tasks.TaskManager;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.tasks.MockTaskManager;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
@ -74,6 +76,7 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import static java.util.Collections.emptyMap;
@ -89,10 +92,12 @@ public abstract class TaskManagerTestCase extends OpenSearchTestCase {
protected ThreadPool threadPool;
protected TestNode[] testNodes;
protected int nodesCount;
protected AtomicReference<RunnableTaskExecutionListener> runnableTaskListener;
@Before
public void setupThreadPool() {
threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName());
runnableTaskListener = new AtomicReference<>();
threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), runnableTaskListener);
}
public void setupTestNodes(Settings settings) {
@ -225,14 +230,22 @@ public abstract class TaskManagerTestCase extends OpenSearchTestCase {
transportService.start();
clusterService = createClusterService(threadPool, discoveryNode.get());
clusterService.addStateApplier(transportService.getTaskManager());
taskResourceTrackingService = new TaskResourceTrackingService(settings, clusterService.getClusterSettings(), threadPool);
transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService);
ActionFilters actionFilters = new ActionFilters(emptySet());
transportListTasksAction = new TransportListTasksAction(clusterService, transportService, actionFilters);
transportListTasksAction = new TransportListTasksAction(
clusterService,
transportService,
actionFilters,
taskResourceTrackingService
);
transportCancelTasksAction = new TransportCancelTasksAction(clusterService, transportService, actionFilters);
transportService.acceptIncomingRequests();
}
public final ClusterService clusterService;
public final TransportService transportService;
public final TaskResourceTrackingService taskResourceTrackingService;
private final SetOnce<DiscoveryNode> discoveryNode = new SetOnce<>();
public final TransportListTasksAction transportListTasksAction;
public final TransportCancelTasksAction transportCancelTasksAction;

View File

@ -91,6 +91,7 @@ import java.util.function.BiConsumer;
import static java.util.Collections.emptyMap;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.sameInstance;
import static org.mockito.Answers.RETURNS_MOCKS;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyString;
@ -224,7 +225,7 @@ public class TransportBulkActionIngestTests extends OpenSearchTestCase {
remoteResponseHandler = ArgumentCaptor.forClass(TransportResponseHandler.class);
// setup services that will be called by action
transportService = mock(TransportService.class);
transportService = mock(TransportService.class, RETURNS_MOCKS);
clusterService = mock(ClusterService.class);
localIngest = true;
// setup nodes for local and remote

View File

@ -48,6 +48,7 @@ import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.sameInstance;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;
public class ThreadContextTests extends OpenSearchTestCase {
@ -154,6 +155,15 @@ public class ThreadContextTests extends OpenSearchTestCase {
assertEquals(1, threadContext.getResponseHeaders().get("baz").size());
}
public void testStashContextWithPreservedTransients() {
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
threadContext.putTransient("foo", "bar");
threadContext.putTransient(TASK_ID, 1);
threadContext.stashContext();
assertNull(threadContext.getTransient("foo"));
assertEquals(1, (int) threadContext.getTransient(TASK_ID));
}
public void testStashWithOrigin() {
final String origin = randomAlphaOfLengthBetween(4, 16);
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);

View File

@ -198,6 +198,7 @@ import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.FetchPhase;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.snapshots.mockstore.MockEventuallyConsistentRepository;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.disruption.DisruptableMockTransport;
import org.opensearch.threadpool.ThreadPool;
@ -1738,6 +1739,8 @@ public class SnapshotResiliencyTests extends OpenSearchTestCase {
final IndexNameExpressionResolver indexNameExpressionResolver = new IndexNameExpressionResolver(
new ThreadContext(Settings.EMPTY)
);
transportService.getTaskManager()
.setTaskResourceTrackingService(new TaskResourceTrackingService(settings, clusterSettings, threadPool));
repositoriesService = new RepositoriesService(
settings,
clusterService,

View File

@ -40,6 +40,7 @@ import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.FakeTcpChannel;
@ -59,6 +60,7 @@ import java.util.Map;
import java.util.Set;
import java.util.concurrent.Phaser;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.everyItem;
@ -67,10 +69,12 @@ import static org.mockito.Mockito.mock;
public class TaskManagerTests extends OpenSearchTestCase {
private ThreadPool threadPool;
private AtomicReference<RunnableTaskExecutionListener> runnableTaskListener;
@Before
public void setupThreadPool() {
threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName());
runnableTaskListener = new AtomicReference<>();
threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), runnableTaskListener);
}
@After

View File

@ -0,0 +1,97 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.tasks;
import org.junit.After;
import org.junit.Before;
import org.opensearch.action.admin.cluster.node.tasks.TransportTasksActionTests;
import org.opensearch.action.search.SearchTask;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import java.util.HashMap;
import java.util.concurrent.atomic.AtomicReference;
import static org.opensearch.tasks.ResourceStats.MEMORY;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;
public class TaskResourceTrackingServiceTests extends OpenSearchTestCase {
private ThreadPool threadPool;
private TaskResourceTrackingService taskResourceTrackingService;
@Before
public void setup() {
threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), new AtomicReference<>());
taskResourceTrackingService = new TaskResourceTrackingService(
Settings.EMPTY,
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
threadPool
);
}
@After
public void terminateThreadPool() {
terminate(threadPool);
}
public void testThreadContextUpdateOnTrackingStart() {
taskResourceTrackingService.setTaskResourceTrackingEnabled(true);
Task task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>());
String key = "KEY";
String value = "VALUE";
// Prepare thread context
threadPool.getThreadContext().putHeader(key, value);
threadPool.getThreadContext().putTransient(key, value);
threadPool.getThreadContext().addResponseHeader(key, value);
ThreadContext.StoredContext storedContext = taskResourceTrackingService.startTracking(task);
// All headers should be preserved and Task Id should also be included in thread context
verifyThreadContextFixedHeaders(key, value);
assertEquals((long) threadPool.getThreadContext().getTransient(TASK_ID), task.getId());
storedContext.restore();
// Post restore only task id should be removed from the thread context
verifyThreadContextFixedHeaders(key, value);
assertNull(threadPool.getThreadContext().getTransient(TASK_ID));
}
public void testStopTrackingHandlesCurrentActiveThread() {
taskResourceTrackingService.setTaskResourceTrackingEnabled(true);
Task task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>());
ThreadContext.StoredContext storedContext = taskResourceTrackingService.startTracking(task);
long threadId = Thread.currentThread().getId();
taskResourceTrackingService.taskExecutionStartedOnThread(task.getId(), threadId);
assertTrue(task.getResourceStats().get(threadId).get(0).isActive());
assertEquals(0, task.getResourceStats().get(threadId).get(0).getResourceUsageInfo().getStatsInfo().get(MEMORY).getTotalValue());
taskResourceTrackingService.stopTracking(task);
// Makes sure stop tracking marks the current active thread inactive and refreshes the resource stats before returning.
assertFalse(task.getResourceStats().get(threadId).get(0).isActive());
assertTrue(task.getResourceStats().get(threadId).get(0).getResourceUsageInfo().getStatsInfo().get(MEMORY).getTotalValue() > 0);
}
private void verifyThreadContextFixedHeaders(String key, String value) {
assertEquals(threadPool.getThreadContext().getHeader(key), value);
assertEquals(threadPool.getThreadContext().getTransient(key), value);
assertEquals(threadPool.getThreadContext().getResponseHeaders().get(key).get(0), value);
}
}

View File

@ -39,6 +39,7 @@ import org.apache.logging.log4j.util.Supplier;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Setting.Property;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskAwareRequest;
import org.opensearch.tasks.TaskManager;
@ -127,6 +128,21 @@ public class MockTaskManager extends TaskManager {
super.waitForTaskCompletion(task, untilInNanos);
}
@Override
public ThreadContext.StoredContext taskExecutionStarted(Task task) {
for (MockTaskManagerListener listener : listeners) {
listener.taskExecutionStarted(task, false);
}
ThreadContext.StoredContext storedContext = super.taskExecutionStarted(task);
return () -> {
for (MockTaskManagerListener listener : listeners) {
listener.taskExecutionStarted(task, true);
}
storedContext.restore();
};
}
public void addListener(MockTaskManagerListener listener) {
listeners.add(listener);
}

View File

@ -43,4 +43,7 @@ public interface MockTaskManagerListener {
void onTaskUnregistered(Task task);
void waitForTaskCompletion(Task task);
void taskExecutionStarted(Task task, Boolean closeableInvoked);
}

View File

@ -40,6 +40,7 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicReference;
public class TestThreadPool extends ThreadPool {
@ -47,12 +48,29 @@ public class TestThreadPool extends ThreadPool {
private volatile boolean returnRejectingExecutor = false;
private volatile ThreadPoolExecutor rejectingExecutor;
public TestThreadPool(
String name,
AtomicReference<RunnableTaskExecutionListener> runnableTaskListener,
ExecutorBuilder<?>... customBuilders
) {
this(name, Settings.EMPTY, runnableTaskListener, customBuilders);
}
public TestThreadPool(String name, ExecutorBuilder<?>... customBuilders) {
this(name, Settings.EMPTY, customBuilders);
}
public TestThreadPool(String name, Settings settings, ExecutorBuilder<?>... customBuilders) {
super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(), customBuilders);
this(name, settings, null, customBuilders);
}
public TestThreadPool(
String name,
Settings settings,
AtomicReference<RunnableTaskExecutionListener> runnableTaskListener,
ExecutorBuilder<?>... customBuilders
) {
super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(), runnableTaskListener, customBuilders);
}
@Override