Update ClusterByStatisticsCollectorImpl to use bytes instead of keys (#12998)

* Update clusterByStatistics to use bytes instead of keys

* Address review comments

* Resolve checkstyle

* Increase test coverage

* Update test

* Update thresholds

* Update retained keys function

* Update docs

* Fix spelling
This commit is contained in:
Adarsh Sanjeev 2022-10-03 12:08:23 +05:30 committed by GitHub
parent ebfe1c0c90
commit 92d2633ae6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 261 additions and 90 deletions

View File

@ -252,6 +252,9 @@ Worker tasks use both JVM heap memory and off-heap ("direct") memory.
On Peons launched by Middle Managers, the bulk of the JVM heap (75%) is split up into two bundles of equal size: one On Peons launched by Middle Managers, the bulk of the JVM heap (75%) is split up into two bundles of equal size: one
processor bundle and one worker bundle. Each one comprises 37.5% of the available JVM heap. processor bundle and one worker bundle. Each one comprises 37.5% of the available JVM heap.
Depending on the type of query, each worker and controller task can use a sketch for generating partition boundaries.
Each sketch uses at most approximately 300 MB.
The processor memory bundle is used for query processing and segment generation. Each processor bundle must also The processor memory bundle is used for query processing and segment generation. Each processor bundle must also
provides space to buffer I/O between stages. Specifically, each downstream stage requires 1 MB of buffer space for each provides space to buffer I/O between stages. Specifically, each downstream stage requires 1 MB of buffer space for each
upstream worker. For example, if you have 100 workers running in stage 0, and stage 1 reads from stage 0, then each upstream worker. For example, if you have 100 workers running in stage 0, and stage 1 reads from stage 0, then each

View File

@ -74,7 +74,7 @@ import java.util.function.Supplier;
*/ */
public class StageDefinition public class StageDefinition
{ {
private static final int PARTITION_STATS_MAX_KEYS = 2 << 15; // Avoid immediate downsample of single-bucket collectors private static final int PARTITION_STATS_MAX_BYTES = 300_000_000; // Avoid immediate downsample of single-bucket collectors
private static final int PARTITION_STATS_MAX_BUCKETS = 5_000; // Limit for TooManyBuckets private static final int PARTITION_STATS_MAX_BUCKETS = 5_000; // Limit for TooManyBuckets
private static final int MAX_PARTITIONS = 25_000; // Limit for TooManyPartitions private static final int MAX_PARTITIONS = 25_000; // Limit for TooManyPartitions
@ -289,7 +289,7 @@ public class StageDefinition
return ClusterByStatisticsCollectorImpl.create( return ClusterByStatisticsCollectorImpl.create(
shuffleSpec.getClusterBy(), shuffleSpec.getClusterBy(),
signature, signature,
PARTITION_STATS_MAX_KEYS, PARTITION_STATS_MAX_BYTES,
PARTITION_STATS_MAX_BUCKETS, PARTITION_STATS_MAX_BUCKETS,
shuffleSpec.doesAggregateByClusterKey(), shuffleSpec.doesAggregateByClusterKey(),
shuffleCheckHasMultipleValues shuffleCheckHasMultipleValues

View File

@ -56,17 +56,15 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
private final boolean[] hasMultipleValues; private final boolean[] hasMultipleValues;
// This can be reworked to accommodate maxSize instead of maxRetainedKeys to account for the skewness in the size of hte private final int maxRetainedBytes;
// keys depending on the datasource
private final int maxRetainedKeys;
private final int maxBuckets; private final int maxBuckets;
private int totalRetainedKeys; private double totalRetainedBytes;
private ClusterByStatisticsCollectorImpl( private ClusterByStatisticsCollectorImpl(
final ClusterBy clusterBy, final ClusterBy clusterBy,
final RowKeyReader keyReader, final RowKeyReader keyReader,
final KeyCollectorFactory<?, ?> keyCollectorFactory, final KeyCollectorFactory<?, ?> keyCollectorFactory,
final int maxRetainedKeys, final int maxRetainedBytes,
final int maxBuckets, final int maxBuckets,
final boolean checkHasMultipleValues final boolean checkHasMultipleValues
) )
@ -74,21 +72,21 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
this.clusterBy = clusterBy; this.clusterBy = clusterBy;
this.keyReader = keyReader; this.keyReader = keyReader;
this.keyCollectorFactory = keyCollectorFactory; this.keyCollectorFactory = keyCollectorFactory;
this.maxRetainedKeys = maxRetainedKeys; this.maxRetainedBytes = maxRetainedBytes;
this.buckets = new TreeMap<>(clusterBy.bucketComparator()); this.buckets = new TreeMap<>(clusterBy.bucketComparator());
this.maxBuckets = maxBuckets; this.maxBuckets = maxBuckets;
this.checkHasMultipleValues = checkHasMultipleValues; this.checkHasMultipleValues = checkHasMultipleValues;
this.hasMultipleValues = checkHasMultipleValues ? new boolean[clusterBy.getColumns().size()] : null; this.hasMultipleValues = checkHasMultipleValues ? new boolean[clusterBy.getColumns().size()] : null;
if (maxBuckets > maxRetainedKeys) { if (maxBuckets > maxRetainedBytes) {
throw new IAE("maxBuckets[%s] cannot be larger than maxRetainedKeys[%s]", maxBuckets, maxRetainedKeys); throw new IAE("maxBuckets[%s] cannot be larger than maxRetainedBytes[%s]", maxBuckets, maxRetainedBytes);
} }
} }
public static ClusterByStatisticsCollector create( public static ClusterByStatisticsCollector create(
final ClusterBy clusterBy, final ClusterBy clusterBy,
final RowSignature signature, final RowSignature signature,
final int maxRetainedKeys, final int maxRetainedBytes,
final int maxBuckets, final int maxBuckets,
final boolean aggregate, final boolean aggregate,
final boolean checkHasMultipleValues final boolean checkHasMultipleValues
@ -101,7 +99,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
clusterBy, clusterBy,
keyReader, keyReader,
keyCollectorFactory, keyCollectorFactory,
maxRetainedKeys, maxRetainedBytes,
maxBuckets, maxBuckets,
checkHasMultipleValues checkHasMultipleValues
); );
@ -126,8 +124,8 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
bucketHolder.keyCollector.add(key, weight); bucketHolder.keyCollector.add(key, weight);
totalRetainedKeys += bucketHolder.updateRetainedKeys(); totalRetainedBytes += bucketHolder.updateRetainedBytes();
if (totalRetainedKeys > maxRetainedKeys) { if (totalRetainedBytes > maxRetainedBytes) {
downSample(); downSample();
} }
@ -147,15 +145,15 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
//noinspection rawtypes, unchecked //noinspection rawtypes, unchecked
((KeyCollector) bucketHolder.keyCollector).addAll(otherBucketEntry.getValue().keyCollector); ((KeyCollector) bucketHolder.keyCollector).addAll(otherBucketEntry.getValue().keyCollector);
totalRetainedKeys += bucketHolder.updateRetainedKeys(); totalRetainedBytes += bucketHolder.updateRetainedBytes();
if (totalRetainedKeys > maxRetainedKeys) { if (totalRetainedBytes > maxRetainedBytes) {
downSample(); downSample();
} }
} }
if (checkHasMultipleValues) { if (checkHasMultipleValues) {
for (int i = 0; i < clusterBy.getColumns().size(); i++) { for (int i = 0; i < clusterBy.getColumns().size(); i++) {
hasMultipleValues[i] |= that.hasMultipleValues[i]; hasMultipleValues[i] = hasMultipleValues[i] || that.hasMultipleValues[i];
} }
} }
} else { } else {
@ -178,8 +176,8 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
//noinspection rawtypes, unchecked //noinspection rawtypes, unchecked
((KeyCollector) bucketHolder.keyCollector).addAll(otherKeyCollector); ((KeyCollector) bucketHolder.keyCollector).addAll(otherKeyCollector);
totalRetainedKeys += bucketHolder.updateRetainedKeys(); totalRetainedBytes += bucketHolder.updateRetainedBytes();
if (totalRetainedKeys > maxRetainedKeys) { if (totalRetainedBytes > maxRetainedBytes) {
downSample(); downSample();
} }
} }
@ -221,7 +219,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
public ClusterByStatisticsCollector clear() public ClusterByStatisticsCollector clear()
{ {
buckets.clear(); buckets.clear();
totalRetainedKeys = 0; totalRetainedBytes = 0;
return this; return this;
} }
@ -232,7 +230,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
throw new IAE("Target weight must be positive"); throw new IAE("Target weight must be positive");
} }
assertRetainedKeyCountsAreTrackedCorrectly(); assertRetainedByteCountsAreTrackedCorrectly();
if (buckets.isEmpty()) { if (buckets.isEmpty()) {
return ClusterByPartitions.oneUniversalPartition(); return ClusterByPartitions.oneUniversalPartition();
@ -315,7 +313,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
@Override @Override
public ClusterByStatisticsSnapshot snapshot() public ClusterByStatisticsSnapshot snapshot()
{ {
assertRetainedKeyCountsAreTrackedCorrectly(); assertRetainedByteCountsAreTrackedCorrectly();
final List<ClusterByStatisticsSnapshot.Bucket> bucketSnapshots = new ArrayList<>(); final List<ClusterByStatisticsSnapshot.Bucket> bucketSnapshots = new ArrayList<>();
@ -365,20 +363,20 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
} }
/** /**
* Reduce the number of retained keys by about half, if possible. May reduce by less than that, or keep the * Reduce the number of retained bytes by about half, if possible. May reduce by less than that, or keep the
* number the same, if downsampling is not possible. (For example: downsampling is not possible if all buckets * number the same, if downsampling is not possible. (For example: downsampling is not possible if all buckets
* have been downsampled all the way to one key each.) * have been downsampled all the way to one key each.)
*/ */
private void downSample() private void downSample()
{ {
int newTotalRetainedKeys = totalRetainedKeys; double newTotalRetainedBytes = totalRetainedBytes;
final int targetTotalRetainedKeys = totalRetainedKeys / 2; final double targetTotalRetainedBytes = totalRetainedBytes / 2;
final List<BucketHolder> sortedHolders = new ArrayList<>(buckets.size()); final List<BucketHolder> sortedHolders = new ArrayList<>(buckets.size());
// Only consider holders with more than one retained key. Holders with a single retained key cannot be downsampled. // Only consider holders with more than one retained key. Holders with a single retained key cannot be downsampled.
for (final BucketHolder holder : buckets.values()) { for (final BucketHolder holder : buckets.values()) {
if (holder.retainedKeys > 1) { if (holder.keyCollector.estimatedRetainedKeys() > 1) {
sortedHolders.add(holder); sortedHolders.add(holder);
} }
} }
@ -386,54 +384,54 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
// Downsample least-dense buckets first. (They're less likely to need high resolution.) // Downsample least-dense buckets first. (They're less likely to need high resolution.)
sortedHolders.sort( sortedHolders.sort(
Comparator.comparing((BucketHolder holder) -> Comparator.comparing((BucketHolder holder) ->
(double) holder.keyCollector.estimatedTotalWeight() / holder.retainedKeys) (double) holder.keyCollector.estimatedTotalWeight() / holder.keyCollector.estimatedRetainedKeys())
); );
int i = 0; int i = 0;
while (i < sortedHolders.size() && newTotalRetainedKeys > targetTotalRetainedKeys) { while (i < sortedHolders.size() && newTotalRetainedBytes > targetTotalRetainedBytes) {
final BucketHolder bucketHolder = sortedHolders.get(i); final BucketHolder bucketHolder = sortedHolders.get(i);
// Ignore false return, because we wrap all collectors in DelegateOrMinKeyCollector and can be assured that // Ignore false return, because we wrap all collectors in DelegateOrMinKeyCollector and can be assured that
// it will downsample all the way to one if needed. Can't do better than that. // it will downsample all the way to one if needed. Can't do better than that.
bucketHolder.keyCollector.downSample(); bucketHolder.keyCollector.downSample();
newTotalRetainedKeys += bucketHolder.updateRetainedKeys(); newTotalRetainedBytes += bucketHolder.updateRetainedBytes();
if (i == sortedHolders.size() - 1 || sortedHolders.get(i + 1).retainedKeys > bucketHolder.retainedKeys) { if (i == sortedHolders.size() - 1 || sortedHolders.get(i + 1).retainedBytes > bucketHolder.retainedBytes) {
i++; i++;
} }
} }
totalRetainedKeys = newTotalRetainedKeys; totalRetainedBytes = newTotalRetainedBytes;
} }
private void assertRetainedKeyCountsAreTrackedCorrectly() private void assertRetainedByteCountsAreTrackedCorrectly()
{ {
// Check cached value of retainedKeys in each holder. // Check cached value of retainedKeys in each holder.
assert buckets.values() assert buckets.values()
.stream() .stream()
.allMatch(holder -> holder.retainedKeys == holder.keyCollector.estimatedRetainedKeys()); .allMatch(holder -> holder.retainedBytes == holder.keyCollector.estimatedRetainedBytes());
// Check cached value of totalRetainedKeys. // Check cached value of totalRetainedBytes.
assert totalRetainedKeys == assert totalRetainedBytes ==
buckets.values().stream().mapToInt(holder -> holder.keyCollector.estimatedRetainedKeys()).sum(); buckets.values().stream().mapToDouble(holder -> holder.keyCollector.estimatedRetainedBytes()).sum();
} }
private static class BucketHolder private static class BucketHolder
{ {
private final KeyCollector<?> keyCollector; private final KeyCollector<?> keyCollector;
private int retainedKeys; private double retainedBytes;
public BucketHolder(final KeyCollector<?> keyCollector) public BucketHolder(final KeyCollector<?> keyCollector)
{ {
this.keyCollector = keyCollector; this.keyCollector = keyCollector;
this.retainedKeys = keyCollector.estimatedRetainedKeys(); this.retainedBytes = keyCollector.estimatedRetainedBytes();
} }
public int updateRetainedKeys() public double updateRetainedBytes()
{ {
final int newRetainedKeys = keyCollector.estimatedRetainedKeys(); final double newRetainedBytes = keyCollector.estimatedRetainedBytes();
final int difference = newRetainedKeys - retainedKeys; final double difference = newRetainedBytes - retainedBytes;
retainedKeys = newRetainedKeys; retainedBytes = newRetainedBytes;
return difference; return difference;
} }
} }

View File

@ -127,6 +127,16 @@ public class DelegateOrMinKeyCollector<TDelegate extends KeyCollector<TDelegate>
} }
} }
@Override
public double estimatedRetainedBytes()
{
if (delegate != null) {
return delegate.estimatedRetainedBytes();
} else {
return minKey != null ? minKey.getNumberOfBytes() : 0;
}
}
@Override @Override
public boolean downSample() public boolean downSample()
{ {

View File

@ -43,8 +43,8 @@ import java.util.Map;
*/ */
public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector> public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
{ {
static final int INITIAL_MAX_KEYS = 2 << 15 /* 65,536 */; static final int INITIAL_MAX_BYTES = 134_217_728;
static final int SMALLEST_MAX_KEYS = 16; static final int SMALLEST_MAX_BYTES = 5000;
private static final int MISSING_KEY_WEIGHT = 0; private static final int MISSING_KEY_WEIGHT = 0;
private final Comparator<RowKey> comparator; private final Comparator<RowKey> comparator;
@ -71,7 +71,8 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
* collector type, which is based on a more solid statistical foundation. * collector type, which is based on a more solid statistical foundation.
*/ */
private final Object2LongSortedMap<RowKey> retainedKeys; private final Object2LongSortedMap<RowKey> retainedKeys;
private int maxKeys; private int maxBytes;
private int retainedBytes;
/** /**
* Each key is retained with probability 2^(-spaceReductionFactor). This value is incremented on calls to * Each key is retained with probability 2^(-spaceReductionFactor). This value is incremented on calls to
@ -92,7 +93,7 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
this.comparator = Preconditions.checkNotNull(comparator, "comparator"); this.comparator = Preconditions.checkNotNull(comparator, "comparator");
this.retainedKeys = Preconditions.checkNotNull(retainedKeys, "retainedKeys"); this.retainedKeys = Preconditions.checkNotNull(retainedKeys, "retainedKeys");
this.retainedKeys.defaultReturnValue(MISSING_KEY_WEIGHT); this.retainedKeys.defaultReturnValue(MISSING_KEY_WEIGHT);
this.maxKeys = INITIAL_MAX_KEYS; this.maxBytes = INITIAL_MAX_BYTES;
this.spaceReductionFactor = spaceReductionFactor; this.spaceReductionFactor = spaceReductionFactor;
this.totalWeightUnadjusted = 0; this.totalWeightUnadjusted = 0;
@ -120,14 +121,16 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
if (isNewMin && !retainedKeys.isEmpty() && !isKeySelected(retainedKeys.firstKey())) { if (isNewMin && !retainedKeys.isEmpty() && !isKeySelected(retainedKeys.firstKey())) {
// Old min should be kicked out. // Old min should be kicked out.
totalWeightUnadjusted -= retainedKeys.removeLong(retainedKeys.firstKey()); totalWeightUnadjusted -= retainedKeys.removeLong(retainedKeys.firstKey());
retainedBytes -= retainedKeys.firstKey().getNumberOfBytes();
} }
if (retainedKeys.putIfAbsent(key, weight) == MISSING_KEY_WEIGHT) { if (retainedKeys.putIfAbsent(key, weight) == MISSING_KEY_WEIGHT) {
// We did add this key. (Previous value was zero, meaning absent.) // We did add this key. (Previous value was zero, meaning absent.)
totalWeightUnadjusted += weight; totalWeightUnadjusted += weight;
retainedBytes += key.getNumberOfBytes();
} }
while (retainedKeys.size() >= maxKeys) { while (retainedBytes >= maxBytes) {
increaseSpaceReductionFactorIfPossible(); increaseSpaceReductionFactorIfPossible();
} }
} }
@ -168,6 +171,12 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
return retainedKeys.size(); return retainedKeys.size();
} }
@Override
public double estimatedRetainedBytes()
{
return retainedBytes;
}
@Override @Override
public RowKey minKey() public RowKey minKey()
{ {
@ -182,13 +191,13 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
return true; return true;
} }
if (maxKeys == SMALLEST_MAX_KEYS) { if (maxBytes <= SMALLEST_MAX_BYTES) {
return false; return false;
} }
maxKeys /= 2; maxBytes /= 2;
while (retainedKeys.size() >= maxKeys) { while (retainedBytes >= maxBytes) {
if (!increaseSpaceReductionFactorIfPossible()) { if (!increaseSpaceReductionFactorIfPossible()) {
return false; return false;
} }
@ -242,10 +251,10 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
return retainedKeys; return retainedKeys;
} }
@JsonProperty("maxKeys") @JsonProperty("maxBytes")
int getMaxKeys() int getMaxBytes()
{ {
return maxKeys; return maxBytes;
} }
@JsonProperty("spaceReductionFactor") @JsonProperty("spaceReductionFactor")
@ -296,6 +305,7 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
if (!isKeySelected(key)) { if (!isKeySelected(key)) {
totalWeightUnadjusted -= entry.getLongValue(); totalWeightUnadjusted -= entry.getLongValue();
retainedBytes -= entry.getKey().getNumberOfBytes();
iterator.remove(); iterator.remove();
} }
} }

View File

@ -53,6 +53,12 @@ public interface KeyCollector<CollectorType extends KeyCollector<CollectorType>>
*/ */
int estimatedRetainedKeys(); int estimatedRetainedKeys();
/**
* Returns an estimate of the number of bytes currently retained by this collector. This may change over time as
* more keys are added.
*/
double estimatedRetainedBytes();
/** /**
* Downsample this collector, dropping about half of the keys that are currently retained. Returns true if * Downsample this collector, dropping about half of the keys that are currently retained. Returns true if
* the collector was downsampled, or if it is already retaining zero or one keys. Returns false if the collector is * the collector was downsampled, or if it is already retaining zero or one keys. Returns false if the collector is

View File

@ -37,28 +37,39 @@ import java.util.NoSuchElementException;
/** /**
* A key collector that is used when not aggregating. It uses a quantiles sketch to track keys. * A key collector that is used when not aggregating. It uses a quantiles sketch to track keys.
*
* The collector maintains the averageKeyLength for all keys added through {@link #add(RowKey, long)} or
* {@link #addAll(QuantilesSketchKeyCollector)}. The average is calculated as a running average and accounts for
* weight of the key added. The averageKeyLength is assumed to be unaffected by {@link #downSample()}.
*/ */
public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketchKeyCollector> public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketchKeyCollector>
{ {
private final Comparator<RowKey> comparator; private final Comparator<RowKey> comparator;
private ItemsSketch<RowKey> sketch; private ItemsSketch<RowKey> sketch;
private double averageKeyLength;
QuantilesSketchKeyCollector( QuantilesSketchKeyCollector(
final Comparator<RowKey> comparator, final Comparator<RowKey> comparator,
@Nullable final ItemsSketch<RowKey> sketch @Nullable final ItemsSketch<RowKey> sketch,
double averageKeyLength
) )
{ {
this.comparator = comparator; this.comparator = comparator;
this.sketch = sketch; this.sketch = sketch;
this.averageKeyLength = averageKeyLength;
} }
@Override @Override
public void add(RowKey key, long weight) public void add(RowKey key, long weight)
{ {
double estimatedTotalSketchSizeInBytes = averageKeyLength * sketch.getN();
// The key is added "weight" times to the sketch, we can update the total weight directly.
estimatedTotalSketchSizeInBytes += key.getNumberOfBytes() * weight;
for (int i = 0; i < weight; i++) { for (int i = 0; i < weight; i++) {
// Add the same key multiple times to make it "heavier". // Add the same key multiple times to make it "heavier".
sketch.update(key); sketch.update(key);
} }
averageKeyLength = (estimatedTotalSketchSizeInBytes / sketch.getN());
} }
@Override @Override
@ -69,6 +80,10 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
comparator comparator
); );
double sketchBytesCount = averageKeyLength * sketch.getN();
double otherBytesCount = other.averageKeyLength * other.getSketch().getN();
averageKeyLength = ((sketchBytesCount + otherBytesCount) / (sketch.getN() + other.sketch.getN()));
union.update(sketch); union.update(sketch);
union.update(other.sketch); union.update(other.sketch);
sketch = union.getResultAndReset(); sketch = union.getResultAndReset();
@ -86,15 +101,16 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
return sketch.getN(); return sketch.getN();
} }
@Override
public double estimatedRetainedBytes()
{
return averageKeyLength * estimatedRetainedKeys();
}
@Override @Override
public int estimatedRetainedKeys() public int estimatedRetainedKeys()
{ {
// Rough estimation of retained keys for a given K for ~billions of total items, based on the table from return sketch.getRetainedItems();
// https://datasketches.apache.org/docs/Quantiles/OrigQuantilesSketch.html.
final int estimatedMaxRetainedKeys = 11 * sketch.getK();
// Cast to int is safe because estimatedMaxRetainedKeys is always within int range.
return (int) Math.min(sketch.getN(), estimatedMaxRetainedKeys);
} }
@Override @Override
@ -165,4 +181,12 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
{ {
return sketch; return sketch;
} }
/**
* Retrieves the average key length. Exists for usage by {@link QuantilesSketchKeyCollectorFactory}.
*/
double getAverageKeyLength()
{
return averageKeyLength;
}
} }

View File

@ -38,9 +38,9 @@ import java.util.Comparator;
public class QuantilesSketchKeyCollectorFactory public class QuantilesSketchKeyCollectorFactory
implements KeyCollectorFactory<QuantilesSketchKeyCollector, QuantilesSketchKeyCollectorSnapshot> implements KeyCollectorFactory<QuantilesSketchKeyCollector, QuantilesSketchKeyCollectorSnapshot>
{ {
// smallest value with normalized rank error < 0.1%; retain up to ~86k elements // Maximum value of K possible.
@VisibleForTesting @VisibleForTesting
static final int SKETCH_INITIAL_K = 1 << 12; static final int SKETCH_INITIAL_K = 1 << 15;
private final Comparator<RowKey> comparator; private final Comparator<RowKey> comparator;
@ -57,7 +57,7 @@ public class QuantilesSketchKeyCollectorFactory
@Override @Override
public QuantilesSketchKeyCollector newKeyCollector() public QuantilesSketchKeyCollector newKeyCollector()
{ {
return new QuantilesSketchKeyCollector(comparator, ItemsSketch.getInstance(SKETCH_INITIAL_K, comparator)); return new QuantilesSketchKeyCollector(comparator, ItemsSketch.getInstance(SKETCH_INITIAL_K, comparator), 0);
} }
@Override @Override
@ -79,7 +79,7 @@ public class QuantilesSketchKeyCollectorFactory
{ {
final String encodedSketch = final String encodedSketch =
StringUtils.encodeBase64String(collector.getSketch().toByteArray(RowKeySerde.INSTANCE)); StringUtils.encodeBase64String(collector.getSketch().toByteArray(RowKeySerde.INSTANCE));
return new QuantilesSketchKeyCollectorSnapshot(encodedSketch); return new QuantilesSketchKeyCollectorSnapshot(encodedSketch, collector.getAverageKeyLength());
} }
@Override @Override
@ -89,7 +89,7 @@ public class QuantilesSketchKeyCollectorFactory
final byte[] bytes = StringUtils.decodeBase64String(encodedSketch); final byte[] bytes = StringUtils.decodeBase64String(encodedSketch);
final ItemsSketch<RowKey> sketch = final ItemsSketch<RowKey> sketch =
ItemsSketch.getInstance(Memory.wrap(bytes), comparator, RowKeySerde.INSTANCE); ItemsSketch.getInstance(Memory.wrap(bytes), comparator, RowKeySerde.INSTANCE);
return new QuantilesSketchKeyCollector(comparator, sketch); return new QuantilesSketchKeyCollector(comparator, sketch, snapshot.getAverageKeyLength());
} }
private static class RowKeySerde extends ArrayOfItemsSerDe<RowKey> private static class RowKeySerde extends ArrayOfItemsSerDe<RowKey>
@ -106,7 +106,7 @@ public class QuantilesSketchKeyCollectorFactory
int serializedSize = Integer.BYTES * items.length; int serializedSize = Integer.BYTES * items.length;
for (final RowKey key : items) { for (final RowKey key : items) {
serializedSize += key.array().length; serializedSize += key.getNumberOfBytes();
} }
final byte[] serializedBytes = new byte[serializedSize]; final byte[] serializedBytes = new byte[serializedSize];

