diff --git a/benchmarks/pom.xml b/benchmarks/pom.xml index c53b80f0336..c9327170c07 100644 --- a/benchmarks/pom.xml +++ b/benchmarks/pom.xml @@ -158,7 +158,10 @@ org.apache.datasketches datasketches-java - 1.1.0-incubating + + + org.apache.datasketches + datasketches-memory junit diff --git a/benchmarks/src/main/java/org/apache/druid/benchmark/MemoryBenchmark.java b/benchmarks/src/main/java/org/apache/druid/benchmark/MemoryBenchmark.java new file mode 100644 index 00000000000..56fd68c922a --- /dev/null +++ b/benchmarks/src/main/java/org/apache/druid/benchmark/MemoryBenchmark.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.benchmark; + +import org.apache.datasketches.memory.WritableMemory; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.query.groupby.epinephelinae.collection.HashTableUtils; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 10) +@Measurement(iterations = 15) +public class MemoryBenchmark +{ + static { + NullHandling.initializeForTests(); + } + + @Param({"4", "5", "8", "9", "12", "16", "31", "32", "64", "128"}) + public int numBytes; + + @Param({"offheap"}) + public String where; + + private ByteBuffer buffer1; + private ByteBuffer buffer2; + private ByteBuffer buffer3; + private WritableMemory memory1; + private WritableMemory memory2; + private WritableMemory memory3; + + @Setup + public void setUp() + { + if ("onheap".equals(where)) { + buffer1 = ByteBuffer.allocate(numBytes).order(ByteOrder.nativeOrder()); + buffer2 = ByteBuffer.allocate(numBytes).order(ByteOrder.nativeOrder()); + buffer3 = ByteBuffer.allocate(numBytes).order(ByteOrder.nativeOrder()); + } else if ("offheap".equals(where)) { + buffer1 = ByteBuffer.allocateDirect(numBytes).order(ByteOrder.nativeOrder()); + buffer2 = ByteBuffer.allocateDirect(numBytes).order(ByteOrder.nativeOrder()); + buffer3 = ByteBuffer.allocateDirect(numBytes).order(ByteOrder.nativeOrder()); + } + + memory1 = WritableMemory.wrap(buffer1, ByteOrder.nativeOrder()); + memory2 = WritableMemory.wrap(buffer2, ByteOrder.nativeOrder()); + memory3 = WritableMemory.wrap(buffer3, ByteOrder.nativeOrder()); + + // Scribble in some random but consistent (same seed) garbage. + final Random random = new Random(0); + for (int i = 0; i < numBytes; i++) { + memory1.putByte(i, (byte) random.nextInt()); + } + + // memory1 == memory2 + memory1.copyTo(0, memory2, 0, numBytes); + + // memory1 != memory3, but only slightly (different in a middle byte; an attempt to not favor leftward moving vs + // rightward moving equality checks). + memory1.copyTo(0, memory3, 0, numBytes); + memory3.putByte(numBytes / 2, (byte) (~memory3.getByte(numBytes / 2))); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void equals_byteBuffer_whenEqual(Blackhole blackhole) + { + blackhole.consume(buffer1.equals(buffer2)); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void equals_byteBuffer_whenDifferent(Blackhole blackhole) + { + blackhole.consume(buffer1.equals(buffer3)); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void equals_hashTableUtils_whenEqual(Blackhole blackhole) + { + blackhole.consume(HashTableUtils.memoryEquals(memory1, 0, memory2, 0, numBytes)); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void equals_hashTableUtils_whenDifferent(Blackhole blackhole) + { + blackhole.consume(HashTableUtils.memoryEquals(memory1, 0, memory3, 0, numBytes)); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void equals_memoryEqualTo_whenEqual(Blackhole blackhole) + { + blackhole.consume(memory1.equalTo(0, memory2, 0, numBytes)); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void equals_memoryEqualTo_whenDifferent(Blackhole blackhole) + { + blackhole.consume(memory1.equalTo(0, memory3, 0, numBytes)); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void hash_byteBufferHashCode(Blackhole blackhole) + { + blackhole.consume(buffer1.hashCode()); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void hash_hashTableUtils(Blackhole blackhole) + { + blackhole.consume(HashTableUtils.hashMemory(memory1, 0, numBytes)); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void hash_memoryXxHash64(Blackhole blackhole) + { + blackhole.consume(memory1.xxHash64(0, numBytes, 0)); + } +} diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index 28cd7051a54..74376a22ac9 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -474,6 +474,7 @@ public class HllSketchSqlAggregatorTest extends CalciteTestBase + " Log Config K : 12\n" + " Hll Target : HLL_4\n" + " Current Mode : LIST\n" + + " Memory : false\n" + " LB : 2.0\n" + " Estimate : 2.000000004967054\n" + " UB : 2.000099863468538\n" @@ -483,6 +484,7 @@ public class HllSketchSqlAggregatorTest extends CalciteTestBase + " LOG CONFIG K : 12\n" + " HLL TARGET : HLL_4\n" + " CURRENT MODE : LIST\n" + + " MEMORY : FALSE\n" + " LB : 2.0\n" + " ESTIMATE : 2.000000004967054\n" + " UB : 2.000099863468538\n" @@ -611,6 +613,7 @@ public class HllSketchSqlAggregatorTest extends CalciteTestBase + " Log Config K : 12\n" + " Hll Target : HLL_4\n" + " Current Mode : LIST\n" + + " Memory : false\n" + " LB : 2.0\n" + " Estimate : 2.000000004967054\n" + " UB : 2.000099863468538\n" diff --git a/pom.xml b/pom.xml index e34f3222392..10d3014f217 100644 --- a/pom.xml +++ b/pom.xml @@ -82,6 +82,7 @@ 1.15.0 1.9.1 1.21.0 + 1.2.0-incubating 10.14.2.0 4.0.0 16.0.1 @@ -1000,12 +1001,12 @@ org.apache.datasketches datasketches-java - 1.1.0-incubating + ${datasketches.version} org.apache.datasketches datasketches-memory - 1.2.0-incubating + ${datasketches.version} org.apache.calcite diff --git a/processing/pom.xml b/processing/pom.xml index 0629c5837ff..e1b6b8cf39e 100644 --- a/processing/pom.xml +++ b/processing/pom.xml @@ -156,6 +156,10 @@ javax.validation validation-api + + org.apache.datasketches + datasketches-memory + diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/Groupers.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/Groupers.java index a1d8dbf816e..2ed89161819 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/Groupers.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/Groupers.java @@ -56,7 +56,7 @@ public class Groupers * MurmurHash3 was written by Austin Appleby, and is placed in the public domain. The author * hereby disclaims copyright to this source code. */ - private static int smear(int hashCode) + public static int smear(int hashCode) { return C2 * Integer.rotateLeft(hashCode * C1, 15); } diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/collection/HashTableUtils.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/collection/HashTableUtils.java new file mode 100644 index 00000000000..574bd080a48 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/collection/HashTableUtils.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.collection; + +import org.apache.datasketches.memory.Memory; + +public class HashTableUtils +{ + private HashTableUtils() + { + // No instantiation. + } + + /** + * Computes the previous power of two less than or equal to a given "n". + * + * The integer should be between 1 (inclusive) and {@link Integer#MAX_VALUE} for best results. Other parameters will + * return {@link Integer#MIN_VALUE}. + */ + public static int previousPowerOfTwo(final int n) + { + if (n > 0) { + return Integer.highestOneBit(n); + } else { + return Integer.MIN_VALUE; + } + } + + /** + * Compute a simple, fast hash code of some memory range. + * + * @param memory a region of memory + * @param position position within the memory region + * @param length length of memory to hash, starting at the position + */ + public static int hashMemory(final Memory memory, final long position, final int length) + { + // Special cases for small, common key sizes to speed them up: e.g. one int key, two int keys, one long key, etc. + // The plus-one sizes (9, 13) are for nullable dimensions. The specific choices of special cases were chosen based + // on benchmarking (see MemoryBenchmark) on a Skylake-based cloud instance. + + switch (length) { + case 4: + return memory.getInt(position); + + case 8: + return 31 * (31 + memory.getInt(position)) + memory.getInt(position + Integer.BYTES); + + case 9: + return 31 * (31 * (31 + memory.getInt(position)) + memory.getInt(position + Integer.BYTES)) + + memory.getByte(position + 2 * Integer.BYTES); + + case 12: + return 31 * (31 * (31 + memory.getInt(position)) + memory.getInt(position + Integer.BYTES)) + + memory.getInt(position + 2 * Integer.BYTES); + + case 13: + return 31 * (31 * (31 * (31 + memory.getInt(position)) + memory.getInt(position + Integer.BYTES)) + + memory.getInt(position + 2 * Integer.BYTES)) + memory.getByte(position + 3 * Integer.BYTES); + + case 16: + return 31 * (31 * (31 * (31 + memory.getInt(position)) + memory.getInt(position + Integer.BYTES)) + + memory.getInt(position + 2 * Integer.BYTES)) + memory.getInt(position + 3 * Integer.BYTES); + + default: + int hashCode = 1; + int remainingBytes = length; + long pos = position; + + while (remainingBytes >= Integer.BYTES) { + hashCode = 31 * hashCode + memory.getInt(pos); + remainingBytes -= Integer.BYTES; + pos += Integer.BYTES; + } + + if (remainingBytes == 1) { + hashCode = 31 * hashCode + memory.getByte(pos); + } else if (remainingBytes == 2) { + hashCode = 31 * hashCode + memory.getByte(pos); + hashCode = 31 * hashCode + memory.getByte(pos + 1); + } else if (remainingBytes == 3) { + hashCode = 31 * hashCode + memory.getByte(pos); + hashCode = 31 * hashCode + memory.getByte(pos + 1); + hashCode = 31 * hashCode + memory.getByte(pos + 2); + } + + return hashCode; + } + } + + /** + * Compare two memory ranges for equality. + * + * The purpose of this function is to be faster than {@link Memory#equalTo} for the small memory ranges that + * typically comprise keys in hash tables. As of this writing, it is. See "MemoryBenchmark" in the druid-benchmarks + * module for performance evaluation code. + * + * @param memory1 a region of memory + * @param offset1 position within the first memory region + * @param memory2 another region of memory + * @param offset2 position within the second memory region + * @param length length of memory to compare, starting at the positions + */ + public static boolean memoryEquals( + final Memory memory1, + final long offset1, + final Memory memory2, + final long offset2, + final int length + ) + { + // Special cases for small, common key sizes to speed them up: e.g. one int key, two int keys, one long key, etc. + // The plus-one sizes (9, 13) are for nullable dimensions. The specific choices of special cases were chosen based + // on benchmarking (see MemoryBenchmark) on a Skylake-based cloud instance. + + switch (length) { + case 4: + return memory1.getInt(offset1) == memory2.getInt(offset2); + + case 8: + return memory1.getLong(offset1) == memory2.getLong(offset2); + + case 9: + return memory1.getLong(offset1) == memory2.getLong(offset2) + && memory1.getByte(offset1 + Long.BYTES) == memory2.getByte(offset2 + Long.BYTES); + + case 12: + return memory1.getInt(offset1) == memory2.getInt(offset2) + && memory1.getLong(offset1 + Integer.BYTES) == memory2.getLong(offset2 + Integer.BYTES); + + case 13: + return memory1.getLong(offset1) == memory2.getLong(offset2) + && memory1.getInt(offset1 + Long.BYTES) == memory2.getInt(offset2 + Long.BYTES) + && (memory1.getByte(offset1 + Integer.BYTES + Long.BYTES) + == memory2.getByte(offset2 + Integer.BYTES + Long.BYTES)); + + case 16: + return memory1.getLong(offset1) == memory2.getLong(offset2) + && memory1.getLong(offset1 + Long.BYTES) == memory2.getLong(offset2 + Long.BYTES); + + default: + return memory1.equalTo(offset1, memory2, offset2, length); + } + } +} diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/collection/MemoryOpenHashTable.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/collection/MemoryOpenHashTable.java new file mode 100644 index 00000000000..f8ae15c6f3c --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/collection/MemoryOpenHashTable.java @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.collection; + +import it.unimi.dsi.fastutil.ints.IntIterator; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.memory.WritableMemory; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.query.groupby.epinephelinae.Groupers; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.NoSuchElementException; + +/** + * An open-addressed hash table with linear probing backed by {@link WritableMemory}. Does not offer a similar + * interface to {@link java.util.Map} because this is meant to be useful to lower-level, high-performance callers. + * There is no copying or serde of keys and values: callers access the backing memory of the table directly. + * + * This table will not grow itself. Callers must handle growing if required; the {@link #copyTo} method is provided + * to assist. + */ +public class MemoryOpenHashTable +{ + private static final byte USED_BYTE = 1; + private static final int USED_BYTE_SIZE = Byte.BYTES; + + private final WritableMemory tableMemory; + private final int keySize; + private final int valueSize; + private final int bucketSize; + + // Maximum number of elements in the table (based on numBuckets and maxLoadFactor). + private final int maxSize; + + // Number of available/used buckets in the table. Always a power of two. + private final int numBuckets; + + // Mask that clips a number to [0, numBuckets). Used when searching through buckets. + private final int bucketMask; + + // Number of elements in the table right now. + private int size; + + /** + * Create a new table. + * + * @param tableMemory backing memory for the table; must be exactly large enough to hold "numBuckets" + * @param numBuckets number of buckets for the table + * @param maxSize maximum number of elements for the table; must be less than numBuckets + * @param keySize key size in bytes + * @param valueSize value size in bytes + */ + public MemoryOpenHashTable( + final WritableMemory tableMemory, + final int numBuckets, + final int maxSize, + final int keySize, + final int valueSize + ) + { + this.tableMemory = tableMemory; + this.numBuckets = numBuckets; + this.bucketMask = numBuckets - 1; + this.maxSize = maxSize; + this.keySize = keySize; + this.valueSize = valueSize; + this.bucketSize = bucketSize(keySize, valueSize); + + // Our main intended users (VectorGrouper implementations) need the tableMemory to be backed by a big-endian + // ByteBuffer that is coterminous with the tableMemory, since it's going to feed that buffer into VectorAggregators + // instead of interacting with our WritableMemory directly. Nothing about this class actually requires that the + // Memory be backed by a ByteBuffer, but we'll check it here anyway for the benefit of our biggest customer. + verifyMemoryIsByteBuffer(tableMemory); + + if (!tableMemory.getTypeByteOrder().equals(ByteOrder.nativeOrder())) { + throw new ISE("tableMemory must be native byte order"); + } + + if (tableMemory.getCapacity() != memoryNeeded(numBuckets, bucketSize)) { + throw new ISE( + "tableMemory must be size[%,d] but was[%,d]", + memoryNeeded(numBuckets, bucketSize), + tableMemory.getCapacity() + ); + } + + if (maxSize >= numBuckets) { + throw new ISE("maxSize must be less than numBuckets"); + } + + if (Integer.bitCount(numBuckets) != 1) { + throw new ISE("numBuckets must be a power of two but was[%,d]", numBuckets); + } + + clear(); + } + + /** + * Returns the amount of memory needed for a table. + * + * This is just a multiplication, which is easy enough to do on your own, but sometimes it's nice for clarity's sake + * to call a function with a name that indicates why the multiplication is happening. + * + * @param numBuckets number of buckets + * @param bucketSize size per bucket (in bytes) + * + * @return size of table (in bytes) + */ + public static int memoryNeeded(final int numBuckets, final int bucketSize) + { + return numBuckets * bucketSize; + } + + /** + * Returns the size of each bucket in a table. + * + * @param keySize size of keys (in bytes) + * @param valueSize size of values (in bytes) + * + * @return size of buckets (in bytes) + */ + public static int bucketSize(final int keySize, final int valueSize) + { + return USED_BYTE_SIZE + keySize + valueSize; + } + + /** + * Clear the table, resetting size to zero. + */ + public void clear() + { + size = 0; + + // Clear used flags. + for (int bucket = 0; bucket < numBuckets; bucket++) { + tableMemory.putByte((long) bucket * bucketSize, (byte) 0); + } + } + + /** + * Copy this table into another one. The other table must be large enough to hold all the copied buckets. The other + * table will be cleared before the copy takes place. + * + * @param other the other table + * @param copyHandler a callback that is notified for each copied bucket + */ + public void copyTo(final MemoryOpenHashTable other, @Nullable final BucketCopyHandler copyHandler) + { + if (other.size() > 0) { + other.clear(); + } + + for (int bucket = 0; bucket < numBuckets; bucket++) { + final int bucketOffset = bucket * bucketSize; + if (isOffsetUsed(bucketOffset)) { + final int keyPosition = bucketOffset + USED_BYTE_SIZE; + final int keyHash = Groupers.smear(HashTableUtils.hashMemory(tableMemory, keyPosition, keySize)); + final int newBucket = other.findBucket(keyHash, tableMemory, keyPosition); + + if (newBucket >= 0) { + // Not expected to happen, since we cleared the other table first. + throw new ISE("Found already-used bucket while copying"); + } + + if (!other.canInsertNewBucket()) { + throw new ISE("Unable to copy bucket to new table, size[%,d]", other.size()); + } + + final int newBucketOffset = -(newBucket + 1) * bucketSize; + assert !other.isOffsetUsed(newBucketOffset); + tableMemory.copyTo(bucketOffset, other.tableMemory, newBucketOffset, bucketSize); + other.size++; + + if (copyHandler != null) { + copyHandler.bucketCopied(bucket, -(newBucket + 1), this, other); + } + } + } + + // Sanity check. + if (other.size() != size) { + throw new ISE("New table size[%,d] != old table size[%,d] after copying", other.size(), size); + } + } + + /** + * Finds the bucket for a particular key. + * + * @param keyHash result of calling {@link HashTableUtils#hashMemory} on this key + * @param keySpace memory containing the key + * @param keySpacePosition position of the key within keySpace + * + * @return bucket number if currently occupied, or {@code -bucket - 1} if not occupied (yet) + */ + public int findBucket(final int keyHash, final Memory keySpace, final int keySpacePosition) + { + int bucket = keyHash & bucketMask; + + while (true) { + final int bucketOffset = bucket * bucketSize; + + if (tableMemory.getByte(bucketOffset) == 0) { + // Found unused bucket before finding our key. + return -bucket - 1; + } + + final boolean keyFound = HashTableUtils.memoryEquals( + tableMemory, + bucketOffset + USED_BYTE_SIZE, + keySpace, + keySpacePosition, + keySize + ); + + if (keyFound) { + return bucket; + } + + bucket = (bucket + 1) & bucketMask; + } + } + + /** + * Returns whether this table can accept a new bucket. + */ + public boolean canInsertNewBucket() + { + return size < maxSize; + } + + /** + * Initialize a bucket with a particular key. + * + * Do not call this method unless the bucket is currently unused and {@link #canInsertNewBucket()} returns true. + * + * @param bucket bucket number + * @param keySpace memory containing the key + * @param keySpacePosition position of the key within keySpace + */ + public void initBucket(final int bucket, final Memory keySpace, final int keySpacePosition) + { + final int bucketOffset = bucket * bucketSize; + + // Method preconditions. + assert canInsertNewBucket() && !isOffsetUsed(bucketOffset); + + // Mark the bucket used and write in the key. + tableMemory.putByte(bucketOffset, USED_BYTE); + keySpace.copyTo(keySpacePosition, tableMemory, bucketOffset + USED_BYTE_SIZE, keySize); + size++; + } + + /** + * Returns the number of elements currently in the table. + */ + public int size() + { + return size; + } + + /** + * Returns the number of buckets in this table. Note that not all of these can actually be used. The amount that + * can be used depends on the "maxSize" parameter provided during construction. + */ + public int numBuckets() + { + return numBuckets; + } + + /** + * Returns the size of keys, in bytes. + */ + public int keySize() + { + return keySize; + } + + /** + * Returns the size of values, in bytes. + */ + public int valueSize() + { + return valueSize; + } + + /** + * Returns the offset within each bucket where the key starts. + */ + public int bucketKeyOffset() + { + return USED_BYTE_SIZE; + } + + /** + * Returns the offset within each bucket where the value starts. + */ + public int bucketValueOffset() + { + return USED_BYTE_SIZE + keySize; + } + + /** + * Returns the size in bytes of each bucket. + */ + public int bucketSize() + { + return bucketSize; + } + + /** + * Returns the position within {@link #memory()} where a particular bucket starts. + */ + public int bucketMemoryPosition(final int bucket) + { + return bucket * bucketSize; + } + + /** + * Returns the memory backing this table. + */ + public WritableMemory memory() + { + return tableMemory; + } + + /** + * Iterates over all used buckets, returning bucket numbers for each one. + * + * The intent is that callers will pass the bucket numbers to {@link #bucketMemoryPosition} and then use + * {@link #bucketKeyOffset()} and {@link #bucketValueOffset()} to extract keys and values from the buckets as needed. + */ + public IntIterator bucketIterator() + { + return new IntIterator() + { + private int curr = 0; + private int currBucket = -1; + + @Override + public boolean hasNext() + { + return curr < size; + } + + @Override + public int nextInt() + { + if (curr >= size) { + throw new NoSuchElementException(); + } + + currBucket++; + + while (!isOffsetUsed(currBucket * bucketSize)) { + currBucket++; + } + + curr++; + return currBucket; + } + }; + } + + /** + * Returns whether the bucket at position "bucketOffset" is used or not. Note that this is a bucket position (in + * bytes), not a bucket number. + */ + private boolean isOffsetUsed(final int bucketOffset) + { + return tableMemory.getByte(bucketOffset) == USED_BYTE; + } + + /** + * Validates that some Memory is coterminous with a backing big-endian ByteBuffer. Returns quietly if so, throws an + * exception otherwise. + */ + private static void verifyMemoryIsByteBuffer(final Memory memory) + { + final ByteBuffer buffer = memory.getByteBuffer(); + + if (buffer == null) { + throw new ISE("tableMemory must be ByteBuffer-backed"); + } + + if (!buffer.order().equals(ByteOrder.BIG_ENDIAN)) { + throw new ISE("tableMemory's ByteBuffer must be in big-endian order"); + } + + if (buffer.capacity() != memory.getCapacity() || buffer.remaining() != buffer.capacity()) { + throw new ISE("tableMemory's ByteBuffer must be coterminous"); + } + } + + /** + * Callback used by {@link #copyTo}. + */ + public interface BucketCopyHandler + { + /** + * Indicates that "oldBucket" in "oldTable" was copied to "newBucket" in "newTable". + * + * @param oldBucket old bucket number + * @param newBucket new bucket number + * @param oldTable old table + * @param newTable new table + */ + void bucketCopied( + int oldBucket, + int newBucket, + MemoryOpenHashTable oldTable, + MemoryOpenHashTable newTable + ); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/collection/HashTableUtilsTest.java b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/collection/HashTableUtilsTest.java new file mode 100644 index 00000000000..226b6152c26 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/collection/HashTableUtilsTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.collection; + +import it.unimi.dsi.fastutil.ints.Int2IntLinkedOpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2IntMap; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.memory.WritableMemory; +import org.apache.druid.java.util.common.StringUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Random; + +public class HashTableUtilsTest +{ + @Test + public void test_previousPowerOfTwo() + { + final Int2IntMap expectedResults = new Int2IntLinkedOpenHashMap(); + expectedResults.put(Integer.MIN_VALUE, Integer.MIN_VALUE); + expectedResults.put(Integer.MIN_VALUE + 1, Integer.MIN_VALUE); + expectedResults.put(-4, Integer.MIN_VALUE); + expectedResults.put(-3, Integer.MIN_VALUE); + expectedResults.put(-2, Integer.MIN_VALUE); + expectedResults.put(-1, Integer.MIN_VALUE); + expectedResults.put(0, Integer.MIN_VALUE); + expectedResults.put(1, 1); + expectedResults.put(2, 2); + expectedResults.put(3, 2); + expectedResults.put(4, 4); + expectedResults.put(5, 4); + expectedResults.put(6, 4); + expectedResults.put(7, 4); + expectedResults.put(8, 8); + expectedResults.put((1 << 30) - 1, 1 << 29); + expectedResults.put(1 << 30, 1 << 30); + expectedResults.put((1 << 30) + 1, 1073741824); + expectedResults.put(Integer.MAX_VALUE - 1, 1073741824); + expectedResults.put(Integer.MAX_VALUE, 1073741824); + + for (final Int2IntMap.Entry entry : expectedResults.int2IntEntrySet()) { + Assert.assertEquals( + entry.getIntKey() + " => " + entry.getIntValue(), + entry.getIntValue(), + HashTableUtils.previousPowerOfTwo(entry.getIntKey()) + ); + } + } + + private static WritableMemory generateRandomButNotReallyRandomMemory(final int length) + { + final WritableMemory randomMemory = WritableMemory.allocate(length); + + // Fill with random bytes, but use the same seed every run for consistency. This test should pass with + // any seed unless something really pathological happens. + final Random random = new Random(0); + final byte[] randomBytes = new byte[length]; + random.nextBytes(randomBytes); + randomMemory.putByteArray(0, randomBytes, 0, length); + return randomMemory; + } + + @Test + public void test_hashMemory_allByteLengthsUpTo128() + { + // This test validates that we *can* hash any amount of memory up to 128 bytes, and that if any bit is flipped + // in the memory then the hash changes. It doesn't validate that the hash function is actually good at dispersion. + // That also has a big impact on performance and needs to be checked separately if the hash function is changed. + + final int maxBytes = 128; + final WritableMemory randomMemory = generateRandomButNotReallyRandomMemory(maxBytes); + + for (int numBytes = 0; numBytes < maxBytes; numBytes++) { + // Grab "numBytes" bytes from the end of randomMemory. + final Memory regionToHash = randomMemory.region(maxBytes - numBytes, numBytes); + + // Validate that hashing regionAtEnd is equivalent to hashing the end of a region. This helps validate + // that using a nonzero position is effective. + Assert.assertEquals( + StringUtils.format("numBytes[%s] nonzero position check", numBytes), + HashTableUtils.hashMemory(regionToHash, 0, numBytes), + HashTableUtils.hashMemory(randomMemory, maxBytes - numBytes, numBytes) + ); + + // Copy the memory and make sure we did it right. + final WritableMemory copyOfRegion = WritableMemory.allocate(numBytes); + regionToHash.copyTo(0, copyOfRegion, 0, numBytes); + Assert.assertTrue( + StringUtils.format("numBytes[%s] copy equality check", numBytes), + regionToHash.equalTo(0, copyOfRegion, 0, numBytes) + ); + + // Validate that flipping any bit affects the hash. + for (int bit = 0; bit < numBytes * Byte.SIZE; bit++) { + final int bytePosition = bit / Byte.SIZE; + final byte mask = (byte) (1 << (bit % Byte.SIZE)); + + copyOfRegion.putByte( + bytePosition, + (byte) (copyOfRegion.getByte(bytePosition) ^ mask) + ); + + Assert.assertNotEquals( + StringUtils.format("numBytes[%s] bit[%s] flip check", numBytes, bit), + HashTableUtils.hashMemory(regionToHash, 0, numBytes), + HashTableUtils.hashMemory(copyOfRegion, 0, numBytes) + ); + + // Set it back and make sure we did it right. + copyOfRegion.putByte( + bytePosition, + (byte) (copyOfRegion.getByte(bytePosition) ^ mask) + ); + + Assert.assertTrue( + StringUtils.format("numBytes[%s] bit[%s] reset check", numBytes, bit), + regionToHash.equalTo(0, copyOfRegion, 0, numBytes) + ); + } + } + } + + @Test + public void test_memoryEquals_allByteLengthsUpTo128() + { + // This test validates that we can compare any two slices of memory of size up to 128 bytes, and that if any bit + // is flipped in two identical memory slices, then the comparison correctly returns not equal. + + final int maxBytes = 128; + final WritableMemory randomMemory = generateRandomButNotReallyRandomMemory(maxBytes); + + for (int numBytes = 0; numBytes < maxBytes; numBytes++) { + // Copy "numBytes" from the end of randomMemory. + final WritableMemory copyOfRegion = WritableMemory.allocate(numBytes); + randomMemory.copyTo(maxBytes - numBytes, copyOfRegion, 0, numBytes); + + // Compare the two. + Assert.assertTrue( + StringUtils.format("numBytes[%s] nonzero position check", numBytes), + HashTableUtils.memoryEquals(randomMemory, maxBytes - numBytes, copyOfRegion, 0, numBytes) + ); + + // Validate that flipping any bit affects equality. + for (int bit = 0; bit < numBytes * Byte.SIZE; bit++) { + final int bytePosition = bit / Byte.SIZE; + final byte mask = (byte) (1 << (bit % Byte.SIZE)); + + copyOfRegion.putByte( + bytePosition, + (byte) (copyOfRegion.getByte(bytePosition) ^ mask) + ); + + Assert.assertFalse( + StringUtils.format("numBytes[%s] bit[%s] flip check", numBytes, bit), + HashTableUtils.memoryEquals(randomMemory, maxBytes - numBytes, copyOfRegion, 0, numBytes) + ); + + // Set it back and make sure we did it right. + copyOfRegion.putByte( + bytePosition, + (byte) (copyOfRegion.getByte(bytePosition) ^ mask) + ); + + Assert.assertTrue( + StringUtils.format("numBytes[%s] bit[%s] reset check", numBytes, bit), + HashTableUtils.memoryEquals(randomMemory, maxBytes - numBytes, copyOfRegion, 0, numBytes) + ); + } + } + } +} diff --git a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/collection/MemoryOpenHashTableTest.java b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/collection/MemoryOpenHashTableTest.java new file mode 100644 index 00000000000..f7f76de747a --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/collection/MemoryOpenHashTableTest.java @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.collection; + +import com.google.common.collect.ImmutableMap; +import it.unimi.dsi.fastutil.ints.Int2IntMap; +import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap; +import it.unimi.dsi.fastutil.ints.IntIterator; +import org.apache.datasketches.memory.WritableMemory; +import org.apache.druid.java.util.common.Pair; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +public class MemoryOpenHashTableTest +{ + @Test + public void testMemoryNeeded() + { + Assert.assertEquals(512, MemoryOpenHashTable.memoryNeeded(128, 4)); + } + + @Test + public void testEmptyTable() + { + final MemoryOpenHashTable table = createTable(8, .75, 4, 4); + + Assert.assertEquals(8, table.numBuckets()); + Assert.assertEquals(72, table.memory().getCapacity()); + Assert.assertEquals(9, table.bucketSize()); + assertEqualsMap(ImmutableMap.of(), table); + } + + @Test + public void testInsertRepeatedKeys() + { + final MemoryOpenHashTable table = createTable(8, .7, Integer.BYTES, Integer.BYTES); + final WritableMemory keyMemory = WritableMemory.allocate(Integer.BYTES); + + // Insert the following keys repeatedly. + final int[] keys = {0, 1, 2}; + + for (int i = 0; i < 3; i++) { + for (int key : keys) { + // Find bucket for key. + keyMemory.putInt(0, key); + int bucket = table.findBucket(HashTableUtils.hashMemory(keyMemory, 0, Integer.BYTES), keyMemory, 0); + + if (bucket < 0) { + Assert.assertTrue(table.canInsertNewBucket()); + bucket = -(bucket + 1); + table.initBucket(bucket, keyMemory, 0); + final int valuePosition = table.bucketMemoryPosition(bucket) + table.bucketValueOffset(); + + // Initialize to zero. + table.memory().putInt(valuePosition, 0); + } + + // Add the key. + final int valuePosition = table.bucketMemoryPosition(bucket) + table.bucketValueOffset(); + table.memory().putInt( + valuePosition, + table.memory().getInt(valuePosition) + key + ); + } + } + + final Map expectedMap = new HashMap<>(); + expectedMap.put(expectedKey(0), expectedValue(0)); + expectedMap.put(expectedKey(1), expectedValue(3)); + expectedMap.put(expectedKey(2), expectedValue(6)); + + assertEqualsMap(expectedMap, table); + } + + @Test + public void testInsertDifferentKeysUntilFull() + { + final MemoryOpenHashTable table = createTable(256, .999, Integer.BYTES, Integer.BYTES); + final Map expectedMap = new HashMap<>(); + + int key = 0; + while (table.canInsertNewBucket()) { + final int value = Integer.MAX_VALUE - key; + + // Find bucket for key (which should not already exist). + final int bucket = findAndInitBucket(table, key); + Assert.assertTrue("bucket < 0 for key " + key, bucket < 0); + + // Insert bucket and write value. + writeValueToBucket(table, -(bucket + 1), value); + expectedMap.put(expectedKey(key), expectedValue(value)); + + key += 7; + } + + // This table should fill up at 255 elements (256 buckets, .999 load factor) + Assert.assertEquals("expected size", 255, table.size()); + assertEqualsMap(expectedMap, table); + } + + @Test + public void testCopyTo() + { + final MemoryOpenHashTable table1 = createTable(64, .7, Integer.BYTES, Integer.BYTES); + final MemoryOpenHashTable table2 = createTable(128, .7, Integer.BYTES, Integer.BYTES); + + final Int2IntMap expectedMap = new Int2IntOpenHashMap(); + expectedMap.put(0, 1); + expectedMap.put(-1, 2); + expectedMap.put(111, 123); + expectedMap.put(Integer.MAX_VALUE, Integer.MIN_VALUE); + expectedMap.put(Integer.MIN_VALUE, Integer.MAX_VALUE); + + // Populate table1. + for (Int2IntMap.Entry entry : expectedMap.int2IntEntrySet()) { + final int bucket = findAndInitBucket(table1, entry.getIntKey()); + Assert.assertTrue("bucket < 0 for key " + entry.getIntKey(), bucket < 0); + writeValueToBucket(table1, -(bucket + 1), entry.getIntValue()); + } + + // Copy to table2. + table1.copyTo(table2, ((oldBucket, newBucket, oldTable, newTable) -> { + Assert.assertSame(table1, oldTable); + Assert.assertSame(table2, newTable); + })); + + // Compute expected map to compare these tables to. + final Map expectedByteBufferMap = + expectedMap.int2IntEntrySet() + .stream() + .collect( + Collectors.toMap( + entry -> expectedKey(entry.getIntKey()), + entry -> expectedValue(entry.getIntValue()) + ) + ); + + assertEqualsMap(expectedByteBufferMap, table1); + assertEqualsMap(expectedByteBufferMap, table2); + } + + @Test + public void testClear() + { + final MemoryOpenHashTable table = createTable(64, .7, Integer.BYTES, Integer.BYTES); + + final Int2IntMap expectedMap = new Int2IntOpenHashMap(); + expectedMap.put(0, 1); + expectedMap.put(-1, 2); + + // Populate table. + for (Int2IntMap.Entry entry : expectedMap.int2IntEntrySet()) { + final int bucket = findAndInitBucket(table, entry.getIntKey()); + Assert.assertTrue("bucket < 0 for key " + entry.getIntKey(), bucket < 0); + writeValueToBucket(table, -(bucket + 1), entry.getIntValue()); + } + + // Compute expected map to compare these tables to. + final Map expectedByteBufferMap = + expectedMap.int2IntEntrySet() + .stream() + .collect( + Collectors.toMap( + entry -> expectedKey(entry.getIntKey()), + entry -> expectedValue(entry.getIntValue()) + ) + ); + + assertEqualsMap(expectedByteBufferMap, table); + + // Clear and verify. + table.clear(); + + assertEqualsMap(ImmutableMap.of(), table); + } + + /** + * Finds the bucket for the provided key using {@link MemoryOpenHashTable#findBucket} and initializes it if empty + * using {@link MemoryOpenHashTable#initBucket}. Same return value as {@link MemoryOpenHashTable#findBucket}. + */ + private static int findAndInitBucket(final MemoryOpenHashTable table, final int key) + { + final int keyMemoryPosition = 1; // Helps verify that offsets work + final WritableMemory keyMemory = WritableMemory.allocate(Integer.BYTES + 1); + + keyMemory.putInt(keyMemoryPosition, key); + + final int bucket = table.findBucket( + HashTableUtils.hashMemory(keyMemory, keyMemoryPosition, Integer.BYTES), + keyMemory, + keyMemoryPosition + ); + + if (bucket < 0) { + table.initBucket(-(bucket + 1), keyMemory, keyMemoryPosition); + } + + return bucket; + } + + /** + * Writes a value to a bucket. The bucket must have already been initialized by calling + * {@link MemoryOpenHashTable#initBucket}. + */ + private static void writeValueToBucket(final MemoryOpenHashTable table, final int bucket, final int value) + { + final int valuePosition = table.bucketMemoryPosition(bucket) + table.bucketValueOffset(); + table.memory().putInt(valuePosition, value); + } + + /** + * Returns a set of key, value pairs from the provided table. Uses the table's {@link MemoryOpenHashTable#bucketIterator()} + * method. + */ + private static Set pairSet(final MemoryOpenHashTable table) + { + final Set retVal = new HashSet<>(); + + final IntIterator bucketIterator = table.bucketIterator(); + + while (bucketIterator.hasNext()) { + final int bucket = bucketIterator.nextInt(); + final ByteBuffer entryBuffer = table.memory().getByteBuffer().duplicate(); + entryBuffer.position(table.bucketMemoryPosition(bucket)); + entryBuffer.limit(entryBuffer.position() + table.bucketSize()); + + // Must copy since we're materializing, and the buffer will get reused. + final ByteBuffer keyBuffer = ByteBuffer.allocate(table.keySize()); + final ByteBuffer keyDup = entryBuffer.duplicate(); + final int keyPosition = keyDup.position() + table.bucketKeyOffset(); + keyDup.position(keyPosition); + keyDup.limit(keyPosition + table.keySize()); + keyBuffer.put(keyDup); + keyBuffer.position(0); + + final ByteBuffer valueBuffer = ByteBuffer.allocate(table.valueSize()); + final ByteBuffer valueDup = entryBuffer.duplicate(); + final int valuePosition = valueDup.position() + table.bucketValueOffset(); + valueDup.position(valuePosition); + valueDup.limit(valuePosition + table.valueSize()); + valueBuffer.put(valueDup); + valueBuffer.position(0); + + retVal.add(new ByteBufferPair(keyBuffer, valueBuffer)); + } + + return retVal; + } + + private static MemoryOpenHashTable createTable( + final int numBuckets, + final double loadFactor, + final int keySize, + final int valueSize + ) + { + final int maxSize = (int) Math.floor(numBuckets * loadFactor); + + final ByteBuffer buffer = ByteBuffer.allocate( + MemoryOpenHashTable.memoryNeeded( + numBuckets, + MemoryOpenHashTable.bucketSize(keySize, valueSize) + ) + ); + + // Scribble garbage to make sure that we don't depend on the buffer being clear. + for (int i = 0; i < buffer.capacity(); i++) { + buffer.put(i, (byte) ThreadLocalRandom.current().nextInt()); + } + + return new MemoryOpenHashTable( + WritableMemory.wrap(buffer, ByteOrder.nativeOrder()), + numBuckets, + maxSize, + keySize, + valueSize + ); + } + + private static ByteBuffer expectedKey(final int key) + { + return ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.nativeOrder()).putInt(0, key); + } + + private static ByteBuffer expectedValue(final int value) + { + return ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.nativeOrder()).putInt(0, value); + } + + private static void assertEqualsMap(final Map expected, final MemoryOpenHashTable actual) + { + Assert.assertEquals("size", expected.size(), actual.size()); + Assert.assertEquals( + "entries", + expected.entrySet() + .stream() + .map(entry -> new ByteBufferPair(entry.getKey(), entry.getValue())) + .collect(Collectors.toSet()), + pairSet(actual) + ); + } + + private static class ByteBufferPair extends Pair + { + public ByteBufferPair(ByteBuffer lhs, ByteBuffer rhs) + { + super(lhs, rhs); + } + + @Override + public String toString() + { + final byte[] lhsBytes = new byte[lhs.remaining()]; + lhs.duplicate().get(lhsBytes); + + final byte[] rhsBytes = new byte[rhs.remaining()]; + rhs.duplicate().get(rhsBytes); + + return "ByteBufferPair{" + + "lhs=" + Arrays.toString(lhsBytes) + + ", rhs=" + Arrays.toString(rhsBytes) + + '}'; + } + } +}