Fix superbatch merge last partition boundaries (#9448)

* Fix superbatch merge last partition boundaries

A bug in the computation for the last parallel merge partition could
cause an IndexOutOfBoundsException or precondition failure due to an
empty partition.

* Improve comments and tests
This commit is contained in:
Chi Cao Minh 2020-03-04 10:35:21 -08:00 committed by GitHub
parent 9466ac7c9b
commit 4ed83f6af6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 174 additions and 12 deletions

View File

@ -742,7 +742,8 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
);
}
private static <M extends PartialSegmentMergeIOConfig, L extends PartitionLocation> List<M> createMergeIOConfigs(
@VisibleForTesting
static <M extends PartialSegmentMergeIOConfig, L extends PartitionLocation> List<M> createMergeIOConfigs(
int totalNumMergeTasks,
Map<Pair<Interval, Integer>, List<L>> partitionToLocations,
Function<List<L>, M> createPartialSegmentMergeIOConfig
@ -760,29 +761,43 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
// See PartitionStat in GeneratedPartitionsReport.
final List<Pair<Interval, Integer>> partitions = new ArrayList<>(partitionToLocations.keySet());
Collections.shuffle(partitions, ThreadLocalRandom.current());
final int numPartitionsPerTask = (int) Math.round(partitions.size() / (double) numMergeTasks);
final List<M> assignedPartitionLocations = new ArrayList<>(numMergeTasks);
for (int i = 0; i < numMergeTasks - 1; i++) {
for (int i = 0; i < numMergeTasks; i++) {
Pair<Integer, Integer> partitionBoundaries = getPartitionBoundaries(i, partitions.size(), numMergeTasks);
final List<L> assignedToSameTask = partitions
.subList(i * numPartitionsPerTask, (i + 1) * numPartitionsPerTask)
.subList(partitionBoundaries.lhs, partitionBoundaries.rhs)
.stream()
.flatMap(intervalAndPartitionId -> partitionToLocations.get(intervalAndPartitionId).stream())
.collect(Collectors.toList());
assignedPartitionLocations.add(createPartialSegmentMergeIOConfig.apply(assignedToSameTask));
}
// The last task is assigned all remaining partitions.
final List<L> assignedToSameTask = partitions
.subList((numMergeTasks - 1) * numPartitionsPerTask, partitions.size())
.stream()
.flatMap(intervalAndPartitionId -> partitionToLocations.get(intervalAndPartitionId).stream())
.collect(Collectors.toList());
assignedPartitionLocations.add(createPartialSegmentMergeIOConfig.apply(assignedToSameTask));
return assignedPartitionLocations;
}
/**
* Partition items into as evenly-sized splits as possible.
*
* @param index index of partition
* @param total number of items to partition
* @param splits number of desired partitions
*
* @return partition range: [lhs, rhs)
*/
private static Pair<Integer, Integer> getPartitionBoundaries(int index, int total, int splits)
{
int chunk = total / splits;
int remainder = total % splits;
// Distribute the remainder across the first few partitions. For example total=8 and splits=5, will give partitions
// of sizes (starting from i=0): 2, 2, 2, 1, 1
int start = index * chunk + (index < remainder ? index : remainder);
int stop = start + chunk + (index < remainder ? 1 : 0);
return Pair.of(start, stop);
}
private static void publishSegments(TaskToolbox toolbox, Map<String, PushedSegmentsReport> reportsMap)
throws IOException
{

View File

@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.indexing.common.task.batch.parallel;
import com.google.common.collect.Ordering;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.Pair;
import org.hamcrest.Matchers;
import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@RunWith(Enclosed.class)
public class ParallelIndexSupervisorTaskTest
{
@RunWith(Parameterized.class)
public static class CreateMergeIoConfigsTest
{
private static final int TOTAL_NUM_MERGE_TASKS = 10;
private static final Function<List<HashPartitionLocation>, PartialHashSegmentMergeIOConfig>
CREATE_PARTIAL_SEGMENT_MERGE_IO_CONFIG = PartialHashSegmentMergeIOConfig::new;
@Parameterized.Parameters(name = "count = {0}")
public static Iterable<? extends Object> data()
{
// different scenarios for last (index = 10 - 1 = 9) partition:
return Arrays.asList(
20, // even partitions per task: round(20 / 10) * (10 - 1) = 2 * 9 = 18 < 20
24, // round down: round(24 / 10) * (10 - 1) = 2 * 9 = 18 < 24
25, // round up to greater: round(25 / 10) * (10 - 1) = 3 * 9 = 27 > 25 (index out of bounds)
27 // round up to equal: round(27 / 10) * (10 - 1) = 3 * 9 = 27 == 27 (empty partition)
);
}
@Parameterized.Parameter
public int count;
@Test
public void handlesLastPartitionCorrectly()
{
List<PartialHashSegmentMergeIOConfig> assignedPartitionLocation = createMergeIOConfigs();
assertNoMissingPartitions(count, assignedPartitionLocation);
}
@Test
public void sizesPartitionsEvenly()
{
List<PartialHashSegmentMergeIOConfig> assignedPartitionLocation = createMergeIOConfigs();
List<Integer> actualPartitionSizes = assignedPartitionLocation.stream()
.map(i -> i.getPartitionLocations().size())
.collect(Collectors.toList());
List<Integer> sortedPartitionSizes = Ordering.natural().sortedCopy(actualPartitionSizes);
int minPartitionSize = sortedPartitionSizes.get(0);
int maxPartitionSize = sortedPartitionSizes.get(sortedPartitionSizes.size() - 1);
int partitionSizeRange = maxPartitionSize - minPartitionSize;
Assert.assertThat(
"partition sizes = " + actualPartitionSizes,
partitionSizeRange,
Matchers.is(Matchers.both(Matchers.greaterThanOrEqualTo(0)).and(Matchers.lessThanOrEqualTo(1)))
);
}
private List<PartialHashSegmentMergeIOConfig> createMergeIOConfigs()
{
return ParallelIndexSupervisorTask.createMergeIOConfigs(
TOTAL_NUM_MERGE_TASKS,
createPartitionToLocations(count),
CREATE_PARTIAL_SEGMENT_MERGE_IO_CONFIG
);
}
private static Map<Pair<Interval, Integer>, List<HashPartitionLocation>> createPartitionToLocations(int count)
{
return IntStream.range(0, count).boxed().collect(
Collectors.toMap(
i -> Pair.of(createInterval(i), i),
i -> Collections.singletonList(createPartitionLocation(i))
)
);
}
private static HashPartitionLocation createPartitionLocation(int id)
{
return new HashPartitionLocation(
"host",
0,
false,
"subTaskId",
createInterval(id),
id
);
}
private static Interval createInterval(int id)
{
return Intervals.utc(id, id + 1);
}
private static void assertNoMissingPartitions(
int count,
List<PartialHashSegmentMergeIOConfig> assignedPartitionLocation
)
{
List<Integer> expectedIds = IntStream.range(0, count).boxed().collect(Collectors.toList());
List<Integer> actualIds = assignedPartitionLocation.stream()
.flatMap(
i -> i.getPartitionLocations()
.stream()
.map(HashPartitionLocation::getPartitionId)
)
.sorted()
.collect(Collectors.toList());
Assert.assertEquals(expectedIds, actualIds);
}
}
}