mirror of https://github.com/apache/druid.git
BufferArrayGrouper: Fix potential overflow in requiredBufferCapacity. (#9435)
* BufferArrayGrouper: Fix potential overflow in requiredBufferCapacity. If cardinality was high, the computation could overflow an int. There were tests for this, but the tests were wrong. * Nicer.
This commit is contained in:
parent
81d8be6e39
commit
1fd865b7c1
|
@ -21,6 +21,7 @@ package org.apache.druid.query.groupby.epinephelinae;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
import com.google.common.base.Supplier;
|
import com.google.common.base.Supplier;
|
||||||
|
import com.google.common.primitives.Ints;
|
||||||
import org.apache.datasketches.memory.Memory;
|
import org.apache.datasketches.memory.Memory;
|
||||||
import org.apache.datasketches.memory.WritableMemory;
|
import org.apache.datasketches.memory.WritableMemory;
|
||||||
import org.apache.druid.java.util.common.IAE;
|
import org.apache.druid.java.util.common.IAE;
|
||||||
|
@ -73,19 +74,24 @@ public class BufferArrayGrouper implements VectorGrouper, IntGrouper
|
||||||
AggregatorFactory[] aggregatorFactories
|
AggregatorFactory[] aggregatorFactories
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
final int cardinalityWithMissingValue = cardinality + 1;
|
final long cardinalityWithMissingValue = computeCardinalityWithMissingValue(cardinality);
|
||||||
final int recordSize = Arrays.stream(aggregatorFactories)
|
final long recordSize = Arrays.stream(aggregatorFactories)
|
||||||
.mapToInt(AggregatorFactory::getMaxIntermediateSizeWithNulls)
|
.mapToLong(AggregatorFactory::getMaxIntermediateSizeWithNulls)
|
||||||
.sum();
|
.sum();
|
||||||
|
|
||||||
return getUsedFlagBufferCapacity(cardinalityWithMissingValue) + // total used flags size
|
return getUsedFlagBufferCapacity(cardinalityWithMissingValue) + // total used flags size
|
||||||
(long) cardinalityWithMissingValue * recordSize; // total values size
|
cardinalityWithMissingValue * recordSize; // total values size
|
||||||
|
}
|
||||||
|
|
||||||
|
private static long computeCardinalityWithMissingValue(int cardinality)
|
||||||
|
{
|
||||||
|
return (long) cardinality + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute the number of bytes to store all used flag bits.
|
* Compute the number of bytes to store all used flag bits.
|
||||||
*/
|
*/
|
||||||
private static int getUsedFlagBufferCapacity(int cardinalityWithMissingValue)
|
private static long getUsedFlagBufferCapacity(long cardinalityWithMissingValue)
|
||||||
{
|
{
|
||||||
return (cardinalityWithMissingValue + Byte.SIZE - 1) / Byte.SIZE;
|
return (cardinalityWithMissingValue + Byte.SIZE - 1) / Byte.SIZE;
|
||||||
}
|
}
|
||||||
|
@ -102,7 +108,7 @@ public class BufferArrayGrouper implements VectorGrouper, IntGrouper
|
||||||
|
|
||||||
this.bufferSupplier = Preconditions.checkNotNull(bufferSupplier, "bufferSupplier");
|
this.bufferSupplier = Preconditions.checkNotNull(bufferSupplier, "bufferSupplier");
|
||||||
this.aggregators = aggregators;
|
this.aggregators = aggregators;
|
||||||
this.cardinalityWithMissingValue = cardinality + 1;
|
this.cardinalityWithMissingValue = Ints.checkedCast(computeCardinalityWithMissingValue(cardinality));
|
||||||
this.recordSize = aggregators.spaceNeeded();
|
this.recordSize = aggregators.spaceNeeded();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,7 +118,7 @@ public class BufferArrayGrouper implements VectorGrouper, IntGrouper
|
||||||
if (!initialized) {
|
if (!initialized) {
|
||||||
final ByteBuffer buffer = bufferSupplier.get();
|
final ByteBuffer buffer = bufferSupplier.get();
|
||||||
|
|
||||||
final int usedFlagBufferEnd = getUsedFlagBufferCapacity(cardinalityWithMissingValue);
|
final int usedFlagBufferEnd = Ints.checkedCast(getUsedFlagBufferCapacity(cardinalityWithMissingValue));
|
||||||
|
|
||||||
// Sanity check on buffer capacity.
|
// Sanity check on buffer capacity.
|
||||||
if (usedFlagBufferEnd + (long) cardinalityWithMissingValue * recordSize > buffer.capacity()) {
|
if (usedFlagBufferEnd + (long) cardinalityWithMissingValue * recordSize > buffer.capacity()) {
|
||||||
|
|
|
@ -27,12 +27,14 @@ import com.google.common.collect.Ordering;
|
||||||
import com.google.common.primitives.Ints;
|
import com.google.common.primitives.Ints;
|
||||||
import org.apache.druid.common.config.NullHandling;
|
import org.apache.druid.common.config.NullHandling;
|
||||||
import org.apache.druid.data.input.MapBasedRow;
|
import org.apache.druid.data.input.MapBasedRow;
|
||||||
|
import org.apache.druid.java.util.common.StringUtils;
|
||||||
import org.apache.druid.query.aggregation.AggregatorAdapters;
|
import org.apache.druid.query.aggregation.AggregatorAdapters;
|
||||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||||
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
|
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
|
||||||
import org.apache.druid.query.groupby.epinephelinae.Grouper.Entry;
|
import org.apache.druid.query.groupby.epinephelinae.Grouper.Entry;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
@ -41,6 +43,12 @@ import java.util.List;
|
||||||
|
|
||||||
public class BufferArrayGrouperTest
|
public class BufferArrayGrouperTest
|
||||||
{
|
{
|
||||||
|
@BeforeClass
|
||||||
|
public static void setUpStatic()
|
||||||
|
{
|
||||||
|
NullHandling.initializeForTests();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAggregate()
|
public void testAggregate()
|
||||||
{
|
{
|
||||||
|
@ -94,23 +102,29 @@ public class BufferArrayGrouperTest
|
||||||
@Test
|
@Test
|
||||||
public void testRequiredBufferCapacity()
|
public void testRequiredBufferCapacity()
|
||||||
{
|
{
|
||||||
int[] cardinalityArray = new int[]{1, 10, Integer.MAX_VALUE - 1};
|
final int[] cardinalityArray = new int[]{1, 10, Integer.MAX_VALUE - 1, Integer.MAX_VALUE};
|
||||||
AggregatorFactory[] aggregatorFactories = new AggregatorFactory[]{
|
final AggregatorFactory[] aggregatorFactories = new AggregatorFactory[]{
|
||||||
new LongSumAggregatorFactory("sum", "sum")
|
new LongSumAggregatorFactory("sum", "sum")
|
||||||
};
|
};
|
||||||
long[] requiredSizes;
|
|
||||||
|
final long[] requiredSizes;
|
||||||
|
|
||||||
if (NullHandling.sqlCompatible()) {
|
if (NullHandling.sqlCompatible()) {
|
||||||
// We need additional size to store nullability information.
|
// We need additional size to store nullability information.
|
||||||
requiredSizes = new long[]{19, 101, 19058917368L};
|
requiredSizes = new long[]{19, 101, 19595788279L, 19595788288L};
|
||||||
} else {
|
} else {
|
||||||
requiredSizes = new long[]{17, 90, 16911433721L};
|
requiredSizes = new long[]{17, 90, 17448304632L, 17448304640L};
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < cardinalityArray.length; i++) {
|
for (int i = 0; i < cardinalityArray.length; i++) {
|
||||||
Assert.assertEquals(requiredSizes[i], BufferArrayGrouper.requiredBufferCapacity(
|
Assert.assertEquals(
|
||||||
|
StringUtils.format("cardinality[%d]", cardinalityArray[i]),
|
||||||
|
requiredSizes[i],
|
||||||
|
BufferArrayGrouper.requiredBufferCapacity(
|
||||||
cardinalityArray[i],
|
cardinalityArray[i],
|
||||||
aggregatorFactories
|
aggregatorFactories
|
||||||
));
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue