diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java index 2884efe1f0b..f06d9697764 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java @@ -215,7 +215,7 @@ public class WorkerMemoryParameters final long minimumBundleFreeMemory = computeMinimumBundleFreeMemory(frameSize, numFramesPerOutputChannel); if (bundleFreeMemory < minimumBundleFreeMemory) { - final long requiredTaskMemory = bundleMemory - bundleFreeMemory + minimumBundleFreeMemory; + final long requiredTaskMemory = (bundleMemory - bundleFreeMemory + minimumBundleFreeMemory) * maxConcurrentStages; throw new MSQException( new NotEnoughMemoryFault( memoryIntrospector.computeJvmMemoryRequiredForTaskMemory(requiredTaskMemory), diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java index 990610af99e..8e467f0ed69 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java @@ -281,6 +281,55 @@ public class WorkerMemoryParametersTest ); } + @Test + public void test_1WorkerInJvm_alone_40Threads_2ConcurrentStages() + { + final int numThreads = 40; + final int frameSize = WorkerMemoryParameters.DEFAULT_FRAME_SIZE; + + final MemoryIntrospectorImpl memoryIntrospector = createMemoryIntrospector(1_250_000_000, 1, numThreads); + final List slices = makeInputSlices(ReadablePartitions.striped(0, 1, numThreads)); + final IntSet broadcastInputs = IntSets.emptySet(); + final ShuffleSpec shuffleSpec = makeSortShuffleSpec(); + + final MSQException e = Assert.assertThrows( + MSQException.class, + () -> WorkerMemoryParameters.createInstance( + memoryIntrospector, + frameSize, + slices, + broadcastInputs, + shuffleSpec, + 2, + 1 + ) + ); + + Assert.assertEquals( + new NotEnoughMemoryFault(2_732_500_000L, 1_250_000_000, 1_000_000_000, 1, 40, 1, 2), + e.getFault() + ); + } + + @Test + public void test_1WorkerInJvm_alone_40Threads_2ConcurrentStages_memoryFromError() + { + // Test with the amount of memory recommended in the error message from + // test_1WorkerInJvm_alone_40Threads_2ConcurrentStages. + final int numThreads = 40; + final int frameSize = WorkerMemoryParameters.DEFAULT_FRAME_SIZE; + + final MemoryIntrospectorImpl memoryIntrospector = createMemoryIntrospector(2_732_500_000L, 1, numThreads); + final List slices = makeInputSlices(ReadablePartitions.striped(0, 1, numThreads)); + final IntSet broadcastInputs = IntSets.emptySet(); + final ShuffleSpec shuffleSpec = makeSortShuffleSpec(); + + Assert.assertEquals( + new WorkerMemoryParameters(13_000_000, frameSize, 1, 2, 250_000, 0), + WorkerMemoryParameters.createInstance(memoryIntrospector, frameSize, slices, broadcastInputs, shuffleSpec, 2, 1) + ); + } + @Test public void test_1WorkerInJvm_200WorkersInCluster_4Threads() {