Fix integer overflow in BufferGrouper. (#4333)

Would have led to out of bounds buffer access with large buffers.
Also added tests using large buffers.
This commit is contained in:
Gian Merlino 2017-05-26 15:30:20 +09:00 committed by Jonathan Wei
parent 2c55a935f8
commit 1eaa7887bd
2 changed files with 50 additions and 3 deletions

View File

@ -411,7 +411,7 @@ public class BufferGrouper<KeyType> implements Grouper<KeyType>
final int newMaxSize; final int newMaxSize;
final int newTableStart; final int newTableStart;
if ((tableStart + buckets * 3 * bucketSize) > tableArenaSize) { if ((long) buckets * 3 * bucketSize > (long) tableArenaSize - tableStart) {
// Not enough space to grow upwards, start back from zero // Not enough space to grow upwards, start back from zero
newTableStart = 0; newTableStart = 0;
newBuckets = tableStart / bucketSize; newBuckets = tableStart / bucketSize;

View File

@ -20,24 +20,34 @@
package io.druid.query.groupby.epinephelinae; package io.druid.query.groupby.epinephelinae;
import com.google.common.base.Suppliers; import com.google.common.base.Suppliers;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Ordering; import com.google.common.collect.Ordering;
import com.google.common.io.Files;
import com.google.common.primitives.Ints; import com.google.common.primitives.Ints;
import io.druid.data.input.MapBasedRow; import io.druid.data.input.MapBasedRow;
import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory; import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.aggregation.LongSumAggregatorFactory; import io.druid.query.aggregation.LongSumAggregatorFactory;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
public class BufferGrouperTest public class BufferGrouperTest
{ {
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Test @Test
public void testSimple() public void testSimple()
{ {
@ -116,6 +126,34 @@ public class BufferGrouperTest
Assert.assertEquals(expected, Lists.newArrayList(grouper.iterator(true))); Assert.assertEquals(expected, Lists.newArrayList(grouper.iterator(true)));
} }
@Test
public void testGrowing2()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
final Grouper<Integer> grouper = makeGrouper(columnSelectorFactory, 2_000_000_000, 2);
final int expectedMaxSize = 40988516;
columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("value", 10L)));
for (int i = 0; i < expectedMaxSize; i++) {
Assert.assertTrue(String.valueOf(i), grouper.aggregate(i).isOk());
}
Assert.assertFalse(grouper.aggregate(expectedMaxSize).isOk());
}
@Test
public void testGrowing3()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
final Grouper<Integer> grouper = makeGrouper(columnSelectorFactory, Integer.MAX_VALUE, 2);
final int expectedMaxSize = 44938972;
columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("value", 10L)));
for (int i = 0; i < expectedMaxSize; i++) {
Assert.assertTrue(String.valueOf(i), grouper.aggregate(i).isOk());
}
Assert.assertFalse(grouper.aggregate(expectedMaxSize).isOk());
}
@Test @Test
public void testNoGrowing() public void testNoGrowing()
{ {
@ -144,14 +182,23 @@ public class BufferGrouperTest
Assert.assertEquals(expected, Lists.newArrayList(grouper.iterator(true))); Assert.assertEquals(expected, Lists.newArrayList(grouper.iterator(true)));
} }
private static BufferGrouper<Integer> makeGrouper( private BufferGrouper<Integer> makeGrouper(
TestColumnSelectorFactory columnSelectorFactory, TestColumnSelectorFactory columnSelectorFactory,
int bufferSize, int bufferSize,
int initialBuckets int initialBuckets
) )
{ {
final MappedByteBuffer buffer;
try {
buffer = Files.map(temporaryFolder.newFile(), FileChannel.MapMode.READ_WRITE, bufferSize);
}
catch (IOException e) {
throw Throwables.propagate(e);
}
final BufferGrouper<Integer> grouper = new BufferGrouper<>( final BufferGrouper<Integer> grouper = new BufferGrouper<>(
Suppliers.ofInstance(ByteBuffer.allocate(bufferSize)), Suppliers.ofInstance(buffer),
GrouperTestUtil.intKeySerde(), GrouperTestUtil.intKeySerde(),
columnSelectorFactory, columnSelectorFactory,
new AggregatorFactory[]{ new AggregatorFactory[]{