From 912adcc2845d91c58b0f66621fb03d4cf969811c Mon Sep 17 00:00:00 2001 From: Niketh Sabbineni Date: Wed, 28 Mar 2018 16:37:53 -0700 Subject: [PATCH] ArrayAggregation: Use long to avoid overflow (#5544) * ArrayAggregation: Use long to avoid overflow * Add Tests --- .../epinephelinae/BufferArrayGrouper.java | 4 ++-- .../epinephelinae/GroupByQueryEngineV2.java | 2 +- .../epinephelinae/BufferArrayGrouperTest.java | 17 +++++++++++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouper.java index 6db3f884b20..4e4a30d8fac 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouper.java @@ -61,7 +61,7 @@ public class BufferArrayGrouper implements IntGrouper private ByteBuffer usedFlagBuffer; private ByteBuffer valBuffer; - static int requiredBufferCapacity( + static long requiredBufferCapacity( int cardinality, AggregatorFactory[] aggregatorFactories ) @@ -72,7 +72,7 @@ public class BufferArrayGrouper implements IntGrouper .sum(); return getUsedFlagBufferCapacity(cardinalityWithMissingValue) + // total used flags size - cardinalityWithMissingValue * recordSize; // total values size + (long) cardinalityWithMissingValue * recordSize; // total values size } /** diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java index 3d48f1ad2b6..8e79e9e73d6 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java @@ -219,7 +219,7 @@ public class GroupByQueryEngineV2 final AggregatorFactory[] aggregatorFactories = query .getAggregatorSpecs() .toArray(new AggregatorFactory[query.getAggregatorSpecs().size()]); - final int requiredBufferCapacity = BufferArrayGrouper.requiredBufferCapacity( + final long requiredBufferCapacity = BufferArrayGrouper.requiredBufferCapacity( cardinality, aggregatorFactories ); diff --git a/processing/src/test/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouperTest.java b/processing/src/test/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouperTest.java index 6ca584f2b8d..acdc2c40604 100644 --- a/processing/src/test/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouperTest.java +++ b/processing/src/test/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouperTest.java @@ -86,4 +86,21 @@ public class BufferArrayGrouperTest grouper.init(); return grouper; } + + @Test + public void testRequiredBufferCapacity() + { + int[] cardinalityArray = new int[] {1, 10, Integer.MAX_VALUE - 1}; + AggregatorFactory[] aggregatorFactories = new AggregatorFactory[] { + new LongSumAggregatorFactory("sum", "sum") + }; + + long[] requiredSizes = new long[] {17, 90, 16911433721L}; + + for (int i = 0; i < cardinalityArray.length; i++) { + Assert.assertEquals(requiredSizes[i], BufferArrayGrouper.requiredBufferCapacity( + cardinalityArray[i], + aggregatorFactories)); + } + } }