Kill all running tasks when the supervisor task is killed (#7041)

* Kill all running tasks when the supervisor task is killed

* add some docs and simplify

* address comment
This commit is contained in:
Jihoon Son 2019-03-01 11:28:03 -08:00 committed by GitHub
parent 45f12de9ad
commit 06c8229c08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 518 additions and 36 deletions

View File

@ -38,6 +38,7 @@ import org.apache.druid.indexing.common.TaskToolbox;
import org.apache.druid.indexing.common.actions.LockListAction;
import org.apache.druid.indexing.common.actions.LockTryAcquireAction;
import org.apache.druid.indexing.common.actions.TaskActionClient;
import org.apache.druid.indexing.common.config.TaskConfig;
import org.apache.druid.indexing.common.stats.RowIngestionMetersFactory;
import org.apache.druid.indexing.common.task.AbstractTask;
import org.apache.druid.indexing.common.task.IndexTask;
@ -235,6 +236,14 @@ public class ParallelIndexSupervisorTask extends AbstractTask implements ChatHan
return true;
}
@Override
public void stopGracefully(TaskConfig taskConfig)
{
if (runner != null) {
runner.stopGracefully();
}
}
@Override
public TaskStatus run(TaskToolbox toolbox) throws Exception
{

View File

@ -41,6 +41,12 @@ public interface ParallelIndexTaskRunner<T extends Task>
*/
TaskState run() throws Exception;
/**
* Stop this runner gracefully. This method is called when the task is killed.
* See {@link org.apache.druid.indexing.overlord.SingleTaskBackgroundRunner#stop}.
*/
void stopGracefully();
/**
* {@link PushedSegmentsReport} is the report sent by {@link ParallelIndexSubTask}s. The subTasks call this method to
* send their reports after pushing generated segments to deep storage.

View File

@ -84,7 +84,7 @@ public class SinglePhaseParallelIndexTaskRunner implements ParallelIndexTaskRunn
/** subTaskId -> report */
private final ConcurrentHashMap<String, PushedSegmentsReport> segmentsMap = new ConcurrentHashMap<>();
private volatile boolean stopped;
private volatile boolean subTaskScheduleAndMonitorStopped;
private volatile TaskMonitor<ParallelIndexSubTask> taskMonitor;
private int nextSpecId = 0;
@ -111,6 +111,11 @@ public class SinglePhaseParallelIndexTaskRunner implements ParallelIndexTaskRunn
@Override
public TaskState run() throws Exception
{
if (baseFirehoseFactory.getNumSplits() == 0) {
log.warn("There's no input split to process");
return TaskState.SUCCESS;
}
final Iterator<ParallelIndexSubTaskSpec> subTaskSpecIterator = subTaskSpecIterator().iterator();
final long taskStatusCheckingPeriod = ingestionSchema.getTuningConfig().getTaskStatusCheckPeriodMs();
@ -153,7 +158,7 @@ public class SinglePhaseParallelIndexTaskRunner implements ParallelIndexTaskRunn
if (!subTaskSpecIterator.hasNext()) {
// We have no more subTasks to run
if (taskMonitor.getNumRunningTasks() == 0 && taskCompleteEvents.size() == 0) {
stopped = true;
subTaskScheduleAndMonitorStopped = true;
if (taskMonitor.isSucceeded()) {
// Publishing all segments reported so far
publish(toolbox);
@ -182,7 +187,7 @@ public class SinglePhaseParallelIndexTaskRunner implements ParallelIndexTaskRunn
case FAILED:
// TaskMonitor already tried everything it can do for failed tasks. We failed.
state = TaskState.FAILED;
stopped = true;
subTaskScheduleAndMonitorStopped = true;
final TaskStatusPlus lastStatus = taskCompleteEvent.getLastStatus();
if (lastStatus != null) {
log.error("Failed because of the failed sub task[%s]", lastStatus.getId());
@ -202,30 +207,39 @@ public class SinglePhaseParallelIndexTaskRunner implements ParallelIndexTaskRunn
}
}
finally {
log.info("Cleaning up resources");
// Cleanup resources
taskCompleteEvents.clear();
taskMonitor.stop();
if (state != TaskState.SUCCESS) {
log.info(
"This task is finished with [%s] state. Killing [%d] remaining subtasks.",
state,
taskMonitor.getNumRunningTasks()
);
// if this fails, kill all sub tasks
// Note: this doesn't work when this task is killed by users. We need a way for gracefully shutting down tasks
// for resource cleanup.
taskMonitor.killAll();
stopInternal();
if (!state.isComplete()) {
state = TaskState.FAILED;
}
}
return state;
}
@Override
public void stopGracefully()
{
subTaskScheduleAndMonitorStopped = true;
stopInternal();
}
/**
* Stop task scheduling and monitoring, and kill all running tasks.
* This method is thread-safe.
*/
private void stopInternal()
{
log.info("Cleaning up resources");
taskCompleteEvents.clear();
if (taskMonitor != null) {
taskMonitor.stop();
}
}
private boolean isRunning()
{
return !stopped && !Thread.currentThread().isInterrupted();
return !subTaskScheduleAndMonitorStopped && !Thread.currentThread().isInterrupted();
}
@VisibleForTesting
@ -240,6 +254,13 @@ public class SinglePhaseParallelIndexTaskRunner implements ParallelIndexTaskRunn
return ingestionSchema;
}
@VisibleForTesting
@Nullable
TaskMonitor<ParallelIndexSubTask> getTaskMonitor()
{
return taskMonitor;
}
@Override
public void collectReport(PushedSegmentsReport report)
{

View File

@ -19,6 +19,7 @@
package org.apache.druid.indexing.common.task.batch.parallel;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
@ -84,6 +85,11 @@ public class TaskMonitor<T extends Task>
private int numRunningTasks;
private int numSucceededTasks;
private int numFailedTasks;
// This metric is used only for unit tests because the current taskStatus system doesn't track the killed task status.
// Currently, this metric only represents # of killed tasks by ParallelIndexTaskRunner.
// See killAllRunningTasks(), SinglePhaseParallelIndexTaskRunner.run(), and
// SinglePhaseParallelIndexTaskRunner.stopGracefully()
private int numKilledTasks;
private boolean running = false;
@ -169,11 +175,35 @@ public class TaskMonitor<T extends Task>
}
}
/**
* Stop task monitoring and kill all running tasks.
*/
public void stop()
{
synchronized (startStopLock) {
running = false;
taskStatusChecker.shutdownNow();
if (numRunningTasks > 0) {
final Iterator<MonitorEntry> iterator = runningTasks.values().iterator();
while (iterator.hasNext()) {
final MonitorEntry entry = iterator.next();
iterator.remove();
final String taskId = entry.runningTask.getId();
log.info("Request to kill subtask[%s]", taskId);
indexingServiceClient.killTask(taskId);
numRunningTasks--;
numKilledTasks++;
}
if (numRunningTasks > 0) {
log.warn(
"Inconsistent state: numRunningTasks[%d] is still not zero after trying to kill all running tasks.",
numRunningTasks
);
}
}
log.info("Stopped taskMonitor");
}
}
@ -225,27 +255,14 @@ public class TaskMonitor<T extends Task>
}
}
/**
* This method should be called after {@link #stop()} to make sure no additional tasks are submitted.
*/
void killAll()
{
runningTasks.values().forEach(entry -> {
final String taskId = entry.runningTask.getId();
log.info("Request to kill subtask[%s]", taskId);
indexingServiceClient.killTask(taskId);
});
runningTasks.clear();
}
void incrementNumRunningTasks()
private void incrementNumRunningTasks()
{
synchronized (taskCountLock) {
numRunningTasks++;
}
}
void incrementNumSucceededTasks()
private void incrementNumSucceededTasks()
{
synchronized (taskCountLock) {
numRunningTasks--;
@ -254,7 +271,7 @@ public class TaskMonitor<T extends Task>
}
}
void incrementNumFailedTasks()
private void incrementNumFailedTasks()
{
synchronized (taskCountLock) {
numRunningTasks--;
@ -276,6 +293,12 @@ public class TaskMonitor<T extends Task>
}
}
@VisibleForTesting
int getNumKilledTasks()
{
return numKilledTasks;
}
SinglePhaseParallelIndexingProgress getProgress()
{
synchronized (taskCountLock) {
@ -336,7 +359,7 @@ public class TaskMonitor<T extends Task>
@Nullable
private volatile TaskStatusPlus runningStatus;
MonitorEntry(
private MonitorEntry(
SubTaskSpec<T> spec,
T runningTask,
@Nullable TaskStatusPlus runningStatus,

View File

@ -0,0 +1,423 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.indexing.common.task.batch.parallel;
import com.google.common.collect.Iterables;
import org.apache.druid.client.indexing.IndexingServiceClient;
import org.apache.druid.data.input.FiniteFirehoseFactory;
import org.apache.druid.data.input.InputSplit;
import org.apache.druid.data.input.impl.StringInputRowParser;
import org.apache.druid.indexer.TaskState;
import org.apache.druid.indexer.TaskStatus;
import org.apache.druid.indexer.TaskStatusPlus;
import org.apache.druid.indexing.common.TaskToolbox;
import org.apache.druid.indexing.common.actions.TaskActionClient;
import org.apache.druid.indexing.common.task.IndexTaskClientFactory;
import org.apache.druid.indexing.common.task.TaskResource;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.segment.indexing.DataSchema;
import org.apache.druid.segment.indexing.granularity.UniformGranularitySpec;
import org.joda.time.Interval;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.Stream;
public class ParallelIndexSupervisorTaskKillTest extends AbstractParallelIndexSupervisorTaskTest
{
private ExecutorService service;
@Before
public void setup() throws IOException
{
indexingServiceClient = new LocalIndexingServiceClient();
localDeepStorage = temporaryFolder.newFolder("localStorage");
service = Execs.singleThreaded("ParallelIndexSupervisorTaskKillTest-%d");
}
@After
public void teardown()
{
indexingServiceClient.shutdown();
temporaryFolder.delete();
service.shutdownNow();
}
@Test(timeout = 5000L)
public void testStopGracefully() throws Exception
{
final ParallelIndexSupervisorTask task = newTask(
Intervals.of("2017/2018"),
new ParallelIndexIOConfig(
// Sub tasks would run forever
new TestFirehoseFactory(Pair.of(new TestInput(Integer.MAX_VALUE, TaskState.SUCCESS), 4)),
false
)
);
actionClient = createActionClient(task);
toolbox = createTaskToolbox(task);
prepareTaskForLocking(task);
Assert.assertTrue(task.isReady(actionClient));
final Future<TaskState> future = service.submit(() -> task.run(toolbox).getStatusCode());
while (task.getRunner() == null) {
Thread.sleep(100);
}
task.stopGracefully(null);
Assert.assertEquals(TaskState.FAILED, future.get());
final TestParallelIndexTaskRunner runner = (TestParallelIndexTaskRunner) task.getRunner();
Assert.assertTrue(runner.getRunningTaskIds().isEmpty());
// completeSubTaskSpecs should be empty because no task has reported its status to TaskMonitor
Assert.assertTrue(runner.getCompleteSubTaskSpecs().isEmpty());
Assert.assertEquals(4, runner.getTaskMonitor().getNumKilledTasks());
}
@Test(timeout = 5000L)
public void testSubTaskFail() throws Exception
{
final ParallelIndexSupervisorTask task = newTask(
Intervals.of("2017/2018"),
new ParallelIndexIOConfig(
new TestFirehoseFactory(
Pair.of(new TestInput(10L, TaskState.FAILED), 1),
Pair.of(new TestInput(Integer.MAX_VALUE, TaskState.FAILED), 3)
),
false
)
);
actionClient = createActionClient(task);
toolbox = createTaskToolbox(task);
prepareTaskForLocking(task);
Assert.assertTrue(task.isReady(actionClient));
final TaskState state = task.run(toolbox).getStatusCode();
Assert.assertEquals(TaskState.FAILED, state);
final TestParallelIndexTaskRunner runner = (TestParallelIndexTaskRunner) task.getRunner();
Assert.assertTrue(runner.getRunningTaskIds().isEmpty());
final List<SubTaskSpec<ParallelIndexSubTask>> completeSubTaskSpecs = runner.getCompleteSubTaskSpecs();
Assert.assertEquals(1, completeSubTaskSpecs.size());
final TaskHistory<ParallelIndexSubTask> history = runner.getCompleteSubTaskSpecAttemptHistory(
completeSubTaskSpecs.get(0).getId()
);
Assert.assertNotNull(history);
Assert.assertEquals(3, history.getAttemptHistory().size());
for (TaskStatusPlus status : history.getAttemptHistory()) {
Assert.assertEquals(TaskState.FAILED, status.getStatusCode());
}
Assert.assertEquals(3, runner.getTaskMonitor().getNumKilledTasks());
}
private ParallelIndexSupervisorTask newTask(
Interval interval,
ParallelIndexIOConfig ioConfig
)
{
final TestFirehoseFactory firehoseFactory = (TestFirehoseFactory) ioConfig.getFirehoseFactory();
final int numTotalSubTasks = firehoseFactory.getNumSplits();
// set up ingestion spec
final ParallelIndexIngestionSpec ingestionSpec = new ParallelIndexIngestionSpec(
new DataSchema(
"dataSource",
getObjectMapper().convertValue(
new StringInputRowParser(
DEFAULT_PARSE_SPEC,
null
),
Map.class
),
new AggregatorFactory[]{
new LongSumAggregatorFactory("val", "val")
},
new UniformGranularitySpec(
Granularities.DAY,
Granularities.MINUTE,
interval == null ? null : Collections.singletonList(interval)
),
null,
getObjectMapper()
),
ioConfig,
new ParallelIndexTuningConfig(
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
numTotalSubTasks,
null,
null,
null,
null,
null,
null,
null
)
);
// set up test tools
return new TestSupervisorTask(
ingestionSpec,
Collections.emptyMap(),
indexingServiceClient
);
}
private static class TestInput
{
private final long runTime;
private final TaskState finalState;
private TestInput(long runTime, TaskState finalState)
{
this.runTime = runTime;
this.finalState = finalState;
}
}
private static class TestFirehoseFactory implements FiniteFirehoseFactory<StringInputRowParser, TestInput>
{
private final List<InputSplit<TestInput>> splits;
@SafeVarargs
private TestFirehoseFactory(Pair<TestInput, Integer>... inputSpecs)
{
splits = new ArrayList<>();
for (Pair<TestInput, Integer> inputSpec : inputSpecs) {
final int numInputs = inputSpec.rhs;
for (int i = 0; i < numInputs; i++) {
splits.add(new InputSplit<>(new TestInput(inputSpec.lhs.runTime, inputSpec.lhs.finalState)));
}
}
}
private TestFirehoseFactory(InputSplit<TestInput> split)
{
this.splits = Collections.singletonList(split);
}
@Override
public Stream<InputSplit<TestInput>> getSplits()
{
return splits.stream();
}
@Override
public int getNumSplits()
{
return splits.size();
}
@Override
public FiniteFirehoseFactory<StringInputRowParser, TestInput> withSplit(InputSplit<TestInput> split)
{
return new TestFirehoseFactory(split);
}
}
private static class TestSupervisorTask extends TestParallelIndexSupervisorTask
{
private final IndexingServiceClient indexingServiceClient;
private TestSupervisorTask(
ParallelIndexIngestionSpec ingestionSchema,
Map<String, Object> context,
IndexingServiceClient indexingServiceClient
)
{
super(
null,
null,
ingestionSchema,
context,
indexingServiceClient
);
this.indexingServiceClient = indexingServiceClient;
}
@Override
public TaskStatus run(TaskToolbox toolbox) throws Exception
{
setToolbox(toolbox);
setRunner(
new TestRunner(
toolbox,
this,
indexingServiceClient
)
);
return TaskStatus.fromCode(
getId(),
getRunner().run()
);
}
}
private static class TestRunner extends TestParallelIndexTaskRunner
{
private final ParallelIndexSupervisorTask supervisorTask;
private TestRunner(
TaskToolbox toolbox,
ParallelIndexSupervisorTask supervisorTask,
IndexingServiceClient indexingServiceClient
)
{
super(
toolbox,
supervisorTask.getId(),
supervisorTask.getGroupId(),
supervisorTask.getIngestionSchema(),
supervisorTask.getContext(),
indexingServiceClient
);
this.supervisorTask = supervisorTask;
}
@Override
ParallelIndexSubTaskSpec newTaskSpec(InputSplit split)
{
final FiniteFirehoseFactory baseFirehoseFactory = (FiniteFirehoseFactory) getIngestionSchema()
.getIOConfig()
.getFirehoseFactory();
return new TestParallelIndexSubTaskSpec(
supervisorTask.getId() + "_" + getAndIncrementNextSpecId(),
supervisorTask.getGroupId(),
supervisorTask,
new ParallelIndexIngestionSpec(
getIngestionSchema().getDataSchema(),
new ParallelIndexIOConfig(
baseFirehoseFactory.withSplit(split),
getIngestionSchema().getIOConfig().isAppendToExisting()
),
getIngestionSchema().getTuningConfig()
),
supervisorTask.getContext(),
split
);
}
}
private static class TestParallelIndexSubTaskSpec extends ParallelIndexSubTaskSpec
{
private final ParallelIndexSupervisorTask supervisorTask;
private TestParallelIndexSubTaskSpec(
String id,
String groupId,
ParallelIndexSupervisorTask supervisorTask,
ParallelIndexIngestionSpec ingestionSpec,
Map<String, Object> context,
InputSplit inputSplit
)
{
super(id, groupId, supervisorTask.getId(), ingestionSpec, context, inputSplit);
this.supervisorTask = supervisorTask;
}
@Override
public ParallelIndexSubTask newSubTask(int numAttempts)
{
return new TestParallelIndexSubTask(
null,
getGroupId(),
null,
getSupervisorTaskId(),
numAttempts,
getIngestionSpec(),
getContext(),
null,
new LocalParallelIndexTaskClientFactory(supervisorTask)
);
}
}
private static class TestParallelIndexSubTask extends ParallelIndexSubTask
{
private TestParallelIndexSubTask(
@Nullable String id,
String groupId,
TaskResource taskResource,
String supervisorTaskId,
int numAttempts,
ParallelIndexIngestionSpec ingestionSchema,
Map<String, Object> context,
IndexingServiceClient indexingServiceClient,
IndexTaskClientFactory<ParallelIndexTaskClient> taskClientFactory
)
{
super(
id,
groupId,
taskResource,
supervisorTaskId,
numAttempts,
ingestionSchema,
context,
indexingServiceClient,
taskClientFactory
);
}
@Override
public boolean isReady(TaskActionClient taskActionClient)
{
return true;
}
@Override
public TaskStatus run(final TaskToolbox toolbox) throws Exception
{
final TestFirehoseFactory firehoseFactory = (TestFirehoseFactory) getIngestionSchema().getIOConfig()
.getFirehoseFactory();
final TestInput testInput = Iterables.getOnlyElement(firehoseFactory.splits).get();
Thread.sleep(testInput.runTime);
return TaskStatus.fromCode(getId(), testInput.finalState);
}
}
}