From 1cf3f4bebe048896547460143978d1786f776038 Mon Sep 17 00:00:00 2001 From: Hardik Bajaj <58038410+hardikbajaj@users.noreply.github.com> Date: Thu, 8 Aug 2024 08:37:27 +0530 Subject: [PATCH] Fix Concurrent Task Insertion in pendingCompletionTaskGroups (#16834) Fix streaming task failures that may arise due to concurrent task insertion in pendingCompletionTaskGroups --- .../supervisor/SeekableStreamSupervisor.java | 68 ++++--- .../SeekableStreamSupervisorStateTest.java | 181 ++++++++++++++++++ 2 files changed, 227 insertions(+), 22 deletions(-) diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/seekablestream/supervisor/SeekableStreamSupervisor.java b/indexing-service/src/main/java/org/apache/druid/indexing/seekablestream/supervisor/SeekableStreamSupervisor.java index ec4de45cac7..a99c782557b 100644 --- a/indexing-service/src/main/java/org/apache/druid/indexing/seekablestream/supervisor/SeekableStreamSupervisor.java +++ b/indexing-service/src/main/java/org/apache/druid/indexing/seekablestream/supervisor/SeekableStreamSupervisor.java @@ -2482,43 +2482,67 @@ public abstract class SeekableStreamSupervisor startingPartitions ) { - final CopyOnWriteArrayList taskGroupList = pendingCompletionTaskGroups.computeIfAbsent( + final CopyOnWriteArrayList taskGroupList = pendingCompletionTaskGroups.compute( groupId, - k -> new CopyOnWriteArrayList<>() + (k, val) -> { + // Creating new pending completion task groups while compute so that read and writes are locked. + // To ensure synchronisatoin across threads, we need to do updates in compute so that we get only one task group for all replica tasks + if (val == null) { + val = new CopyOnWriteArrayList<>(); + } + + boolean isTaskGroupPresent = false; + for (TaskGroup taskGroup : val) { + if (taskGroup.startingSequences.equals(startingPartitions)) { + isTaskGroupPresent = true; + break; + } + } + if (!isTaskGroupPresent) { + log.info("Creating new pending completion task group [%s] for discovered task [%s].", groupId, taskId); + + // reading the minimumMessageTime & maximumMessageTime from the publishing task and setting it here is not necessary as this task cannot + // change to a state where it will read any more events. + // This is a discovered task, so it would not have been assigned closed partitions initially. + TaskGroup newTaskGroup = new TaskGroup( + groupId, + ImmutableMap.copyOf(startingPartitions), + null, + Optional.absent(), + Optional.absent(), + null + ); + + newTaskGroup.tasks.put(taskId, new TaskData()); + newTaskGroup.completionTimeout = DateTimes.nowUtc().plus(ioConfig.getCompletionTimeout()); + + val.add(newTaskGroup); + } + return val; + } ); + for (TaskGroup taskGroup : taskGroupList) { if (taskGroup.startingSequences.equals(startingPartitions)) { if (taskGroup.tasks.putIfAbsent(taskId, new TaskData()) == null) { - log.info("Added discovered task [%s] to existing pending task group [%s]", taskId, groupId); + log.info("Added discovered task [%s] to existing pending completion task group [%s]. PendingCompletionTaskGroup: %s", taskId, groupId, taskGroup.taskIds()); } return; } } + } - log.info("Creating new pending completion task group [%s] for discovered task [%s]", groupId, taskId); - - // reading the minimumMessageTime & maximumMessageTime from the publishing task and setting it here is not necessary as this task cannot - // change to a state where it will read any more events. - // This is a discovered task, so it would not have been assigned closed partitions initially. - TaskGroup newTaskGroup = new TaskGroup( - groupId, - ImmutableMap.copyOf(startingPartitions), - null, - Optional.absent(), - Optional.absent(), - null - ); - - newTaskGroup.tasks.put(taskId, new TaskData()); - newTaskGroup.completionTimeout = DateTimes.nowUtc().plus(ioConfig.getCompletionTimeout()); - - taskGroupList.add(newTaskGroup); + @VisibleForTesting + protected CopyOnWriteArrayList getPendingCompletionTaskGroups(int groupId) + { + return pendingCompletionTaskGroups.get(groupId); } // Sanity check to ensure that tasks have the same sequence name as their task group diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/seekablestream/supervisor/SeekableStreamSupervisorStateTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/seekablestream/supervisor/SeekableStreamSupervisorStateTest.java index 00689cee040..cb395acf66b 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/seekablestream/supervisor/SeekableStreamSupervisorStateTest.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/seekablestream/supervisor/SeekableStreamSupervisorStateTest.java @@ -76,6 +76,7 @@ import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.parsers.JSONPathSpec; import org.apache.druid.java.util.metrics.DruidMonitorSchedulerConfig; @@ -114,8 +115,13 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; +import java.util.concurrent.Callable; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -281,6 +287,181 @@ public class SeekableStreamSupervisorStateTest extends EasyMockSupport verifyAll(); } + @Test + public void testAddDiscoveredTaskToPendingCompletionTaskGroups() throws Exception + { + EasyMock.expect(spec.isSuspended()).andReturn(false).anyTimes(); + EasyMock.expect(recordSupplier.getPartitionIds(STREAM)).andReturn(ImmutableSet.of(SHARD_ID)).anyTimes(); + EasyMock.expect(taskStorage.getActiveTasksByDatasource(DATASOURCE)).andReturn(ImmutableList.of()).anyTimes(); + EasyMock.expect(taskQueue.add(EasyMock.anyObject())).andReturn(true).anyTimes(); + + replayAll(); + ExecutorService threadExecutor = Execs.multiThreaded(3, "my-thread-pool-%d"); + + SeekableStreamSupervisor supervisor = new TestSeekableStreamSupervisor(); + Map startingPartitions = new HashMap<>(); + startingPartitions.put("partition", "offset"); + + // Test concurrent threads adding to same task group + Callable task1 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_1", startingPartitions); + return true; + }; + Callable task2 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_2", startingPartitions); + return true; + }; + Callable task3 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_3", startingPartitions); + return true; + }; + + // Create a list to hold the Callable tasks + List> tasks = new ArrayList<>(); + tasks.add(task1); + tasks.add(task2); + tasks.add(task3); + List> futures = threadExecutor.invokeAll(tasks); + // Wait for all tasks to complete + for (Future future : futures) { + try { + Boolean result = future.get(); + Assert.assertTrue(result); + } + catch (ExecutionException e) { + Assert.fail(); + } + } + CopyOnWriteArrayList taskGroups = supervisor.getPendingCompletionTaskGroups(0); + Assert.assertEquals(1, taskGroups.size()); + Assert.assertEquals(3, taskGroups.get(0).tasks.size()); + + // Test concurrent threads adding to different task groups + task1 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_1", startingPartitions); + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_1", startingPartitions); + return true; + }; + task2 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(2, "task_1", startingPartitions); + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(2, "task_1", startingPartitions); + return true; + }; + task3 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_2", startingPartitions); + return true; + }; + Callable task4 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(2, "task_2", startingPartitions); + return true; + }; + Callable task5 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_3", startingPartitions); + return true; + }; + Callable task6 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_1", startingPartitions); + return true; + }; + + tasks = new ArrayList<>(); + tasks.add(task1); + tasks.add(task2); + tasks.add(task3); + tasks.add(task4); + tasks.add(task5); + tasks.add(task6); + futures = threadExecutor.invokeAll(tasks); + for (Future future : futures) { + try { + Boolean result = future.get(); + Assert.assertTrue(result); + } + catch (ExecutionException e) { + Assert.fail(); + } + } + + taskGroups = supervisor.getPendingCompletionTaskGroups(1); + Assert.assertEquals(1, taskGroups.size()); + Assert.assertEquals(3, taskGroups.get(0).tasks.size()); + + taskGroups = supervisor.getPendingCompletionTaskGroups(2); + Assert.assertEquals(1, taskGroups.size()); + Assert.assertEquals(2, taskGroups.get(0).tasks.size()); + } + + @Test + public void testAddDiscoveredTaskToPendingCompletionMultipleTaskGroups() throws Exception + { + EasyMock.expect(spec.isSuspended()).andReturn(false).anyTimes(); + EasyMock.expect(recordSupplier.getPartitionIds(STREAM)).andReturn(ImmutableSet.of(SHARD_ID)).anyTimes(); + EasyMock.expect(taskStorage.getActiveTasksByDatasource(DATASOURCE)).andReturn(ImmutableList.of()).anyTimes(); + EasyMock.expect(taskQueue.add(EasyMock.anyObject())).andReturn(true).anyTimes(); + + replayAll(); + + // Test adding tasks with same task group and different partition offsets. + SeekableStreamSupervisor supervisor = new TestSeekableStreamSupervisor(); + ExecutorService threadExecutor = Execs.multiThreaded(3, "my-thread-pool-%d"); + Map startingPartiions = new HashMap<>(); + startingPartiions.put("partition", "offset"); + + Map startingPartiions1 = new HashMap<>(); + startingPartiions.put("partition", "offset1"); + + Callable task1 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_1", startingPartiions); + return true; + }; + Callable task2 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_2", startingPartiions); + return true; + }; + Callable task3 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_3", startingPartiions); + return true; + }; + Callable task4 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_7", startingPartiions1); + return true; + }; + Callable task5 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_8", startingPartiions1); + return true; + }; + Callable task6 = () -> { + supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_9", startingPartiions1); + return true; + }; + + List> tasks = new ArrayList<>(); + tasks.add(task1); + tasks.add(task2); + tasks.add(task3); + tasks.add(task4); + tasks.add(task5); + tasks.add(task6); + + List> futures = threadExecutor.invokeAll(tasks); + + for (Future future : futures) { + try { + Boolean result = future.get(); + Assert.assertTrue(result); + } + catch (ExecutionException e) { + Assert.fail(); + } + } + + CopyOnWriteArrayList taskGroups = supervisor.getPendingCompletionTaskGroups(0); + + Assert.assertEquals(2, taskGroups.size()); + Assert.assertEquals(3, taskGroups.get(0).tasks.size()); + Assert.assertEquals(3, taskGroups.get(1).tasks.size()); + } + @Test public void testConnectingToStreamFail() {