From 1fd865b7c127c140e241db2c0942378e20a5799d Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Fri, 28 Feb 2020 14:27:52 -0800 Subject: [PATCH] BufferArrayGrouper: Fix potential overflow in requiredBufferCapacity. (#9435) * BufferArrayGrouper: Fix potential overflow in requiredBufferCapacity. If cardinality was high, the computation could overflow an int. There were tests for this, but the tests were wrong. * Nicer. --- .../epinephelinae/BufferArrayGrouper.java | 22 ++++++++----- .../epinephelinae/BufferArrayGrouperTest.java | 32 +++++++++++++------ 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/BufferArrayGrouper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/BufferArrayGrouper.java index d3c2a7ee39a..9d6db299757 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/BufferArrayGrouper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/BufferArrayGrouper.java @@ -21,6 +21,7 @@ package org.apache.druid.query.groupby.epinephelinae; import com.google.common.base.Preconditions; import com.google.common.base.Supplier; +import com.google.common.primitives.Ints; import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.WritableMemory; import org.apache.druid.java.util.common.IAE; @@ -73,19 +74,24 @@ public class BufferArrayGrouper implements VectorGrouper, IntGrouper AggregatorFactory[] aggregatorFactories ) { - final int cardinalityWithMissingValue = cardinality + 1; - final int recordSize = Arrays.stream(aggregatorFactories) - .mapToInt(AggregatorFactory::getMaxIntermediateSizeWithNulls) - .sum(); + final long cardinalityWithMissingValue = computeCardinalityWithMissingValue(cardinality); + final long recordSize = Arrays.stream(aggregatorFactories) + .mapToLong(AggregatorFactory::getMaxIntermediateSizeWithNulls) + .sum(); return getUsedFlagBufferCapacity(cardinalityWithMissingValue) + // total used flags size - (long) cardinalityWithMissingValue * recordSize; // total values size + cardinalityWithMissingValue * recordSize; // total values size + } + + private static long computeCardinalityWithMissingValue(int cardinality) + { + return (long) cardinality + 1; } /** * Compute the number of bytes to store all used flag bits. */ - private static int getUsedFlagBufferCapacity(int cardinalityWithMissingValue) + private static long getUsedFlagBufferCapacity(long cardinalityWithMissingValue) { return (cardinalityWithMissingValue + Byte.SIZE - 1) / Byte.SIZE; } @@ -102,7 +108,7 @@ public class BufferArrayGrouper implements VectorGrouper, IntGrouper this.bufferSupplier = Preconditions.checkNotNull(bufferSupplier, "bufferSupplier"); this.aggregators = aggregators; - this.cardinalityWithMissingValue = cardinality + 1; + this.cardinalityWithMissingValue = Ints.checkedCast(computeCardinalityWithMissingValue(cardinality)); this.recordSize = aggregators.spaceNeeded(); } @@ -112,7 +118,7 @@ public class BufferArrayGrouper implements VectorGrouper, IntGrouper if (!initialized) { final ByteBuffer buffer = bufferSupplier.get(); - final int usedFlagBufferEnd = getUsedFlagBufferCapacity(cardinalityWithMissingValue); + final int usedFlagBufferEnd = Ints.checkedCast(getUsedFlagBufferCapacity(cardinalityWithMissingValue)); // Sanity check on buffer capacity. if (usedFlagBufferEnd + (long) cardinalityWithMissingValue * recordSize > buffer.capacity()) { diff --git a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/BufferArrayGrouperTest.java b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/BufferArrayGrouperTest.java index 98d46fc9d12..c05d97632af 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/BufferArrayGrouperTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/BufferArrayGrouperTest.java @@ -27,12 +27,14 @@ import com.google.common.collect.Ordering; import com.google.common.primitives.Ints; import org.apache.druid.common.config.NullHandling; import org.apache.druid.data.input.MapBasedRow; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorAdapters; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.groupby.epinephelinae.Grouper.Entry; import org.junit.Assert; +import org.junit.BeforeClass; import org.junit.Test; import java.nio.ByteBuffer; @@ -41,6 +43,12 @@ import java.util.List; public class BufferArrayGrouperTest { + @BeforeClass + public static void setUpStatic() + { + NullHandling.initializeForTests(); + } + @Test public void testAggregate() { @@ -94,23 +102,29 @@ public class BufferArrayGrouperTest @Test public void testRequiredBufferCapacity() { - int[] cardinalityArray = new int[]{1, 10, Integer.MAX_VALUE - 1}; - AggregatorFactory[] aggregatorFactories = new AggregatorFactory[]{ + final int[] cardinalityArray = new int[]{1, 10, Integer.MAX_VALUE - 1, Integer.MAX_VALUE}; + final AggregatorFactory[] aggregatorFactories = new AggregatorFactory[]{ new LongSumAggregatorFactory("sum", "sum") }; - long[] requiredSizes; + + final long[] requiredSizes; + if (NullHandling.sqlCompatible()) { // We need additional size to store nullability information. - requiredSizes = new long[]{19, 101, 19058917368L}; + requiredSizes = new long[]{19, 101, 19595788279L, 19595788288L}; } else { - requiredSizes = new long[]{17, 90, 16911433721L}; + requiredSizes = new long[]{17, 90, 17448304632L, 17448304640L}; } for (int i = 0; i < cardinalityArray.length; i++) { - Assert.assertEquals(requiredSizes[i], BufferArrayGrouper.requiredBufferCapacity( - cardinalityArray[i], - aggregatorFactories - )); + Assert.assertEquals( + StringUtils.format("cardinality[%d]", cardinalityArray[i]), + requiredSizes[i], + BufferArrayGrouper.requiredBufferCapacity( + cardinalityArray[i], + aggregatorFactories + ) + ); } } }