View File

@ -20,7 +20,7 @@
package org.apache.druid.msq.statistics; package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Objects; import java.util.Objects;
@ -28,18 +28,27 @@ public class QuantilesSketchKeyCollectorSnapshot implements KeyCollectorSnapshot
{ {
private final String encodedSketch; private final String encodedSketch;
private final double averageKeyLength;
@JsonCreator @JsonCreator
public QuantilesSketchKeyCollectorSnapshot(String encodedSketch) public QuantilesSketchKeyCollectorSnapshot(@JsonProperty("encodedSketch") String encodedSketch, @JsonProperty("averageKeyLength") double averageKeyLength)
{ {
this.encodedSketch = encodedSketch; this.encodedSketch = encodedSketch;
this.averageKeyLength = averageKeyLength;
} }
@JsonValue @JsonProperty("encodedSketch")
public String getEncodedSketch() public String getEncodedSketch()
{ {
return encodedSketch; return encodedSketch;
} }
@JsonProperty("averageKeyLength")
public double getAverageKeyLength()
{
return averageKeyLength;
}
@Override @Override
public boolean equals(Object o) public boolean equals(Object o)
{ {
@ -50,12 +59,13 @@ public class QuantilesSketchKeyCollectorSnapshot implements KeyCollectorSnapshot
return false; return false;
} }
QuantilesSketchKeyCollectorSnapshot that = (QuantilesSketchKeyCollectorSnapshot) o; QuantilesSketchKeyCollectorSnapshot that = (QuantilesSketchKeyCollectorSnapshot) o;
return Objects.equals(encodedSketch, that.encodedSketch); return Objects.equals(encodedSketch, that.encodedSketch)
&& Double.compare(that.averageKeyLength, averageKeyLength) == 0;
} }
@Override @Override
public int hashCode() public int hashCode()
{ {
return Objects.hash(encodedSketch); return Objects.hash(encodedSketch, averageKeyLength);
} }
} }

