From 59d257816b85dbeeca336b8e25d341d67bbc5697 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Thu, 2 Sep 2021 02:25:26 -0700 Subject: [PATCH] fix goldilocks bug with HashVectorGrouper improperly initializing memory (#11649) * fix goldilocks bug with HashVectorGrouper improperly initializing memory that causes failure when there exists room to only grow one time * fix unintended change * cleanup --- .../epinephelinae/HashVectorGrouper.java | 55 ++-- .../epinephelinae/HashVectorGrouperTest.java | 254 ++++++++++++++++++ 2 files changed, 278 insertions(+), 31 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/HashVectorGrouper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/HashVectorGrouper.java index dae166125a2..53e9b2d9934 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/HashVectorGrouper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/HashVectorGrouper.java @@ -19,6 +19,7 @@ package org.apache.druid.query.groupby.epinephelinae; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; import it.unimi.dsi.fastutil.HashCommon; import it.unimi.dsi.fastutil.ints.IntIterator; @@ -202,17 +203,7 @@ public class HashVectorGrouper implements VectorGrouper tableStart = buffer.capacity() - bucketSize * (maxNumBuckets - numBuckets); } - final ByteBuffer tableBuffer = buffer.duplicate(); - tableBuffer.position(0); - tableBuffer.limit(MemoryOpenHashTable.memoryNeeded(numBuckets, bucketSize)); - - this.hashTable = new MemoryOpenHashTable( - WritableMemory.wrap(tableBuffer.slice(), ByteOrder.nativeOrder()), - numBuckets, - Math.max(1, Math.min(bufferGrouperMaxSize, (int) (numBuckets * maxLoadFactor))), - keySize, - aggregators.spaceNeeded() - ); + this.hashTable = createTable(buffer, tableStart, numBuckets); } @Override @@ -268,6 +259,27 @@ public class HashVectorGrouper implements VectorGrouper aggregators.close(); } + @VisibleForTesting + public int getTableStart() + { + return tableStart; + } + + private MemoryOpenHashTable createTable(ByteBuffer buffer, int tableStart, int numBuckets) + { + final ByteBuffer tableBuffer = buffer.duplicate(); + tableBuffer.position(tableStart); + assert tableStart + MemoryOpenHashTable.memoryNeeded(numBuckets, bucketSize) <= buffer.capacity(); + tableBuffer.limit(tableStart + MemoryOpenHashTable.memoryNeeded(numBuckets, bucketSize)); + + return new MemoryOpenHashTable( + WritableMemory.wrap(tableBuffer.slice(), ByteOrder.nativeOrder()), + numBuckets, + Math.max(1, Math.min(bufferGrouperMaxSize, (int) (numBuckets * maxLoadFactor))), + keySize, + aggregators.spaceNeeded() + ); + } /** * Initializes the given bucket with the given key and fresh, empty aggregation state. Must only be called if @@ -307,17 +319,7 @@ public class HashVectorGrouper implements VectorGrouper final int newNumBuckets = nextTableNumBuckets(); final int newTableStart = nextTableStart(); - final ByteBuffer newTableBuffer = buffer.duplicate(); - newTableBuffer.position(newTableStart); - newTableBuffer.limit(newTableStart + MemoryOpenHashTable.memoryNeeded(newNumBuckets, bucketSize)); - - final MemoryOpenHashTable newHashTable = new MemoryOpenHashTable( - WritableMemory.wrap(newTableBuffer.slice(), ByteOrder.nativeOrder()), - newNumBuckets, - maxSizeForNumBuckets(newNumBuckets, maxLoadFactor, bufferGrouperMaxSize), - keySize, - aggregators.spaceNeeded() - ); + final MemoryOpenHashTable newHashTable = createTable(buffer, newTableStart, newNumBuckets); hashTable.copyTo(newHashTable, new HashVectorGrouperBucketCopyHandler(aggregators, hashTable.bucketValueOffset())); hashTable = newHashTable; @@ -382,15 +384,6 @@ public class HashVectorGrouper implements VectorGrouper return nextTableStart; } - /** - * Compute the maximum number of elements (size) for a given number of buckets. When the table hits this size, - * we must either grow it or return a table-full error. - */ - private static int maxSizeForNumBuckets(final int numBuckets, final double maxLoadFactor, final int configuredMaxSize) - { - return Math.max(1, Math.min(configuredMaxSize, (int) (numBuckets * maxLoadFactor))); - } - /** * Compute the initial table bucket count given a particular buffer capacity, bucket size, and user-configured * initial bucket count. diff --git a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/HashVectorGrouperTest.java b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/HashVectorGrouperTest.java index 46cd043289b..d5a863a7542 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/HashVectorGrouperTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/HashVectorGrouperTest.java @@ -20,7 +20,9 @@ package org.apache.druid.query.groupby.epinephelinae; import com.google.common.base.Suppliers; +import org.apache.datasketches.memory.WritableMemory; import org.apache.druid.query.aggregation.AggregatorAdapters; +import org.junit.Assert; import org.junit.Test; import org.mockito.Mockito; @@ -45,4 +47,256 @@ public class HashVectorGrouperTest grouper.close(); Mockito.verify(aggregatorAdapters, Mockito.times(1)).close(); } + + @Test + public void testTableStartIsNotMemoryStartIfNotMaxSized() + { + final int maxVectorSize = 512; + final int keySize = 4; + final int bufferSize = 100 * 1024; + final WritableMemory keySpace = WritableMemory.allocate(keySize * maxVectorSize); + final ByteBuffer buffer = ByteBuffer.wrap(new byte[bufferSize]); + final AggregatorAdapters aggregatorAdapters = Mockito.mock(AggregatorAdapters.class); + final HashVectorGrouper grouper = new HashVectorGrouper( + Suppliers.ofInstance(buffer), + keySize, + aggregatorAdapters, + 8, + 0.f, + 4 + ); + grouper.initVectorized(maxVectorSize); + Assert.assertNotEquals(0, grouper.getTableStart()); + } + + @Test + public void testTableStartIsNotMemoryStartIfIsMaxSized() + { + final int maxVectorSize = 512; + final int keySize = 10000; + final int bufferSize = 100 * 1024; + final ByteBuffer buffer = ByteBuffer.wrap(new byte[bufferSize]); + final AggregatorAdapters aggregatorAdapters = Mockito.mock(AggregatorAdapters.class); + final HashVectorGrouper grouper = new HashVectorGrouper( + Suppliers.ofInstance(buffer), + keySize, + aggregatorAdapters, + 4, + 0.f, + 4 + ); + grouper.initVectorized(maxVectorSize); + Assert.assertEquals(0, grouper.getTableStart()); + } + + @Test + public void testGrowOnce() + { + final int maxVectorSize = 512; + final int keySize = 4; + final int aggSize = 8; + final WritableMemory keySpace = WritableMemory.allocate(keySize * maxVectorSize); + + final AggregatorAdapters aggregatorAdapters = Mockito.mock(AggregatorAdapters.class); + Mockito.when(aggregatorAdapters.spaceNeeded()).thenReturn(aggSize); + + int startingNumBuckets = 4; + int maxBuckets = 16; + final int bufferSize = (keySize + aggSize) * maxBuckets; + final ByteBuffer buffer = ByteBuffer.wrap(new byte[bufferSize]); + final HashVectorGrouper grouper = new HashVectorGrouper( + Suppliers.ofInstance(buffer), + keySize, + aggregatorAdapters, + maxBuckets, + 0.f, + startingNumBuckets + ); + grouper.initVectorized(maxVectorSize); + + int tableStart = grouper.getTableStart(); + + // two keys should not cause buffer to grow + fillKeyspace(keySpace, maxVectorSize, 2); + AggregateResult result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertEquals(tableStart, grouper.getTableStart()); + + // 3rd key should cause buffer to grow + // buffer should grow to maximum size + fillKeyspace(keySpace, maxVectorSize, 3); + result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertEquals(0, grouper.getTableStart()); + } + + @Test + public void testGrowTwice() + { + final int maxVectorSize = 512; + final int keySize = 4; + final int aggSize = 8; + final WritableMemory keySpace = WritableMemory.allocate(keySize * maxVectorSize); + + final AggregatorAdapters aggregatorAdapters = Mockito.mock(AggregatorAdapters.class); + Mockito.when(aggregatorAdapters.spaceNeeded()).thenReturn(aggSize); + + int startingNumBuckets = 4; + int maxBuckets = 32; + final int bufferSize = (keySize + aggSize) * maxBuckets; + final ByteBuffer buffer = ByteBuffer.wrap(new byte[bufferSize]); + final HashVectorGrouper grouper = new HashVectorGrouper( + Suppliers.ofInstance(buffer), + keySize, + aggregatorAdapters, + maxBuckets, + 0.f, + startingNumBuckets + ); + grouper.initVectorized(maxVectorSize); + + int tableStart = grouper.getTableStart(); + + // two keys should not cause buffer to grow + fillKeyspace(keySpace, maxVectorSize, 2); + AggregateResult result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertEquals(tableStart, grouper.getTableStart()); + + // 3rd key should cause buffer to grow + // buffer should grow to next size, but is not full + fillKeyspace(keySpace, maxVectorSize, 3); + result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertTrue(grouper.getTableStart() > tableStart); + + // this time should be all the way + fillKeyspace(keySpace, maxVectorSize, 6); + result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertEquals(0, grouper.getTableStart()); + } + + @Test + public void testGrowThreeTimes() + { + final int maxVectorSize = 512; + final int keySize = 4; + final int aggSize = 8; + final WritableMemory keySpace = WritableMemory.allocate(keySize * maxVectorSize); + + final AggregatorAdapters aggregatorAdapters = Mockito.mock(AggregatorAdapters.class); + Mockito.when(aggregatorAdapters.spaceNeeded()).thenReturn(aggSize); + + int startingNumBuckets = 4; + int maxBuckets = 64; + final int bufferSize = (keySize + aggSize) * maxBuckets; + final ByteBuffer buffer = ByteBuffer.wrap(new byte[bufferSize]); + final HashVectorGrouper grouper = new HashVectorGrouper( + Suppliers.ofInstance(buffer), + keySize, + aggregatorAdapters, + maxBuckets, + 0.f, + startingNumBuckets + ); + grouper.initVectorized(maxVectorSize); + + int tableStart = grouper.getTableStart(); + + // two keys should cause buffer to grow + fillKeyspace(keySpace, maxVectorSize, 2); + AggregateResult result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertEquals(tableStart, grouper.getTableStart()); + + // 3rd key should cause buffer to grow + // buffer should grow to next size, but is not full + fillKeyspace(keySpace, maxVectorSize, 3); + result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertTrue(grouper.getTableStart() > tableStart); + tableStart = grouper.getTableStart(); + + // grow it again + fillKeyspace(keySpace, maxVectorSize, 6); + result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertTrue(grouper.getTableStart() > tableStart); + + // this time should be all the way + fillKeyspace(keySpace, maxVectorSize, 14); + result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertEquals(0, grouper.getTableStart()); + } + + @Test + public void testGrowFourTimes() + { + final int maxVectorSize = 512; + final int keySize = 4; + final int aggSize = 8; + final WritableMemory keySpace = WritableMemory.allocate(keySize * maxVectorSize); + + final AggregatorAdapters aggregatorAdapters = Mockito.mock(AggregatorAdapters.class); + Mockito.when(aggregatorAdapters.spaceNeeded()).thenReturn(aggSize); + + int startingNumBuckets = 4; + int maxBuckets = 128; + final int bufferSize = (keySize + aggSize) * maxBuckets; + final ByteBuffer buffer = ByteBuffer.wrap(new byte[bufferSize]); + final HashVectorGrouper grouper = new HashVectorGrouper( + Suppliers.ofInstance(buffer), + keySize, + aggregatorAdapters, + maxBuckets, + 0.f, + startingNumBuckets + ); + grouper.initVectorized(maxVectorSize); + + int tableStart = grouper.getTableStart(); + + // two keys should cause buffer to grow + fillKeyspace(keySpace, maxVectorSize, 2); + AggregateResult result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertEquals(tableStart, grouper.getTableStart()); + + // 3rd key should cause buffer to grow + // buffer should grow to next size, but is not full + fillKeyspace(keySpace, maxVectorSize, 3); + result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertTrue(grouper.getTableStart() > tableStart); + tableStart = grouper.getTableStart(); + + // grow it again + fillKeyspace(keySpace, maxVectorSize, 6); + result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertTrue(grouper.getTableStart() > tableStart); + tableStart = grouper.getTableStart(); + + // more + fillKeyspace(keySpace, maxVectorSize, 14); + result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertTrue(grouper.getTableStart() > tableStart); + + // this time should be all the way + fillKeyspace(keySpace, maxVectorSize, 25); + result = grouper.aggregateVector(keySpace, 0, maxVectorSize); + Assert.assertTrue(result.isOk()); + Assert.assertEquals(0, grouper.getTableStart()); + } + + private void fillKeyspace(WritableMemory keySpace, int maxVectorSize, int distinctKeys) + { + for (int i = 0; i < maxVectorSize; i++) { + int bucket = i % distinctKeys; + keySpace.putInt(((long) Integer.BYTES * i), bucket); + } + } }