Assign partitionIds in the same order as bucketIds (#12236)

When `ParallelIndexSupervisorTask` converts `BucketNumberedShardSpecs`
to corresponding `BuildingShardSpecs`, the bucketId order gets lost.
Particularly, for range partitioning, this results in the partitionIds not being in the same order
as increasing partition boundaries.

Changes
- Refactor `ParallelIndexSupervisorTask.groupGenericPartitionLocationsPerPartition()`
This commit is contained in:
Kashif Faraz 2022-02-10 11:08:39 +05:30 committed by GitHub
parent 3ee66bb492
commit 95b388d2d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 233 additions and 52 deletions

View File

@ -28,8 +28,6 @@ import com.google.common.base.Throwables;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multimap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import org.apache.datasketches.hll.HllSketch;
import org.apache.datasketches.hll.Union;
import org.apache.datasketches.memory.Memory;
@ -106,14 +104,17 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BiFunction;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
@ -728,8 +729,8 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
// 2. Partial segment merge phase
// partition (interval, partitionId) -> partition locations
Map<Pair<Interval, Integer>, List<PartitionLocation>> partitionToLocations =
groupGenericPartitionLocationsPerPartition(indexingRunner.getReports());
Map<Partition, List<PartitionLocation>> partitionToLocations =
getPartitionToLocations(indexingRunner.getReports());
final List<PartialSegmentMergeIOConfig> ioConfigs = createGenericMergeIOConfigs(
ingestionSchema.getTuningConfig().getTotalNumMergeTasks(),
partitionToLocations
@ -814,8 +815,8 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
}
// partition (interval, partitionId) -> partition locations
Map<Pair<Interval, Integer>, List<PartitionLocation>> partitionToLocations =
groupGenericPartitionLocationsPerPartition(indexingRunner.getReports());
Map<Partition, List<PartitionLocation>> partitionToLocations =
getPartitionToLocations(indexingRunner.getReports());
final List<PartialSegmentMergeIOConfig> ioConfigs = createGenericMergeIOConfigs(
ingestionSchema.getTuningConfig().getTotalNumMergeTasks(),
partitionToLocations
@ -923,50 +924,58 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
return partitions;
}
private static Map<Pair<Interval, Integer>, List<PartitionLocation>> groupGenericPartitionLocationsPerPartition(
/**
* Creates a map from partition (interval + bucketId) to the corresponding
* PartitionLocations. Note that the bucketId maybe different from the final
* partitionId (refer to {@link BuildingShardSpec} for more details).
*/
static Map<Partition, List<PartitionLocation>> getPartitionToLocations(
Map<String, GeneratedPartitionsReport> subTaskIdToReport
)
{
final Map<Pair<Interval, Integer>, BuildingShardSpec<?>> intervalAndIntegerToShardSpec = new HashMap<>();
final Object2IntMap<Interval> intervalToNextPartitionId = new Object2IntOpenHashMap<>();
final BiFunction<String, PartitionStat, PartitionLocation> createPartitionLocationFunction =
(subtaskId, partitionStat) -> {
final BuildingShardSpec<?> shardSpec = intervalAndIntegerToShardSpec.computeIfAbsent(
Pair.of(partitionStat.getInterval(), partitionStat.getBucketId()),
key -> {
// Lazily determine the partitionId to create packed partitionIds for the core partitions.
// See the Javadoc of BucketNumberedShardSpec for details.
final int partitionId = intervalToNextPartitionId.computeInt(
partitionStat.getInterval(),
((interval, nextPartitionId) -> nextPartitionId == null ? 0 : nextPartitionId + 1)
);
return partitionStat.getSecondaryPartition().convert(partitionId);
}
);
return partitionStat.toPartitionLocation(subtaskId, shardSpec);
};
// Create a map from partition to list of reports (PartitionStat and subTaskId)
final Map<Partition, List<PartitionReport>> partitionToReports = new TreeMap<>(
// Sort by (interval, bucketId) to maintain order of partitionIds within interval
Comparator
.comparingLong((Partition partition) -> partition.getInterval().getStartMillis())
.thenComparingLong(partition -> partition.getInterval().getEndMillis())
.thenComparingInt(Partition::getBucketId)
);
subTaskIdToReport.forEach(
(subTaskId, report) -> report.getPartitionStats().forEach(
partitionStat -> partitionToReports
.computeIfAbsent(Partition.fromStat(partitionStat), p -> new ArrayList<>())
.add(new PartitionReport(subTaskId, partitionStat))
)
);
return groupPartitionLocationsPerPartition(subTaskIdToReport, createPartitionLocationFunction);
}
final Map<Partition, List<PartitionLocation>> partitionToLocations = new HashMap<>();
private static <L extends PartitionLocation>
Map<Pair<Interval, Integer>, List<L>> groupPartitionLocationsPerPartition(
Map<String, ? extends GeneratedPartitionsReport> subTaskIdToReport,
BiFunction<String, PartitionStat, L> createPartitionLocationFunction
)
{
// partition (interval, partitionId) -> partition locations
final Map<Pair<Interval, Integer>, List<L>> partitionToLocations = new HashMap<>();
for (Entry<String, ? extends GeneratedPartitionsReport> entry : subTaskIdToReport.entrySet()) {
final String subTaskId = entry.getKey();
final GeneratedPartitionsReport report = entry.getValue();
for (PartitionStat partitionStat : report.getPartitionStats()) {
final List<L> locationsOfSamePartition = partitionToLocations.computeIfAbsent(
Pair.of(partitionStat.getInterval(), partitionStat.getBucketId()),
k -> new ArrayList<>()
);
locationsOfSamePartition.add(createPartitionLocationFunction.apply(subTaskId, partitionStat));
Interval prevInterval = null;
final AtomicInteger partitionId = new AtomicInteger(0);
for (Entry<Partition, List<PartitionReport>> entry : partitionToReports.entrySet()) {
final Partition partition = entry.getKey();
// Reset the partitionId if this is a new interval
Interval interval = partition.getInterval();
if (!interval.equals(prevInterval)) {
partitionId.set(0);
prevInterval = interval;
}
// Use any PartitionStat of this partition to create a shard spec
final List<PartitionReport> reportsOfPartition = entry.getValue();
final BuildingShardSpec<?> shardSpec = reportsOfPartition
.get(0).getPartitionStat().getSecondaryPartition()
.convert(partitionId.getAndIncrement());
// Create a PartitionLocation for each PartitionStat
List<PartitionLocation> locationsOfPartition = reportsOfPartition
.stream()
.map(report -> report.getPartitionStat().toPartitionLocation(report.getSubTaskId(), shardSpec))
.collect(Collectors.toList());
partitionToLocations.put(partition, locationsOfPartition);
}
return partitionToLocations;
@ -974,7 +983,7 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
private static List<PartialSegmentMergeIOConfig> createGenericMergeIOConfigs(
int totalNumMergeTasks,
Map<Pair<Interval, Integer>, List<PartitionLocation>> partitionToLocations
Map<Partition, List<PartitionLocation>> partitionToLocations
)
{
return createMergeIOConfigs(
@ -987,7 +996,7 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
@VisibleForTesting
static <M extends PartialSegmentMergeIOConfig, L extends PartitionLocation> List<M> createMergeIOConfigs(
int totalNumMergeTasks,
Map<Pair<Interval, Integer>, List<L>> partitionToLocations,
Map<Partition, List<L>> partitionToLocations,
Function<List<L>, M> createPartialSegmentMergeIOConfig
)
{
@ -1001,7 +1010,7 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
// Randomly shuffle partitionIds to evenly distribute partitions of potentially different sizes
// This will be improved once we collect partition stats properly.
// See PartitionStat in GeneratedPartitionsReport.
final List<Pair<Interval, Integer>> partitions = new ArrayList<>(partitionToLocations.keySet());
final List<Partition> partitions = new ArrayList<>(partitionToLocations.keySet());
Collections.shuffle(partitions, ThreadLocalRandom.current());
final List<M> assignedPartitionLocations = new ArrayList<>(numMergeTasks);
@ -1010,7 +1019,7 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
final List<L> assignedToSameTask = partitions
.subList(partitionBoundaries.lhs, partitionBoundaries.rhs)
.stream()
.flatMap(intervalAndPartitionId -> partitionToLocations.get(intervalAndPartitionId).stream())
.flatMap(partition -> partitionToLocations.get(partition).stream())
.collect(Collectors.toList());
assignedPartitionLocations.add(createPartialSegmentMergeIOConfig.apply(assignedToSameTask));
}
@ -1631,4 +1640,91 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
return Response.ok(doGetLiveReports(full)).build();
}
/**
* Represents a partition uniquely identified by an Interval and a bucketId.
*
* @see org.apache.druid.timeline.partition.BucketNumberedShardSpec
*/
static class Partition
{
final Interval interval;
final int bucketId;
private static Partition fromStat(PartitionStat partitionStat)
{
return new Partition(partitionStat.getInterval(), partitionStat.getBucketId());
}
Partition(Interval interval, int bucketId)
{
this.interval = interval;
this.bucketId = bucketId;
}
public int getBucketId()
{
return bucketId;
}
public Interval getInterval()
{
return interval;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Partition that = (Partition) o;
return getBucketId() == that.getBucketId()
&& Objects.equals(getInterval(), that.getInterval());
}
@Override
public int hashCode()
{
return Objects.hash(getInterval(), getBucketId());
}
@Override
public String toString()
{
return "Partition{" +
"interval=" + interval +
", bucketId=" + bucketId +
'}';
}
}
/**
* Encapsulates a {@link PartitionStat} and the subTaskId that generated it.
*/
private static class PartitionReport
{
private final PartitionStat partitionStat;
private final String subTaskId;
PartitionReport(String subTaskId, PartitionStat partitionStat)
{
this.subTaskId = subTaskId;
this.partitionStat = partitionStat;
}
String getSubTaskId()
{
return subTaskId;
}
PartitionStat getPartitionStat()
{
return partitionStat;
}
}
}

View File

@ -31,7 +31,6 @@ import org.apache.druid.indexer.partitions.HashedPartitionsSpec;
import org.apache.druid.indexer.partitions.PartitionsSpec;
import org.apache.druid.indexer.partitions.SingleDimensionPartitionsSpec;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.segment.IndexSpec;
import org.apache.druid.segment.data.CompressionFactory.LongEncodingStrategy;
import org.apache.druid.segment.data.CompressionStrategy;
@ -39,6 +38,7 @@ import org.apache.druid.segment.data.RoaringBitmapSerdeFactory;
import org.apache.druid.segment.indexing.DataSchema;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.timeline.partition.BuildingHashBasedNumberedShardSpec;
import org.apache.druid.timeline.partition.DimensionRangeBucketShardSpec;
import org.apache.druid.timeline.partition.HashPartitionFunction;
import org.easymock.EasyMock;
import org.hamcrest.Matchers;
@ -54,8 +54,11 @@ import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@ -131,14 +134,14 @@ public class ParallelIndexSupervisorTaskTest
);
}
private static Map<Pair<Interval, Integer>, List<PartitionLocation>> createPartitionToLocations(
private static Map<ParallelIndexSupervisorTask.Partition, List<PartitionLocation>> createPartitionToLocations(
int count,
String partitionLocationType
)
{
return IntStream.range(0, count).boxed().collect(
Collectors.toMap(
i -> Pair.of(createInterval(i), i),
i -> new ParallelIndexSupervisorTask.Partition(createInterval(i), i),
i -> Collections.singletonList(createPartitionLocation(i, partitionLocationType))
)
);
@ -335,6 +338,88 @@ public class ParallelIndexSupervisorTaskTest
Assert.assertFalse(ParallelIndexSupervisorTask.isParallelMode(inputSource, tuningConfig));
}
@Test
public void test_getPartitionToLocations_ordersPartitionsCorrectly()
{
final Interval day1 = Intervals.of("2022-01-01/2022-01-02");
final Interval day2 = Intervals.of("2022-01-02/2022-01-03");
final String task1 = "task1";
final String task2 = "task2";
// Create task reports
Map<String, GeneratedPartitionsReport> taskIdToReport = new HashMap<>();
taskIdToReport.put(task1, new GeneratedPartitionsReport(task1, Arrays.asList(
createRangePartitionStat(day1, 1),
createRangePartitionStat(day2, 7),
createRangePartitionStat(day1, 0),
createRangePartitionStat(day2, 1)
)));
taskIdToReport.put(task2, new GeneratedPartitionsReport(task2, Arrays.asList(
createRangePartitionStat(day1, 4),
createRangePartitionStat(day1, 6),
createRangePartitionStat(day2, 1),
createRangePartitionStat(day1, 1)
)));
Map<ParallelIndexSupervisorTask.Partition, List<PartitionLocation>> partitionToLocations
= ParallelIndexSupervisorTask.getPartitionToLocations(taskIdToReport);
Assert.assertEquals(6, partitionToLocations.size());
// Verify that partitionIds are packed and in the same order as bucketIds
verifyPartitionIdAndLocations(day1, 0, partitionToLocations,
0, task1);
verifyPartitionIdAndLocations(day1, 1, partitionToLocations,
1, task1, task2);
verifyPartitionIdAndLocations(day1, 4, partitionToLocations,
2, task2);
verifyPartitionIdAndLocations(day1, 6, partitionToLocations,
3, task2);
verifyPartitionIdAndLocations(day2, 1, partitionToLocations,
0, task1, task2);
verifyPartitionIdAndLocations(day2, 7, partitionToLocations,
1, task1);
}
private PartitionStat createRangePartitionStat(Interval interval, int bucketId)
{
return new DeepStoragePartitionStat(
interval,
new DimensionRangeBucketShardSpec(bucketId, Arrays.asList("dim1", "dim2"), null, null),
new HashMap<>()
);
}
private void verifyPartitionIdAndLocations(
Interval interval,
int bucketId,
Map<ParallelIndexSupervisorTask.Partition, List<PartitionLocation>> partitionToLocations,
int expectedPartitionId,
String... expectedTaskIds
)
{
final ParallelIndexSupervisorTask.Partition partition
= new ParallelIndexSupervisorTask.Partition(interval, bucketId);
List<PartitionLocation> locations = partitionToLocations.get(partition);
Assert.assertEquals(expectedTaskIds.length, locations.size());
final Set<String> observedTaskIds = new HashSet<>();
for (PartitionLocation location : locations) {
Assert.assertEquals(bucketId, location.getBucketId());
Assert.assertEquals(interval, location.getInterval());
Assert.assertEquals(expectedPartitionId, location.getShardSpec().getPartitionNum());
observedTaskIds.add(location.getSubTaskId());
}
// Verify the taskIds of the locations
Assert.assertEquals(
new HashSet<>(Arrays.asList(expectedTaskIds)),
observedTaskIds
);
}
}
}