Fix Concurrent Task Insertion in pendingCompletionTaskGroups (#16834)

Fix streaming task failures that may arise due to concurrent task insertion in pendingCompletionTaskGroups
This commit is contained in:
Hardik Bajaj 2024-08-08 08:37:27 +05:30 committed by GitHub
parent ceed4a0634
commit 1cf3f4bebe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 227 additions and 22 deletions

View File

@ -2482,43 +2482,67 @@ public abstract class SeekableStreamSupervisor<PartitionIdType, SequenceOffsetTy
);
}
private void addDiscoveredTaskToPendingCompletionTaskGroups(
@VisibleForTesting
protected void addDiscoveredTaskToPendingCompletionTaskGroups(
int groupId,
String taskId,
Map<PartitionIdType, SequenceOffsetType> startingPartitions
)
{
final CopyOnWriteArrayList<TaskGroup> taskGroupList = pendingCompletionTaskGroups.computeIfAbsent(
final CopyOnWriteArrayList<TaskGroup> taskGroupList = pendingCompletionTaskGroups.compute(
groupId,
k -> new CopyOnWriteArrayList<>()
(k, val) -> {
// Creating new pending completion task groups while compute so that read and writes are locked.
// To ensure synchronisatoin across threads, we need to do updates in compute so that we get only one task group for all replica tasks
if (val == null) {
val = new CopyOnWriteArrayList<>();
}
boolean isTaskGroupPresent = false;
for (TaskGroup taskGroup : val) {
if (taskGroup.startingSequences.equals(startingPartitions)) {
isTaskGroupPresent = true;
break;
}
}
if (!isTaskGroupPresent) {
log.info("Creating new pending completion task group [%s] for discovered task [%s].", groupId, taskId);
// reading the minimumMessageTime & maximumMessageTime from the publishing task and setting it here is not necessary as this task cannot
// change to a state where it will read any more events.
// This is a discovered task, so it would not have been assigned closed partitions initially.
TaskGroup newTaskGroup = new TaskGroup(
groupId,
ImmutableMap.copyOf(startingPartitions),
null,
Optional.absent(),
Optional.absent(),
null
);
newTaskGroup.tasks.put(taskId, new TaskData());
newTaskGroup.completionTimeout = DateTimes.nowUtc().plus(ioConfig.getCompletionTimeout());
val.add(newTaskGroup);
}
return val;
}
);
for (TaskGroup taskGroup : taskGroupList) {
if (taskGroup.startingSequences.equals(startingPartitions)) {
if (taskGroup.tasks.putIfAbsent(taskId, new TaskData()) == null) {
log.info("Added discovered task [%s] to existing pending task group [%s]", taskId, groupId);
log.info("Added discovered task [%s] to existing pending completion task group [%s]. PendingCompletionTaskGroup: %s", taskId, groupId, taskGroup.taskIds());
}
return;
}
}
}
log.info("Creating new pending completion task group [%s] for discovered task [%s]", groupId, taskId);
// reading the minimumMessageTime & maximumMessageTime from the publishing task and setting it here is not necessary as this task cannot
// change to a state where it will read any more events.
// This is a discovered task, so it would not have been assigned closed partitions initially.
TaskGroup newTaskGroup = new TaskGroup(
groupId,
ImmutableMap.copyOf(startingPartitions),
null,
Optional.absent(),
Optional.absent(),
null
);
newTaskGroup.tasks.put(taskId, new TaskData());
newTaskGroup.completionTimeout = DateTimes.nowUtc().plus(ioConfig.getCompletionTimeout());
taskGroupList.add(newTaskGroup);
@VisibleForTesting
protected CopyOnWriteArrayList<TaskGroup> getPendingCompletionTaskGroups(int groupId)
{
return pendingCompletionTaskGroups.get(groupId);
}
// Sanity check to ensure that tasks have the same sequence name as their task group

View File

@ -76,6 +76,7 @@ import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.parsers.JSONPathSpec;
import org.apache.druid.java.util.metrics.DruidMonitorSchedulerConfig;
@ -114,8 +115,13 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
@ -281,6 +287,181 @@ public class SeekableStreamSupervisorStateTest extends EasyMockSupport
verifyAll();
}
@Test
public void testAddDiscoveredTaskToPendingCompletionTaskGroups() throws Exception
{
EasyMock.expect(spec.isSuspended()).andReturn(false).anyTimes();
EasyMock.expect(recordSupplier.getPartitionIds(STREAM)).andReturn(ImmutableSet.of(SHARD_ID)).anyTimes();
EasyMock.expect(taskStorage.getActiveTasksByDatasource(DATASOURCE)).andReturn(ImmutableList.of()).anyTimes();
EasyMock.expect(taskQueue.add(EasyMock.anyObject())).andReturn(true).anyTimes();
replayAll();
ExecutorService threadExecutor = Execs.multiThreaded(3, "my-thread-pool-%d");
SeekableStreamSupervisor supervisor = new TestSeekableStreamSupervisor();
Map<String, String> startingPartitions = new HashMap<>();
startingPartitions.put("partition", "offset");
// Test concurrent threads adding to same task group
Callable<Boolean> task1 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_1", startingPartitions);
return true;
};
Callable<Boolean> task2 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_2", startingPartitions);
return true;
};
Callable<Boolean> task3 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_3", startingPartitions);
return true;
};
// Create a list to hold the Callable tasks
List<Callable<Boolean>> tasks = new ArrayList<>();
tasks.add(task1);
tasks.add(task2);
tasks.add(task3);
List<Future<Boolean>> futures = threadExecutor.invokeAll(tasks);
// Wait for all tasks to complete
for (Future<Boolean> future : futures) {
try {
Boolean result = future.get();
Assert.assertTrue(result);
}
catch (ExecutionException e) {
Assert.fail();
}
}
CopyOnWriteArrayList<SeekableStreamSupervisor.TaskGroup> taskGroups = supervisor.getPendingCompletionTaskGroups(0);
Assert.assertEquals(1, taskGroups.size());
Assert.assertEquals(3, taskGroups.get(0).tasks.size());
// Test concurrent threads adding to different task groups
task1 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_1", startingPartitions);
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_1", startingPartitions);
return true;
};
task2 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(2, "task_1", startingPartitions);
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(2, "task_1", startingPartitions);
return true;
};
task3 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_2", startingPartitions);
return true;
};
Callable<Boolean> task4 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(2, "task_2", startingPartitions);
return true;
};
Callable<Boolean> task5 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_3", startingPartitions);
return true;
};
Callable<Boolean> task6 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_1", startingPartitions);
return true;
};
tasks = new ArrayList<>();
tasks.add(task1);
tasks.add(task2);
tasks.add(task3);
tasks.add(task4);
tasks.add(task5);
tasks.add(task6);
futures = threadExecutor.invokeAll(tasks);
for (Future<Boolean> future : futures) {
try {
Boolean result = future.get();
Assert.assertTrue(result);
}
catch (ExecutionException e) {
Assert.fail();
}
}
taskGroups = supervisor.getPendingCompletionTaskGroups(1);
Assert.assertEquals(1, taskGroups.size());
Assert.assertEquals(3, taskGroups.get(0).tasks.size());
taskGroups = supervisor.getPendingCompletionTaskGroups(2);
Assert.assertEquals(1, taskGroups.size());
Assert.assertEquals(2, taskGroups.get(0).tasks.size());
}
@Test
public void testAddDiscoveredTaskToPendingCompletionMultipleTaskGroups() throws Exception
{
EasyMock.expect(spec.isSuspended()).andReturn(false).anyTimes();
EasyMock.expect(recordSupplier.getPartitionIds(STREAM)).andReturn(ImmutableSet.of(SHARD_ID)).anyTimes();
EasyMock.expect(taskStorage.getActiveTasksByDatasource(DATASOURCE)).andReturn(ImmutableList.of()).anyTimes();
EasyMock.expect(taskQueue.add(EasyMock.anyObject())).andReturn(true).anyTimes();
replayAll();
// Test adding tasks with same task group and different partition offsets.
SeekableStreamSupervisor supervisor = new TestSeekableStreamSupervisor();
ExecutorService threadExecutor = Execs.multiThreaded(3, "my-thread-pool-%d");
Map<String, String> startingPartiions = new HashMap<>();
startingPartiions.put("partition", "offset");
Map<String, String> startingPartiions1 = new HashMap<>();
startingPartiions.put("partition", "offset1");
Callable<Boolean> task1 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_1", startingPartiions);
return true;
};
Callable<Boolean> task2 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_2", startingPartiions);
return true;
};
Callable<Boolean> task3 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_3", startingPartiions);
return true;
};
Callable<Boolean> task4 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_7", startingPartiions1);
return true;
};
Callable<Boolean> task5 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_8", startingPartiions1);
return true;
};
Callable<Boolean> task6 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_9", startingPartiions1);
return true;
};
List<Callable<Boolean>> tasks = new ArrayList<>();
tasks.add(task1);
tasks.add(task2);
tasks.add(task3);
tasks.add(task4);
tasks.add(task5);
tasks.add(task6);
List<Future<Boolean>> futures = threadExecutor.invokeAll(tasks);
for (Future<Boolean> future : futures) {
try {
Boolean result = future.get();
Assert.assertTrue(result);
}
catch (ExecutionException e) {
Assert.fail();
}
}
CopyOnWriteArrayList<SeekableStreamSupervisor.TaskGroup> taskGroups = supervisor.getPendingCompletionTaskGroups(0);
Assert.assertEquals(2, taskGroups.size());
Assert.assertEquals(3, taskGroups.get(0).tasks.size());
Assert.assertEquals(3, taskGroups.get(1).tasks.size());
}
@Test
public void testConnectingToStreamFail()
{