View File

@ -80,7 +80,7 @@ public class ClusterByStatisticsCollectorImplTest extends InitializedNullHandlin
); );
// These numbers are roughly 10x lower than authentic production numbers. (See StageDefinition.) // These numbers are roughly 10x lower than authentic production numbers. (See StageDefinition.)
private static final int MAX_KEYS = 5000; private static final int MAX_BYTES = 1_000_000;
private static final int MAX_BUCKETS = 1000; private static final int MAX_BUCKETS = 1000;
@Test @Test
@ -598,7 +598,7 @@ public class ClusterByStatisticsCollectorImplTest extends InitializedNullHandlin
private ClusterByStatisticsCollectorImpl makeCollector(final ClusterBy clusterBy, final boolean aggregate) private ClusterByStatisticsCollectorImpl makeCollector(final ClusterBy clusterBy, final boolean aggregate)
{ {
return (ClusterByStatisticsCollectorImpl) return (ClusterByStatisticsCollectorImpl)
ClusterByStatisticsCollectorImpl.create(clusterBy, SIGNATURE, MAX_KEYS, MAX_BUCKETS, aggregate, false); ClusterByStatisticsCollectorImpl.create(clusterBy, SIGNATURE, MAX_BYTES, MAX_BUCKETS, aggregate, false);
} }
private static void verifyPartitions( private static void verifyPartitions(

View File

@ -58,7 +58,7 @@ public class DelegateOrMinKeyCollectorTest
Assert.assertTrue(collector.getDelegate().isPresent()); Assert.assertTrue(collector.getDelegate().isPresent());
Assert.assertTrue(collector.isEmpty()); Assert.assertTrue(collector.isEmpty());
Assert.assertThrows(NoSuchElementException.class, collector::minKey); Assert.assertThrows(NoSuchElementException.class, collector::minKey);
Assert.assertEquals(0, collector.estimatedRetainedKeys()); Assert.assertEquals(0, collector.estimatedRetainedBytes(), 0);
Assert.assertEquals(0, collector.estimatedTotalWeight()); Assert.assertEquals(0, collector.estimatedTotalWeight());
MatcherAssert.assertThat(collector.getDelegate().get(), CoreMatchers.instanceOf(QuantilesSketchKeyCollector.class)); MatcherAssert.assertThat(collector.getDelegate().get(), CoreMatchers.instanceOf(QuantilesSketchKeyCollector.class));
} }
@ -83,12 +83,13 @@ public class DelegateOrMinKeyCollectorTest
QuantilesSketchKeyCollectorFactory.create(clusterBy) QuantilesSketchKeyCollectorFactory.create(clusterBy)
).newKeyCollector(); ).newKeyCollector();
collector.add(createKey(1L), 1); RowKey key = createKey(1L);
collector.add(key, 1);
Assert.assertTrue(collector.getDelegate().isPresent()); Assert.assertTrue(collector.getDelegate().isPresent());
Assert.assertFalse(collector.isEmpty()); Assert.assertFalse(collector.isEmpty());
Assert.assertEquals(createKey(1L), collector.minKey()); Assert.assertEquals(key, collector.minKey());
Assert.assertEquals(1, collector.estimatedRetainedKeys()); Assert.assertEquals(key.getNumberOfBytes(), collector.estimatedRetainedBytes(), 0);
Assert.assertEquals(1, collector.estimatedTotalWeight()); Assert.assertEquals(1, collector.estimatedTotalWeight());
} }
@ -101,13 +102,15 @@ public class DelegateOrMinKeyCollectorTest
QuantilesSketchKeyCollectorFactory.create(clusterBy) QuantilesSketchKeyCollectorFactory.create(clusterBy)
).newKeyCollector(); ).newKeyCollector();
collector.add(createKey(1L), 1); RowKey key = createKey(1L);
collector.add(key, 1);
Assert.assertTrue(collector.downSample()); Assert.assertTrue(collector.downSample());
Assert.assertTrue(collector.getDelegate().isPresent()); Assert.assertTrue(collector.getDelegate().isPresent());
Assert.assertFalse(collector.isEmpty()); Assert.assertFalse(collector.isEmpty());
Assert.assertEquals(createKey(1L), collector.minKey()); Assert.assertEquals(key, collector.minKey());
Assert.assertEquals(1, collector.estimatedRetainedKeys()); Assert.assertEquals(key.getNumberOfBytes(), collector.estimatedRetainedBytes(), 0);
Assert.assertEquals(1, collector.estimatedTotalWeight()); Assert.assertEquals(1, collector.estimatedTotalWeight());
// Should not have actually downsampled, because the quantiles-based collector does nothing when // Should not have actually downsampled, because the quantiles-based collector does nothing when
@ -127,23 +130,26 @@ public class DelegateOrMinKeyCollectorTest
QuantilesSketchKeyCollectorFactory.create(clusterBy) QuantilesSketchKeyCollectorFactory.create(clusterBy)
).newKeyCollector(); ).newKeyCollector();
collector.add(createKey(1L), 1); RowKey key = createKey(1L);
collector.add(createKey(1L), 1); collector.add(key, 1);
collector.add(key, 1);
int expectedRetainedBytes = 2 * key.getNumberOfBytes();
Assert.assertTrue(collector.getDelegate().isPresent()); Assert.assertTrue(collector.getDelegate().isPresent());
Assert.assertFalse(collector.isEmpty()); Assert.assertFalse(collector.isEmpty());
Assert.assertEquals(createKey(1L), collector.minKey()); Assert.assertEquals(createKey(1L), collector.minKey());
Assert.assertEquals(2, collector.estimatedRetainedKeys()); Assert.assertEquals(expectedRetainedBytes, collector.estimatedRetainedBytes(), 0);
Assert.assertEquals(2, collector.estimatedTotalWeight()); Assert.assertEquals(2, collector.estimatedTotalWeight());
while (collector.getDelegate().isPresent()) { while (collector.getDelegate().isPresent()) {
Assert.assertTrue(collector.downSample()); Assert.assertTrue(collector.downSample());
} }
expectedRetainedBytes = key.getNumberOfBytes();
Assert.assertFalse(collector.getDelegate().isPresent()); Assert.assertFalse(collector.getDelegate().isPresent());
Assert.assertFalse(collector.isEmpty()); Assert.assertFalse(collector.isEmpty());
Assert.assertEquals(createKey(1L), collector.minKey()); Assert.assertEquals(createKey(1L), collector.minKey());
Assert.assertEquals(1, collector.estimatedRetainedKeys()); Assert.assertEquals(expectedRetainedBytes, collector.estimatedRetainedBytes(), 0);
Assert.assertEquals(1, collector.estimatedTotalWeight()); Assert.assertEquals(1, collector.estimatedTotalWeight());
} }

View File

@ -20,6 +20,7 @@
package org.apache.druid.msq.statistics; package org.apache.druid.msq.statistics;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
@ -43,6 +44,10 @@ public class DistinctKeyCollectorTest
private final Comparator<RowKey> comparator = clusterBy.keyComparator(); private final Comparator<RowKey> comparator = clusterBy.keyComparator();
private final int numKeys = 500_000; private final int numKeys = 500_000;
static {
NullHandling.initializeForTests();
}
@Test @Test
public void test_empty() public void test_empty()
{ {
@ -127,11 +132,11 @@ public class DistinctKeyCollectorTest
// Intentionally empty loop body. // Intentionally empty loop body.
} }
Assert.assertEquals(DistinctKeyCollector.SMALLEST_MAX_KEYS, collector.getMaxKeys()); Assert.assertTrue(DistinctKeyCollector.SMALLEST_MAX_BYTES >= collector.getMaxBytes());
MatcherAssert.assertThat( MatcherAssert.assertThat(
testName, testName,
collector.estimatedRetainedKeys(), (int) collector.estimatedRetainedBytes(),
Matchers.lessThanOrEqualTo(DistinctKeyCollector.SMALLEST_MAX_KEYS) Matchers.lessThanOrEqualTo(DistinctKeyCollector.SMALLEST_MAX_BYTES)
); );
// Don't use verifyCollector, since this collector is downsampled so aggressively that it can't possibly // Don't use verifyCollector, since this collector is downsampled so aggressively that it can't possibly
@ -230,8 +235,7 @@ public class DistinctKeyCollectorTest
final NavigableMap<RowKey, List<Integer>> sortedKeyWeights final NavigableMap<RowKey, List<Integer>> sortedKeyWeights
) )
{ {
Assert.assertEquals(collector.getRetainedKeys().size(), collector.estimatedRetainedKeys()); MatcherAssert.assertThat((int) collector.estimatedRetainedBytes(), Matchers.lessThan(collector.getMaxBytes()));
MatcherAssert.assertThat(collector.getRetainedKeys().size(), Matchers.lessThan(collector.getMaxKeys()));
KeyCollectorTestUtils.verifyCollector( KeyCollectorTestUtils.verifyCollector(
collector, collector,

View File

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.junit.Assert;
import org.junit.Test;
public class QuantilesSketchKeyCollectorSnapshotTest
{
private final ObjectMapper jsonMapper = new DefaultObjectMapper();
@Test
public void testSnapshotSerde() throws JsonProcessingException
{
QuantilesSketchKeyCollectorSnapshot snapshot = new QuantilesSketchKeyCollectorSnapshot("sketchString", 100);
String jsonStr = jsonMapper.writeValueAsString(snapshot);
Assert.assertEquals(snapshot, jsonMapper.readValue(jsonStr, QuantilesSketchKeyCollectorSnapshot.class));
}
}

View File

@ -24,9 +24,12 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.KeyTestUtils;
import org.apache.druid.frame.key.RowKey; import org.apache.druid.frame.key.RowKey;
import org.apache.druid.frame.key.SortColumn; import org.apache.druid.frame.key.SortColumn;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
@ -119,7 +122,7 @@ public class QuantilesSketchKeyCollectorTest
} }
Assert.assertEquals(testName, 2, collector.getSketch().getK()); Assert.assertEquals(testName, 2, collector.getSketch().getK());
Assert.assertEquals(testName, 22, collector.estimatedRetainedKeys()); Assert.assertEquals(testName, 14, collector.estimatedRetainedKeys());
// Don't use verifyCollector, since this collector is downsampled so aggressively that it can't possibly // Don't use verifyCollector, since this collector is downsampled so aggressively that it can't possibly
// hope to pass those tests. Grade on a curve. // hope to pass those tests. Grade on a curve.
@ -161,6 +164,46 @@ public class QuantilesSketchKeyCollectorTest
); );
} }
@Test
public void testAverageKeyLength()
{
final QuantilesSketchKeyCollector collector = QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector();
final QuantilesSketchKeyCollector other = QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector();
RowSignature smallKeySignature = KeyTestUtils.createKeySignature(
new ClusterBy(ImmutableList.of(new SortColumn("x", false)), 0).getColumns(),
RowSignature.builder().add("x", ColumnType.LONG).build()
);
RowKey smallKey = KeyTestUtils.createKey(smallKeySignature, 1L);
RowSignature largeKeySignature = KeyTestUtils.createKeySignature(
new ClusterBy(
ImmutableList.of(
new SortColumn("x", false),
new SortColumn("y", false),
new SortColumn("z", false)
),
0).getColumns(),
RowSignature.builder()
.add("x", ColumnType.LONG)
.add("y", ColumnType.LONG)
.add("z", ColumnType.LONG)
.build()
);
RowKey largeKey = KeyTestUtils.createKey(largeKeySignature, 1L, 2L, 3L);
collector.add(smallKey, 3);
Assert.assertEquals(smallKey.getNumberOfBytes(), collector.getAverageKeyLength(), 0);
other.add(largeKey, 5);
Assert.assertEquals(largeKey.getNumberOfBytes(), other.getAverageKeyLength(), 0);
collector.addAll(other);
Assert.assertEquals((smallKey.getNumberOfBytes() * 3 + largeKey.getNumberOfBytes() * 5) / 8.0, collector.getAverageKeyLength(), 0);
}
@Test @Test
public void test_uniformRandomKeys_inverseBarbellWeighted() public void test_uniformRandomKeys_inverseBarbellWeighted()
{ {

View File

@ -108,4 +108,9 @@ public class RowKey
{ {
return Arrays.toString(key); return Arrays.toString(key);
} }
public int getNumberOfBytes()
{
return array().length;
}
} }

View File

@ -91,4 +91,17 @@ public class RowKeyTest extends InitializedNullHandlingTest
KeyTestUtils.createKey(signatureLongString, 1L, "def").hashCode() KeyTestUtils.createKey(signatureLongString, 1L, "def").hashCode()
); );
} }
@Test
public void testGetNumberOfBytes()
{
final RowSignature signatureLong = RowSignature.builder().add("1", ColumnType.LONG).build();
final RowKey longKey = KeyTestUtils.createKey(signatureLong, 1L, "abc");
Assert.assertEquals(longKey.array().length, longKey.getNumberOfBytes());
final RowSignature signatureLongString =
RowSignature.builder().add("1", ColumnType.LONG).add("2", ColumnType.STRING).build();
final RowKey longStringKey = KeyTestUtils.createKey(signatureLongString, 1L, "abc");
Assert.assertEquals(longStringKey.array().length, longStringKey.getNumberOfBytes());
}
} }