Fix race condition in KubernetesTaskRunner when task is added to the map (#14643)

Changes:
- Fix race condition in KubernetesTaskRunner introduced by #14435 
- Perform addition and removal from map inside a synchronized block
- Update tests
This commit is contained in:
YongGang 2023-07-27 00:04:36 -07:00 committed by GitHub
parent 7634ac896e
commit 9b88b78ba4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 38 deletions

View File

@ -134,16 +134,18 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
@Override @Override
public ListenableFuture<TaskStatus> run(Task task) public ListenableFuture<TaskStatus> run(Task task)
{ {
return tasks.computeIfAbsent( synchronized (tasks) {
task.getId(), k -> new KubernetesWorkItem(task, exec.submit(() -> runTask(task))) return tasks.computeIfAbsent(task.getId(), k -> new KubernetesWorkItem(task, exec.submit(() -> runTask(task))))
).getResult(); .getResult();
}
} }
protected ListenableFuture<TaskStatus> joinAsync(Task task) protected ListenableFuture<TaskStatus> joinAsync(Task task)
{ {
return tasks.computeIfAbsent( synchronized (tasks) {
task.getId(), k -> new KubernetesWorkItem(task, exec.submit(() -> joinTask(task))) return tasks.computeIfAbsent(task.getId(), k -> new KubernetesWorkItem(task, exec.submit(() -> joinTask(task))))
).getResult(); .getResult();
}
} }
private TaskStatus runTask(Task task) private TaskStatus runTask(Task task)
@ -159,8 +161,10 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
@VisibleForTesting @VisibleForTesting
protected TaskStatus doTask(Task task, boolean run) protected TaskStatus doTask(Task task, boolean run)
{ {
try {
KubernetesPeonLifecycle peonLifecycle = peonLifecycleFactory.build(task); KubernetesPeonLifecycle peonLifecycle = peonLifecycleFactory.build(task);
synchronized (tasks) {
KubernetesWorkItem workItem = tasks.get(task.getId()); KubernetesWorkItem workItem = tasks.get(task.getId());
if (workItem == null) { if (workItem == null) {
@ -172,8 +176,8 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
} }
workItem.setKubernetesPeonLifecycle(peonLifecycle); workItem.setKubernetesPeonLifecycle(peonLifecycle);
}
try {
TaskStatus taskStatus; TaskStatus taskStatus;
if (run) { if (run) {
taskStatus = peonLifecycle.run( taskStatus = peonLifecycle.run(
@ -191,16 +195,16 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
return taskStatus; return taskStatus;
} }
catch (Exception e) { catch (Exception e) {
log.error(e, "Task [%s] execution caught an exception", task.getId()); log.error(e, "Task [%s] execution caught an exception", task.getId());
throw new RuntimeException(e); throw new RuntimeException(e);
} }
finally { finally {
synchronized (tasks) {
tasks.remove(task.getId()); tasks.remove(task.getId());
} }
} }
}
@Override @Override
public void updateStatus(Task task, TaskStatus status) public void updateStatus(Task task, TaskStatus status)
@ -269,17 +273,17 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
@Override @Override
public List<Pair<Task, ListenableFuture<TaskStatus>>> restore() public List<Pair<Task, ListenableFuture<TaskStatus>>> restore()
{ {
List<Pair<Task, ListenableFuture<TaskStatus>>> tasks = new ArrayList<>(); List<Pair<Task, ListenableFuture<TaskStatus>>> restoredTasks = new ArrayList<>();
for (Job job : client.getPeonJobs()) { for (Job job : client.getPeonJobs()) {
try { try {
Task task = adapter.toTask(job); Task task = adapter.toTask(job);
tasks.add(Pair.of(task, joinAsync(task))); restoredTasks.add(Pair.of(task, joinAsync(task)));
} }
catch (IOException e) { catch (IOException e) {
log.error(e, "Error deserializing task from job [%s]", job.getMetadata().getName()); log.error(e, "Error deserializing task from job [%s]", job.getMetadata().getName());
} }
} }
return tasks; return restoredTasks;
} }
@Override @Override
@ -319,7 +323,6 @@ public class KubernetesTaskRunner implements TaskLogStreamer, TaskRunner
return Lists.newArrayList(tasks.values()); return Lists.newArrayList(tasks.values());
} }
@Override @Override
public Optional<ScalingStats> getScalingStats() public Optional<ScalingStats> getScalingStats()
{ {

View File

@ -32,7 +32,6 @@ import org.apache.druid.indexer.TaskStatus;
import org.apache.druid.indexing.common.task.NoopTask; import org.apache.druid.indexing.common.task.NoopTask;
import org.apache.druid.indexing.common.task.Task; import org.apache.druid.indexing.common.task.Task;
import org.apache.druid.indexing.overlord.TaskRunnerWorkItem; import org.apache.druid.indexing.overlord.TaskRunnerWorkItem;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.http.client.HttpClient; import org.apache.druid.java.util.http.client.HttpClient;
import org.apache.druid.java.util.http.client.Request; import org.apache.druid.java.util.http.client.Request;
@ -237,17 +236,17 @@ public class KubernetesTaskRunnerTest extends EasyMockSupport
} }
@Test @Test
public void test_doTask_withoutWorkItem_throwsISE() public void test_doTask_withoutWorkItem_throwsRuntimeException()
{ {
Assert.assertThrows( Assert.assertThrows(
"Task [id] disappeared", "Task [id] disappeared",
ISE.class, RuntimeException.class,
() -> runner.doTask(task, true) () -> runner.doTask(task, true)
); );
} }
@Test @Test
public void test_doTask_whenShutdownRequested_throwsISE() public void test_doTask_whenShutdownRequested_throwsRuntimeException()
{ {
KubernetesWorkItem workItem = new KubernetesWorkItem(task, null); KubernetesWorkItem workItem = new KubernetesWorkItem(task, null);
workItem.shutdown(); workItem.shutdown();
@ -256,7 +255,7 @@ public class KubernetesTaskRunnerTest extends EasyMockSupport
Assert.assertThrows( Assert.assertThrows(
"Task [id] has been shut down", "Task [id] has been shut down",
ISE.class, RuntimeException.class,
() -> runner.doTask(task, true) () -> runner.doTask(task, true)
); );
} }