MSQ: Allow for worker gaps. (#17277)

In a Dart query, all Historicals are given worker IDs, but not all of them
are going to actually be started or receive work orders. This can create gaps
in the set of workers. For example, workers 1 and 3 could have work assigned
while workers 0 and 2 do not.

This patch updates ControllerStageTracker and WorkerInputs to handle such
gaps, by using the set of actual worker numbers, rather than 0..workerCount,
in various places.
This commit is contained in:
Gian Merlino 2024-10-08 02:37:57 -07:00 committed by GitHub
parent 4fbb129027
commit 06bbdb38ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 439 additions and 59 deletions

View File

@ -59,6 +59,18 @@ public class ReadablePartition
return new ReadablePartition(stageNumber, workerNumbers, partitionNumber);
}
/**
* Returns an output partition that is striped across a set of {@code workerNumbers}.
*/
public static ReadablePartition striped(
final int stageNumber,
final IntSortedSet workerNumbers,
final int partitionNumber
)
{
return new ReadablePartition(stageNumber, workerNumbers, partitionNumber);
}
/**
* Returns an output partition that has been collected onto a single worker.
*/

View File

@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2IntSortedMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import java.util.Collections;
import java.util.List;
@ -39,6 +40,7 @@ import java.util.Map;
@JsonSubTypes(value = {
@JsonSubTypes.Type(name = "collected", value = CollectedReadablePartitions.class),
@JsonSubTypes.Type(name = "striped", value = StripedReadablePartitions.class),
@JsonSubTypes.Type(name = "sparseStriped", value = SparseStripedReadablePartitions.class),
@JsonSubTypes.Type(name = "combined", value = CombinedReadablePartitions.class)
})
public interface ReadablePartitions extends Iterable<ReadablePartition>
@ -59,7 +61,7 @@ public interface ReadablePartitions extends Iterable<ReadablePartition>
/**
* Combines various sets of partitions into a single set.
*/
static CombinedReadablePartitions combine(List<ReadablePartitions> readablePartitions)
static ReadablePartitions combine(List<ReadablePartitions> readablePartitions)
{
return new CombinedReadablePartitions(readablePartitions);
}
@ -68,7 +70,7 @@ public interface ReadablePartitions extends Iterable<ReadablePartition>
* Returns a set of {@code numPartitions} partitions striped across {@code numWorkers} workers: each worker contains
* a "stripe" of each partition.
*/
static StripedReadablePartitions striped(
static ReadablePartitions striped(
final int stageNumber,
final int numWorkers,
final int numPartitions
@ -82,11 +84,36 @@ public interface ReadablePartitions extends Iterable<ReadablePartition>
return new StripedReadablePartitions(stageNumber, numWorkers, partitionNumbers);
}
/**
* Returns a set of {@code numPartitions} partitions striped across {@code workers}: each worker contains
* a "stripe" of each partition.
*/
static ReadablePartitions striped(
final int stageNumber,
final IntSortedSet workers,
final int numPartitions
)
{
final IntAVLTreeSet partitionNumbers = new IntAVLTreeSet();
for (int i = 0; i < numPartitions; i++) {
partitionNumbers.add(i);
}
if (workers.lastInt() == workers.size() - 1) {
// Dense worker set. Use StripedReadablePartitions for compactness (send a single number rather than the
// entire worker set) and for backwards compatibility (older workers cannot understand
// SparseStripedReadablePartitions).
return new StripedReadablePartitions(stageNumber, workers.size(), partitionNumbers);
} else {
return new SparseStripedReadablePartitions(stageNumber, workers, partitionNumbers);
}
}
/**
* Returns a set of partitions that have been collected onto specific workers: each partition is on exactly
* one worker.
*/
static CollectedReadablePartitions collected(
static ReadablePartitions collected(
final int stageNumber,
final Map<Integer, Integer> partitionToWorkerMap
)

View File

@ -0,0 +1,142 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.input.stage;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.Iterators;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import org.apache.druid.msq.input.SlicerUtils;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* Set of partitions striped across a sparse set of {@code workers}. Each worker contains a "stripe" of each partition.
*
* @see StripedReadablePartitions dense version, where workers from [0..N) are all used.
*/
public class SparseStripedReadablePartitions implements ReadablePartitions
{
private final int stageNumber;
private final IntSortedSet workers;
private final IntSortedSet partitionNumbers;
/**
* Constructor. Most callers should use {@link ReadablePartitions#striped(int, int, int)} instead, which takes
* a partition count rather than a set of partition numbers.
*/
public SparseStripedReadablePartitions(
final int stageNumber,
final IntSortedSet workers,
final IntSortedSet partitionNumbers
)
{
this.stageNumber = stageNumber;
this.workers = workers;
this.partitionNumbers = partitionNumbers;
}
@JsonCreator
private SparseStripedReadablePartitions(
@JsonProperty("stageNumber") final int stageNumber,
@JsonProperty("workers") final Set<Integer> workers,
@JsonProperty("partitionNumbers") final Set<Integer> partitionNumbers
)
{
this(stageNumber, new IntAVLTreeSet(workers), new IntAVLTreeSet(partitionNumbers));
}
@Override
public Iterator<ReadablePartition> iterator()
{
return Iterators.transform(
partitionNumbers.iterator(),
partitionNumber -> ReadablePartition.striped(stageNumber, workers, partitionNumber)
);
}
@Override
public List<ReadablePartitions> split(final int maxNumSplits)
{
final List<ReadablePartitions> retVal = new ArrayList<>();
for (List<Integer> entries : SlicerUtils.makeSlicesStatic(partitionNumbers.iterator(), maxNumSplits)) {
if (!entries.isEmpty()) {
retVal.add(new SparseStripedReadablePartitions(stageNumber, workers, new IntAVLTreeSet(entries)));
}
}
return retVal;
}
@JsonProperty
int getStageNumber()
{
return stageNumber;
}
@JsonProperty
IntSortedSet getWorkers()
{
return workers;
}
@JsonProperty
IntSortedSet getPartitionNumbers()
{
return partitionNumbers;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SparseStripedReadablePartitions that = (SparseStripedReadablePartitions) o;
return stageNumber == that.stageNumber
&& Objects.equals(workers, that.workers)
&& Objects.equals(partitionNumbers, that.partitionNumbers);
}
@Override
public int hashCode()
{
return Objects.hash(stageNumber, workers, partitionNumbers);
}
@Override
public String toString()
{
return "StripedReadablePartitions{" +
"stageNumber=" + stageNumber +
", workers=" + workers +
", partitionNumbers=" + partitionNumbers +
'}';
}
}

View File

@ -403,7 +403,7 @@ class ControllerStageTracker
throw new ISE("Stage does not gather result key statistics");
}
if (workerNumber < 0 || workerNumber >= workerCount) {
if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
@ -522,7 +522,7 @@ class ControllerStageTracker
throw new ISE("Stage does not gather result key statistics");
}
if (workerNumber < 0 || workerNumber >= workerCount) {
if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
@ -656,7 +656,7 @@ class ControllerStageTracker
throw new ISE("Stage does not gather result key statistics");
}
if (workerNumber < 0 || workerNumber >= workerCount) {
if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
@ -763,7 +763,7 @@ class ControllerStageTracker
this.resultPartitionBoundaries = clusterByPartitions;
this.resultPartitions = ReadablePartitions.striped(
stageDef.getStageNumber(),
workerCount,
workerInputs.workers(),
clusterByPartitions.size()
);
@ -788,7 +788,7 @@ class ControllerStageTracker
throw DruidException.defensive("Cannot setDoneReadingInput for stage[%s], it is not sorting", stageDef.getId());
}
if (workerNumber < 0 || workerNumber >= workerCount) {
if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber[%s] for stage[%s]", workerNumber, stageDef.getId());
}
@ -830,7 +830,7 @@ class ControllerStageTracker
@SuppressWarnings("unchecked")
boolean setResultsCompleteForWorker(final int workerNumber, final Object resultObject)
{
if (workerNumber < 0 || workerNumber >= workerCount) {
if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
@ -947,14 +947,18 @@ class ControllerStageTracker
resultPartitionBoundaries = maybeResultPartitionBoundaries.valueOrThrow();
resultPartitions = ReadablePartitions.striped(
stageNumber,
workerCount,
workerInputs.workers(),
resultPartitionBoundaries.size()
);
} else if (shuffleSpec.kind() == ShuffleKind.MIX) {
resultPartitionBoundaries = ClusterByPartitions.oneUniversalPartition();
resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount());
} else {
resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount());
if (shuffleSpec.kind() == ShuffleKind.MIX) {
resultPartitionBoundaries = ClusterByPartitions.oneUniversalPartition();
}
resultPartitions = ReadablePartitions.striped(
stageNumber,
workerInputs.workers(),
shuffleSpec.partitionCount()
);
}
} else {
// No reshuffling: retain partitioning from nonbroadcast inputs.

View File

@ -24,7 +24,9 @@ import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.InputSpec;
@ -45,9 +47,9 @@ import java.util.stream.IntStream;
public class WorkerInputs
{
// Worker number -> input number -> input slice.
private final Int2ObjectMap<List<InputSlice>> assignmentsMap;
private final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap;
private WorkerInputs(final Int2ObjectMap<List<InputSlice>> assignmentsMap)
private WorkerInputs(final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap)
{
this.assignmentsMap = assignmentsMap;
}
@ -64,7 +66,7 @@ public class WorkerInputs
)
{
// Split each inputSpec and assign to workers. This list maps worker number -> input number -> input slice.
final Int2ObjectMap<List<InputSlice>> assignmentsMap = new Int2ObjectAVLTreeMap<>();
final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap = new Int2ObjectAVLTreeMap<>();
final int numInputs = stageDef.getInputSpecs().size();
if (numInputs == 0) {
@ -117,8 +119,8 @@ public class WorkerInputs
final ObjectIterator<Int2ObjectMap.Entry<List<InputSlice>>> assignmentsIterator =
assignmentsMap.int2ObjectEntrySet().iterator();
final IntSortedSet nilWorkers = new IntAVLTreeSet();
boolean first = true;
while (assignmentsIterator.hasNext()) {
final Int2ObjectMap.Entry<List<InputSlice>> entry = assignmentsIterator.next();
final List<InputSlice> slices = entry.getValue();
@ -130,20 +132,29 @@ public class WorkerInputs
}
}
// Eliminate workers that have no non-nil, non-broadcast inputs. (Except the first one, because if all input
// is nil, *some* worker has to do *something*.)
final boolean hasNonNilNonBroadcastInput =
// Identify nil workers (workers with no non-broadcast inputs).
final boolean isNilWorker =
IntStream.range(0, numInputs)
.anyMatch(i ->
!slices.get(i).equals(NilInputSlice.INSTANCE) // Non-nil
&& !stageDef.getBroadcastInputNumbers().contains(i) // Non-broadcast
.allMatch(i ->
slices.get(i).equals(NilInputSlice.INSTANCE) // Nil regular input
|| stageDef.getBroadcastInputNumbers().contains(i) // Broadcast
);
if (!first && !hasNonNilNonBroadcastInput) {
assignmentsIterator.remove();
if (isNilWorker) {
nilWorkers.add(entry.getIntKey());
}
}
first = false;
if (nilWorkers.size() == assignmentsMap.size()) {
// All workers have nil regular inputs. Remove all workers exept the first (*some* worker has to do *something*).
final List<InputSlice> firstSlices = assignmentsMap.get(nilWorkers.firstInt());
assignmentsMap.clear();
assignmentsMap.put(nilWorkers.firstInt(), firstSlices);
} else {
// Remove all nil workers.
for (final int nilWorker : nilWorkers) {
assignmentsMap.remove(nilWorker);
}
}
return new WorkerInputs(assignmentsMap);
@ -154,7 +165,7 @@ public class WorkerInputs
return Preconditions.checkNotNull(assignmentsMap.get(workerNumber), "worker [%s]", workerNumber);
}
public IntSet workers()
public IntSortedSet workers()
{
return assignmentsMap.keySet();
}

View File

@ -33,21 +33,24 @@ public class CollectedReadablePartitionsTest
@Test
public void testPartitionToWorkerMap()
{
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals(ImmutableMap.of(0, 1, 1, 2, 2, 1), partitions.getPartitionToWorkerMap());
}
@Test
public void testStageNumber()
{
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals(1, partitions.getStageNumber());
}
@Test
public void testSplit()
{
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals(
ImmutableList.of(
@ -64,7 +67,8 @@ public class CollectedReadablePartitionsTest
final ObjectMapper mapper = TestHelper.makeJsonMapper()
.registerModules(new MSQIndexingModule().getJacksonModules());
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals(
partitions,

View File

@ -31,7 +31,7 @@ import org.junit.Test;
public class CombinedReadablePartitionsTest
{
private static final CombinedReadablePartitions PARTITIONS = ReadablePartitions.combine(
private static final ReadablePartitions PARTITIONS = ReadablePartitions.combine(
ImmutableList.of(
ReadablePartitions.striped(0, 2, 2),
ReadablePartitions.striped(1, 2, 4)

View File

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.input.stage;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.segment.TestHelper;
import org.junit.Assert;
import org.junit.Test;
public class SparseStripedReadablePartitionsTest
{
@Test
public void testPartitionNumbers()
{
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3);
Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers());
}
@Test
public void testWorkers()
{
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3);
Assert.assertEquals(IntSet.of(1, 3), partitions.getWorkers());
}
@Test
public void testStageNumber()
{
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3);
Assert.assertEquals(1, partitions.getStageNumber());
}
@Test
public void testSplit()
{
final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3});
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, workers, 3);
Assert.assertEquals(
ImmutableList.of(
new SparseStripedReadablePartitions(1, workers, new IntAVLTreeSet(new int[]{0, 2})),
new SparseStripedReadablePartitions(1, workers, new IntAVLTreeSet(new int[]{1}))
),
partitions.split(2)
);
}
@Test
public void testSerde() throws Exception
{
final ObjectMapper mapper = TestHelper.makeJsonMapper()
.registerModules(new MSQIndexingModule().getJacksonModules());
final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3});
final ReadablePartitions partitions = ReadablePartitions.striped(1, workers, 3);
Assert.assertEquals(
partitions,
mapper.readValue(
mapper.writeValueAsString(partitions),
ReadablePartitions.class
)
);
}
@Test
public void testEquals()
{
EqualsVerifier.forClass(SparseStripedReadablePartitions.class).usingGetClass().verify();
}
}

View File

@ -26,36 +26,60 @@ import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.segment.TestHelper;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.junit.Assert;
import org.junit.Test;
public class StripedReadablePartitionsTest
{
@Test
public void testFromDenseSet()
{
// Tests that when ReadablePartitions.striped is called with a dense set, we get StripedReadablePartitions.
final IntAVLTreeSet workers = new IntAVLTreeSet();
workers.add(0);
workers.add(1);
final ReadablePartitions readablePartitionsFromSet = ReadablePartitions.striped(1, workers, 3);
MatcherAssert.assertThat(
readablePartitionsFromSet,
CoreMatchers.instanceOf(StripedReadablePartitions.class)
);
Assert.assertEquals(
ReadablePartitions.striped(1, 2, 3),
readablePartitionsFromSet
);
}
@Test
public void testPartitionNumbers()
{
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers());
}
@Test
public void testNumWorkers()
{
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(2, partitions.getNumWorkers());
}
@Test
public void testStageNumber()
{
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(1, partitions.getStageNumber());
}
@Test
public void testSplit()
{
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(
ImmutableList.of(
@ -72,7 +96,7 @@ public class StripedReadablePartitionsTest
final ObjectMapper mapper = TestHelper.makeJsonMapper()
.registerModules(new MSQIndexingModule().getJacksonModules());
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(
partitions,

View File

@ -25,9 +25,11 @@ import it.unimi.dsi.fastutil.ints.Int2IntMaps;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongList;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.error.DruidException;
import org.apache.druid.msq.exec.Limits;
import org.apache.druid.msq.exec.OutputChannelMode;
import org.apache.druid.msq.input.InputSlice;
@ -75,7 +77,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.MAX,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -91,6 +93,35 @@ public class WorkerInputsTest
);
}
@Test
public void test_max_threeInputs_fourWorkers_withGaps()
{
final StageDefinition stageDef =
StageDefinition.builder(0)
.inputs(new TestInputSpec(1, 2, 3))
.maxWorkerCount(4)
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID);
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(new IntAVLTreeSet(new int[]{1, 3, 4, 5}), true),
WorkerAssignmentStrategy.MAX,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
Assert.assertEquals(
ImmutableMap.<Integer, List<InputSlice>>builder()
.put(1, Collections.singletonList(new TestInputSlice(1)))
.put(3, Collections.singletonList(new TestInputSlice(2)))
.put(4, Collections.singletonList(new TestInputSlice(3)))
.put(5, Collections.singletonList(new TestInputSlice()))
.build(),
inputs.assignmentsMap()
);
}
@Test
public void test_max_zeroInputs_fourWorkers()
{
@ -104,7 +135,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.MAX,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -133,7 +164,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -159,7 +190,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -186,7 +217,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -212,7 +243,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -324,7 +355,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -351,7 +382,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(2), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -384,7 +415,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(1), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -411,7 +442,7 @@ public class WorkerInputsTest
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID);
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true));
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true));
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
@ -455,7 +486,7 @@ public class WorkerInputsTest
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID);
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true));
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true));
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
@ -498,7 +529,7 @@ public class WorkerInputsTest
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID);
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true));
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true));
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
@ -585,11 +616,23 @@ public class WorkerInputsTest
private static class TestInputSpecSlicer implements InputSpecSlicer
{
private final IntSortedSet workers;
private final boolean canSliceDynamic;
public TestInputSpecSlicer(boolean canSliceDynamic)
/**
* Create a test slicer.
*
* @param workers Set of workers to consider assigning work to.
* @param canSliceDynamic Whether this slicer can slice dynamically.
*/
public TestInputSpecSlicer(final IntSortedSet workers, final boolean canSliceDynamic)
{
this.workers = workers;
this.canSliceDynamic = canSliceDynamic;
if (workers.isEmpty()) {
throw DruidException.defensive("Need more than one worker in workers[%s]", workers);
}
}
@Override
@ -606,9 +649,9 @@ public class WorkerInputsTest
SlicerUtils.makeSlicesStatic(
testInputSpec.values.iterator(),
i -> i,
maxNumSlices
Math.min(maxNumSlices, workers.size())
);
return makeSlices(assignments);
return makeSlices(workers, assignments);
}
@Override
@ -624,24 +667,39 @@ public class WorkerInputsTest
SlicerUtils.makeSlicesDynamic(
testInputSpec.values.iterator(),
i -> i,
maxNumSlices,
Math.min(maxNumSlices, workers.size()),
maxFilesPerSlice,
maxBytesPerSlice
);
return makeSlices(assignments);
return makeSlices(workers, assignments);
}
private static List<InputSlice> makeSlices(
final IntSortedSet workers,
final List<List<Long>> assignments
)
{
final List<InputSlice> retVal = new ArrayList<>(assignments.size());
for (final List<Long> assignment : assignments) {
retVal.add(new TestInputSlice(new LongArrayList(assignment)));
for (int assignment = 0, workerNumber = 0;
workerNumber <= workers.lastInt() && assignment < assignments.size();
workerNumber++) {
if (workers.contains(workerNumber)) {
retVal.add(new TestInputSlice(new LongArrayList(assignments.get(assignment++))));
} else {
retVal.add(NilInputSlice.INSTANCE);
}
}
return retVal;
}
}
private static IntSortedSet denseWorkers(final int numWorkers)
{
final IntAVLTreeSet workers = new IntAVLTreeSet();
for (int i = 0; i < numWorkers; i++) {
workers.add(i);
}
return workers;
}
}