diff --git a/server/src/test/java/org/apache/druid/server/coordinator/balancer/ReservoirSegmentSamplerTest.java b/server/src/test/java/org/apache/druid/server/coordinator/balancer/ReservoirSegmentSamplerTest.java index 25b18015468..e6387a5a343 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/balancer/ReservoirSegmentSamplerTest.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/balancer/ReservoirSegmentSamplerTest.java @@ -19,6 +19,7 @@ package org.apache.druid.server.coordinator.balancer; +import com.google.common.collect.Lists; import org.apache.druid.client.DruidServer; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.server.coordination.ServerType; @@ -31,16 +32,16 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.IntStream; public class ReservoirSegmentSamplerTest { @@ -55,9 +56,6 @@ public class ReservoirSegmentSamplerTest .withNumPartitions(10) .eachOfSizeInMb(100); - private final Function> GET_SERVED_SEGMENTS - = serverHolder -> serverHolder.getServer().iterateAllSegments(); - @Before public void setUp() { @@ -80,7 +78,7 @@ public class ReservoirSegmentSamplerTest // due to the pseudo-randomness of this method, we may not select a segment every single time no matter what. segmentCountMap.compute( ReservoirSegmentSampler - .pickMovableSegmentsFrom(servers, 1, GET_SERVED_SEGMENTS, Collections.emptySet()) + .pickMovableSegmentsFrom(servers, 1, ServerHolder::getServedSegments, Collections.emptySet()) .get(0).getSegment(), (segment, count) -> count == null ? 1 : count + 1 ); @@ -151,9 +149,16 @@ public class ReservoirSegmentSamplerTest Assert.assertTrue(pickedSegments.containsAll(loadingSegments)); // Pick only loaded segments - pickedSegments = ReservoirSegmentSampler - .pickMovableSegmentsFrom(Arrays.asList(server1, server2), 10, GET_SERVED_SEGMENTS, Collections.emptySet()) - .stream().map(BalancerSegmentHolder::getSegment).collect(Collectors.toSet()); + List pickedHolders = ReservoirSegmentSampler.pickMovableSegmentsFrom( + Arrays.asList(server1, server2), + 10, + ServerHolder::getServedSegments, + Collections.emptySet() + ); + pickedSegments = pickedHolders + .stream() + .map(BalancerSegmentHolder::getSegment) + .collect(Collectors.toSet()); // Verify that only loaded segments are picked Assert.assertEquals(loadedSegments.size(), pickedSegments.size()); @@ -177,7 +182,7 @@ public class ReservoirSegmentSamplerTest List pickedSegments = ReservoirSegmentSampler.pickMovableSegmentsFrom( Arrays.asList(historical, broker), 10, - GET_SERVED_SEGMENTS, + ServerHolder::getServedSegments, Collections.emptySet() ); @@ -206,8 +211,12 @@ public class ReservoirSegmentSamplerTest ); // Try to pick all the segments on the servers - List pickedSegments = ReservoirSegmentSampler - .pickMovableSegmentsFrom(servers, 10, GET_SERVED_SEGMENTS, Collections.singleton(broadcastDatasource)); + List pickedSegments = ReservoirSegmentSampler.pickMovableSegmentsFrom( + servers, + 10, + ServerHolder::getServedSegments, + Collections.singleton(broadcastDatasource) + ); // Verify that none of the broadcast segments are picked Assert.assertEquals(2, pickedSegments.size()); @@ -216,21 +225,83 @@ public class ReservoirSegmentSamplerTest } } - @Test(timeout = 60_000) - public void testNumberOfIterationsToCycleThroughAllSegments() + @Test + public void testSegmentsFromAllServersAreEquallyLikelyToBePicked() { - // The number of runs required for each sample percentage + // Create 4 servers, each having an equal number of segments + final List> subSegmentLists = Lists.partition(segments, segments.size() / 4); + final List servers = IntStream.range(0, 4).mapToObj( + i -> createHistorical("server_" + i, subSegmentLists.get(i).toArray(new DataSegment[0])) + ).collect(Collectors.toList()); + + // Get the distribution of picked segments for different sample percentages + final int[] samplePercentages = {50, 20, 10, 5}; + for (int samplePercentage : samplePercentages) { + final int[] numSegmentsPickedFromServer + = pickSegmentsAndGetPickedCountPerServer(servers, samplePercentage, 50); + + final int totalSegmentsPicked = Arrays.stream(numSegmentsPickedFromServer).sum(); + + // Number of segments picked from each server is ~25% of total + final double expectedPickedSegments = totalSegmentsPicked * 0.25; + final double error = totalSegmentsPicked * 0.02; + for (int pickedSegments : numSegmentsPickedFromServer) { + Assert.assertEquals(expectedPickedSegments, pickedSegments, error); + } + } + } + + @Test + public void testSegmentsFromMorePopulousServerAreMoreLikelyToBePicked() + { + // Create 4 servers, first one having twice as many segments as the rest + final List> subSegmentLists = Lists.partition(segments, segments.size() / 5); + + final List servers = new ArrayList<>(); + List segmentsForServer0 = new ArrayList<>(subSegmentLists.get(0)); + segmentsForServer0.addAll(subSegmentLists.get(1)); + servers.add(createHistorical("server_" + 0, segmentsForServer0)); + + IntStream.range(1, 4).mapToObj( + i -> createHistorical("server_" + i, subSegmentLists.get(i + 1)) + ).forEach(servers::add); + + final int[] samplePercentages = {50, 20, 10, 5}; + for (int samplePercentage : samplePercentages) { + final int[] numSegmentsPickedFromServer + = pickSegmentsAndGetPickedCountPerServer(servers, samplePercentage, 50); + + final int totalSegmentsPicked = Arrays.stream(numSegmentsPickedFromServer).sum(); + + // Number of segments picked from server0 are ~40% of total and + // number of segments picked from other servers are each ~20% of total + double error = totalSegmentsPicked * 0.02; + Assert.assertEquals(totalSegmentsPicked * 0.40, numSegmentsPickedFromServer[0], error); + + for (int serverId = 1; serverId < servers.size(); ++serverId) { + Assert.assertEquals(totalSegmentsPicked * 0.20, numSegmentsPickedFromServer[serverId], error); + } + } + } + + @Test(timeout = 60_000) + public void testNumberOfSamplingsRequiredToPickAllSegments() + { + // The number of sampling iterations required for each sample percentage // remains more or less fixed, even with a larger number of segments final int[] samplePercentages = {100, 50, 10, 5, 1}; final int[] expectedIterations = {1, 20, 100, 200, 1000}; final int[] totalObservedIterations = new int[5]; + + // For every sample percentage, count the minimum number of required samplings for (int i = 0; i < 50; ++i) { for (int j = 0; j < samplePercentages.length; ++j) { - totalObservedIterations[j] += countMinRunsWithSamplePercent(samplePercentages[j]); + totalObservedIterations[j] += countMinRunsToPickAllSegments(samplePercentages[j]); } } + // Compute the avg value from the 50 observations for each sample percentage for (int j = 0; j < samplePercentages.length; ++j) { double avgObservedIterations = totalObservedIterations[j] / 50.0; Assert.assertTrue(avgObservedIterations <= expectedIterations[j]); @@ -244,7 +315,7 @@ public class ReservoirSegmentSamplerTest *

* {@code k = sampleSize = totalNumSegments * samplePercentage} */ - private int countMinRunsWithSamplePercent(int samplePercentage) + private int countMinRunsToPickAllSegments(int samplePercentage) { final int numSegments = segments.size(); final List servers = Arrays.asList( @@ -259,7 +330,7 @@ public class ReservoirSegmentSamplerTest int numIterations = 1; for (; numIterations < 10000; ++numIterations) { ReservoirSegmentSampler - .pickMovableSegmentsFrom(servers, sampleSize, GET_SERVED_SEGMENTS, Collections.emptySet()) + .pickMovableSegmentsFrom(servers, sampleSize, ServerHolder::getServedSegments, Collections.emptySet()) .forEach(holder -> pickedSegments.add(holder.getSegment())); if (pickedSegments.size() >= numSegments) { @@ -270,6 +341,38 @@ public class ReservoirSegmentSamplerTest return numIterations; } + private int[] pickSegmentsAndGetPickedCountPerServer( + List servers, + int samplePercentage, + int numIterations + ) + { + final int numSegmentsToPick = (int) (segments.size() * samplePercentage / 100.0); + final int[] numSegmentsPickedFromServer = new int[servers.size()]; + + for (int i = 0; i < numIterations; ++i) { + List pickedSegments = ReservoirSegmentSampler.pickMovableSegmentsFrom( + servers, + numSegmentsToPick, + ServerHolder::getServedSegments, + Collections.emptySet() + ); + + // Get the number of segments picked from each server + for (BalancerSegmentHolder pickedSegment : pickedSegments) { + int serverIndex = servers.indexOf(pickedSegment.getServer()); + numSegmentsPickedFromServer[serverIndex]++; + } + } + + return numSegmentsPickedFromServer; + } + + private ServerHolder createHistorical(String serverName, List loadedSegments) + { + return createHistorical(serverName, loadedSegments.toArray(new DataSegment[0])); + } + private ServerHolder createHistorical(String serverName, DataSegment... loadedSegments) { final DruidServer server =