BufferArrayGrouper: Fix potential overflow in requiredBufferCapacity. (#9435)

* BufferArrayGrouper: Fix potential overflow in requiredBufferCapacity.

If cardinality was high, the computation could overflow an int. There
were tests for this, but the tests were wrong.

* Nicer.
This commit is contained in:
Gian Merlino 2020-02-28 14:27:52 -08:00 committed by GitHub
parent 81d8be6e39
commit 1fd865b7c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 17 deletions

View File

@ -21,6 +21,7 @@ package org.apache.druid.query.groupby.epinephelinae;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import com.google.common.primitives.Ints;
import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory; import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
@ -73,19 +74,24 @@ public class BufferArrayGrouper implements VectorGrouper, IntGrouper
AggregatorFactory[] aggregatorFactories AggregatorFactory[] aggregatorFactories
) )
{ {
final int cardinalityWithMissingValue = cardinality + 1; final long cardinalityWithMissingValue = computeCardinalityWithMissingValue(cardinality);
final int recordSize = Arrays.stream(aggregatorFactories) final long recordSize = Arrays.stream(aggregatorFactories)
.mapToInt(AggregatorFactory::getMaxIntermediateSizeWithNulls) .mapToLong(AggregatorFactory::getMaxIntermediateSizeWithNulls)
.sum(); .sum();
return getUsedFlagBufferCapacity(cardinalityWithMissingValue) + // total used flags size return getUsedFlagBufferCapacity(cardinalityWithMissingValue) + // total used flags size
(long) cardinalityWithMissingValue * recordSize; // total values size cardinalityWithMissingValue * recordSize; // total values size
}
private static long computeCardinalityWithMissingValue(int cardinality)
{
return (long) cardinality + 1;
} }
/** /**
* Compute the number of bytes to store all used flag bits. * Compute the number of bytes to store all used flag bits.
*/ */
private static int getUsedFlagBufferCapacity(int cardinalityWithMissingValue) private static long getUsedFlagBufferCapacity(long cardinalityWithMissingValue)
{ {
return (cardinalityWithMissingValue + Byte.SIZE - 1) / Byte.SIZE; return (cardinalityWithMissingValue + Byte.SIZE - 1) / Byte.SIZE;
} }
@ -102,7 +108,7 @@ public class BufferArrayGrouper implements VectorGrouper, IntGrouper
this.bufferSupplier = Preconditions.checkNotNull(bufferSupplier, "bufferSupplier"); this.bufferSupplier = Preconditions.checkNotNull(bufferSupplier, "bufferSupplier");
this.aggregators = aggregators; this.aggregators = aggregators;
this.cardinalityWithMissingValue = cardinality + 1; this.cardinalityWithMissingValue = Ints.checkedCast(computeCardinalityWithMissingValue(cardinality));
this.recordSize = aggregators.spaceNeeded(); this.recordSize = aggregators.spaceNeeded();
} }
@ -112,7 +118,7 @@ public class BufferArrayGrouper implements VectorGrouper, IntGrouper
if (!initialized) { if (!initialized) {
final ByteBuffer buffer = bufferSupplier.get(); final ByteBuffer buffer = bufferSupplier.get();
final int usedFlagBufferEnd = getUsedFlagBufferCapacity(cardinalityWithMissingValue); final int usedFlagBufferEnd = Ints.checkedCast(getUsedFlagBufferCapacity(cardinalityWithMissingValue));
// Sanity check on buffer capacity. // Sanity check on buffer capacity.
if (usedFlagBufferEnd + (long) cardinalityWithMissingValue * recordSize > buffer.capacity()) { if (usedFlagBufferEnd + (long) cardinalityWithMissingValue * recordSize > buffer.capacity()) {

View File

@ -27,12 +27,14 @@ import com.google.common.collect.Ordering;
import com.google.common.primitives.Ints; import com.google.common.primitives.Ints;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.MapBasedRow; import org.apache.druid.data.input.MapBasedRow;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorAdapters; import org.apache.druid.query.aggregation.AggregatorAdapters;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.groupby.epinephelinae.Grouper.Entry; import org.apache.druid.query.groupby.epinephelinae.Grouper.Entry;
import org.junit.Assert; import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -41,6 +43,12 @@ import java.util.List;
public class BufferArrayGrouperTest public class BufferArrayGrouperTest
{ {
@BeforeClass
public static void setUpStatic()
{
NullHandling.initializeForTests();
}
@Test @Test
public void testAggregate() public void testAggregate()
{ {
@ -94,23 +102,29 @@ public class BufferArrayGrouperTest
@Test @Test
public void testRequiredBufferCapacity() public void testRequiredBufferCapacity()
{ {
int[] cardinalityArray = new int[]{1, 10, Integer.MAX_VALUE - 1}; final int[] cardinalityArray = new int[]{1, 10, Integer.MAX_VALUE - 1, Integer.MAX_VALUE};
AggregatorFactory[] aggregatorFactories = new AggregatorFactory[]{ final AggregatorFactory[] aggregatorFactories = new AggregatorFactory[]{
new LongSumAggregatorFactory("sum", "sum") new LongSumAggregatorFactory("sum", "sum")
}; };
long[] requiredSizes;
final long[] requiredSizes;
if (NullHandling.sqlCompatible()) { if (NullHandling.sqlCompatible()) {
// We need additional size to store nullability information. // We need additional size to store nullability information.
requiredSizes = new long[]{19, 101, 19058917368L}; requiredSizes = new long[]{19, 101, 19595788279L, 19595788288L};
} else { } else {
requiredSizes = new long[]{17, 90, 16911433721L}; requiredSizes = new long[]{17, 90, 17448304632L, 17448304640L};
} }
for (int i = 0; i < cardinalityArray.length; i++) { for (int i = 0; i < cardinalityArray.length; i++) {
Assert.assertEquals(requiredSizes[i], BufferArrayGrouper.requiredBufferCapacity( Assert.assertEquals(
StringUtils.format("cardinality[%d]", cardinalityArray[i]),
requiredSizes[i],
BufferArrayGrouper.requiredBufferCapacity(
cardinalityArray[i], cardinalityArray[i],
aggregatorFactories aggregatorFactories
)); )
);
} }
} }
} }