Fix race condition in KubernetesTaskRunner when task is added to the map (#14643)

Changes:
- Fix race condition in KubernetesTaskRunner introduced by #14435 
- Perform addition and removal from map inside a synchronized block
- Update tests
This commit is contained in:
YongGang 2023-07-27 00:04:36 -07:00 committed by GitHub
parent 7634ac896e
commit 9b88b78ba4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 38 deletions

View File

@ -134,16 +134,18 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
@Override
public ListenableFuture<TaskStatus> run(Task task)
{
return tasks.computeIfAbsent(
task.getId(), k -> new KubernetesWorkItem(task, exec.submit(() -> runTask(task)))
).getResult();
synchronized (tasks) {
return tasks.computeIfAbsent(task.getId(), k -> new KubernetesWorkItem(task, exec.submit(() -> runTask(task))))
.getResult();
}
}
protected ListenableFuture<TaskStatus> joinAsync(Task task)
{
return tasks.computeIfAbsent(
task.getId(), k -> new KubernetesWorkItem(task, exec.submit(() -> joinTask(task)))
).getResult();
synchronized (tasks) {
return tasks.computeIfAbsent(task.getId(), k -> new KubernetesWorkItem(task, exec.submit(() -> joinTask(task))))
.getResult();
}
}
private TaskStatus runTask(Task task)
@ -159,8 +161,10 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
@VisibleForTesting
protected TaskStatus doTask(Task task, boolean run)
{
try {
KubernetesPeonLifecycle peonLifecycle = peonLifecycleFactory.build(task);
synchronized (tasks) {
KubernetesWorkItem workItem = tasks.get(task.getId());
if (workItem == null) {
@ -172,8 +176,8 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
}
workItem.setKubernetesPeonLifecycle(peonLifecycle);
}
try {
TaskStatus taskStatus;
if (run) {
taskStatus = peonLifecycle.run(
@ -191,16 +195,16 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
return taskStatus;
}
catch (Exception e) {
log.error(e, "Task [%s] execution caught an exception", task.getId());
throw new RuntimeException(e);
}
finally {
synchronized (tasks) {
tasks.remove(task.getId());
}
}
}
@Override
public void updateStatus(Task task, TaskStatus status)
@ -269,17 +273,17 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
@Override
public List<Pair<Task, ListenableFuture<TaskStatus>>> restore()
{
List<Pair<Task, ListenableFuture<TaskStatus>>> tasks = new ArrayList<>();
List<Pair<Task, ListenableFuture<TaskStatus>>> restoredTasks = new ArrayList<>();
for (Job job : client.getPeonJobs()) {
try {
Task task = adapter.toTask(job);
tasks.add(Pair.of(task, joinAsync(task)));
restoredTasks.add(Pair.of(task, joinAsync(task)));
}
catch (IOException e) {
log.error(e, "Error deserializing task from job [%s]", job.getMetadata().getName());
}
}
return tasks;
return restoredTasks;
}
@Override
@ -319,7 +323,6 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
return Lists.newArrayList(tasks.values());
}
@Override
public Optional<ScalingStats> getScalingStats()
{

View File

@ -32,7 +32,6 @@ import org.apache.druid.indexer.TaskStatus;
import org.apache.druid.indexing.common.task.NoopTask;
import org.apache.druid.indexing.common.task.Task;
import org.apache.druid.indexing.overlord.TaskRunnerWorkItem;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.http.client.HttpClient;
import org.apache.druid.java.util.http.client.Request;
@ -237,17 +236,17 @@ public class KubernetesTaskRunnerTest extends EasyMockSupport
}
@Test
public void test_doTask_withoutWorkItem_throwsISE()
public void test_doTask_withoutWorkItem_throwsRuntimeException()
{
Assert.assertThrows(
"Task [id] disappeared",
ISE.class,
RuntimeException.class,
() -> runner.doTask(task, true)
);
}
@Test
public void test_doTask_whenShutdownRequested_throwsISE()
public void test_doTask_whenShutdownRequested_throwsRuntimeException()
{
KubernetesWorkItem workItem = new KubernetesWorkItem(task, null);
workItem.shutdown();
@ -256,7 +255,7 @@ public class KubernetesTaskRunnerTest extends EasyMockSupport
Assert.assertThrows(
"Task [id] has been shut down",
ISE.class,
RuntimeException.class,
() -> runner.doTask(task, true)
);
}