diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StripedReadablePartitions.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StripedReadablePartitions.java index bf9fdaa152e..887970efdd3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StripedReadablePartitions.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StripedReadablePartitions.java @@ -41,7 +41,11 @@ public class StripedReadablePartitions implements ReadablePartitions private final int numWorkers; private final IntSortedSet partitionNumbers; - StripedReadablePartitions(final int stageNumber, final int numWorkers, final IntSortedSet partitionNumbers) + /** + * Constructor. Most callers should use {@link ReadablePartitions#striped(int, int, int)} instead, which takes + * a partition count rather than a set of partition numbers. + */ + public StripedReadablePartitions(final int stageNumber, final int numWorkers, final IntSortedSet partitionNumbers) { this.stageNumber = stageNumber; this.numWorkers = numWorkers; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java index 8d1832be17f..18f1f821d9a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java @@ -90,7 +90,7 @@ public enum WorkerAssignmentStrategy final IntSet inputStages = stageDef.getInputStageNumbers(); final OptionalInt maxInputStageWorkerCount = inputStages.intStream().map(stageWorkerCountMap).max(); - final int workerCount = maxInputStageWorkerCount.orElse(1); + final int workerCount = Math.min(stageDef.getMaxWorkerCount(), maxInputStageWorkerCount.orElse(1)); return slicer.sliceStatic(inputSpec, workerCount); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java index bfb4c8ca675..00ccfdee6c1 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java @@ -20,7 +20,10 @@ package org.apache.druid.msq.kernel.controller; import com.google.common.collect.ImmutableMap; +import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2IntMaps; +import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; +import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongList; @@ -31,6 +34,11 @@ import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.InputSpecSlicer; import org.apache.druid.msq.input.NilInputSlice; import org.apache.druid.msq.input.SlicerUtils; +import org.apache.druid.msq.input.stage.ReadablePartitions; +import org.apache.druid.msq.input.stage.StageInputSlice; +import org.apache.druid.msq.input.stage.StageInputSpec; +import org.apache.druid.msq.input.stage.StageInputSpecSlicer; +import org.apache.druid.msq.input.stage.StripedReadablePartitions; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.WorkerAssignmentStrategy; import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory; @@ -216,6 +224,87 @@ public class WorkerInputsTest ); } + @Test + public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_fourWorkerMax() + { + final StageDefinition stageDef = + StageDefinition.builder(1) + .inputs(new StageInputSpec(0)) + .maxWorkerCount(4) + .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) + .build(QUERY_ID); + + final WorkerInputs inputs = WorkerInputs.create( + stageDef, + new Int2IntAVLTreeMap(ImmutableMap.of(0, 2)), + new StageInputSpecSlicer( + new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, ReadablePartitions.striped(0, 2, 3))) + ), + WorkerAssignmentStrategy.AUTO, + Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER + ); + + Assert.assertEquals( + ImmutableMap.>builder() + .put( + 0, + Collections.singletonList( + new StageInputSlice( + 0, + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0, 2})) + ) + ) + ) + .put( + 1, + Collections.singletonList( + new StageInputSlice( + 0, + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{1})) + ) + ) + ) + .build(), + inputs.assignmentsMap() + ); + } + + @Test + public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_oneWorkerMax() + { + final StageDefinition stageDef = + StageDefinition.builder(1) + .inputs(new StageInputSpec(0)) + .maxWorkerCount(1) + .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) + .build(QUERY_ID); + + final WorkerInputs inputs = WorkerInputs.create( + stageDef, + new Int2IntAVLTreeMap(ImmutableMap.of(0, 2)), + new StageInputSpecSlicer( + new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, ReadablePartitions.striped(0, 2, 3))) + ), + WorkerAssignmentStrategy.AUTO, + Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER + ); + + Assert.assertEquals( + ImmutableMap.>builder() + .put( + 0, + Collections.singletonList( + new StageInputSlice( + 0, + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0, 1, 2})) + ) + ) + ) + .build(), + inputs.assignmentsMap() + ); + } + @Test public void test_auto_threeBigInputs_fourWorkers() {