Add MemoryOpenHashTable, a table similar to ByteBufferHashTable. (#9308)

* Add MemoryOpenHashTable, a table similar to ByteBufferHashTable.

With some key differences to improve speed and design simplicity:

1) Uses Memory rather than ByteBuffer for its backing storage.
2) Uses faster hashing and comparison routines (see HashTableUtils).
3) Capacity is always a power of two, allowing simpler design and more
   efficient implementation of findBucket.
4) Does not implement growability; instead, leaves that to its callers.
   The idea is this removes the need for subclasses, while still giving
   callers flexibility in how to handle table-full scenarios.

* Fix LGTM warnings.

* Adjust dependencies.

* Remove easymock from druid-benchmarks.

* Adjustments from review.

* Fix datasketches unit tests.

* Fix checkstyle.
This commit is contained in:
Gian Merlino 2020-02-04 19:58:00 -08:00 committed by GitHub
parent 0d2b16c1d0
commit 3ef5c2f2e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1320 additions and 4 deletions

View File

@ -158,7 +158,10 @@
<dependency>
<groupId>org.apache.datasketches</groupId>
<artifactId>datasketches-java</artifactId>
<version>1.1.0-incubating</version>
</dependency>
<dependency>
<groupId>org.apache.datasketches</groupId>
<artifactId>datasketches-memory</artifactId>
</dependency>
<dependency>
<groupId>junit</groupId>

View File

@ -0,0 +1,169 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.benchmark;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.groupby.epinephelinae.collection.HashTableUtils;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Random;
import java.util.concurrent.TimeUnit;
@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 10)
@Measurement(iterations = 15)
public class MemoryBenchmark
{
static {
NullHandling.initializeForTests();
}
@Param({"4", "5", "8", "9", "12", "16", "31", "32", "64", "128"})
public int numBytes;
@Param({"offheap"})
public String where;
private ByteBuffer buffer1;
private ByteBuffer buffer2;
private ByteBuffer buffer3;
private WritableMemory memory1;
private WritableMemory memory2;
private WritableMemory memory3;
@Setup
public void setUp()
{
if ("onheap".equals(where)) {
buffer1 = ByteBuffer.allocate(numBytes).order(ByteOrder.nativeOrder());
buffer2 = ByteBuffer.allocate(numBytes).order(ByteOrder.nativeOrder());
buffer3 = ByteBuffer.allocate(numBytes).order(ByteOrder.nativeOrder());
} else if ("offheap".equals(where)) {
buffer1 = ByteBuffer.allocateDirect(numBytes).order(ByteOrder.nativeOrder());
buffer2 = ByteBuffer.allocateDirect(numBytes).order(ByteOrder.nativeOrder());
buffer3 = ByteBuffer.allocateDirect(numBytes).order(ByteOrder.nativeOrder());
}
memory1 = WritableMemory.wrap(buffer1, ByteOrder.nativeOrder());
memory2 = WritableMemory.wrap(buffer2, ByteOrder.nativeOrder());
memory3 = WritableMemory.wrap(buffer3, ByteOrder.nativeOrder());
// Scribble in some random but consistent (same seed) garbage.
final Random random = new Random(0);
for (int i = 0; i < numBytes; i++) {
memory1.putByte(i, (byte) random.nextInt());
}
// memory1 == memory2
memory1.copyTo(0, memory2, 0, numBytes);
// memory1 != memory3, but only slightly (different in a middle byte; an attempt to not favor leftward moving vs
// rightward moving equality checks).
memory1.copyTo(0, memory3, 0, numBytes);
memory3.putByte(numBytes / 2, (byte) (~memory3.getByte(numBytes / 2)));
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void equals_byteBuffer_whenEqual(Blackhole blackhole)
{
blackhole.consume(buffer1.equals(buffer2));
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void equals_byteBuffer_whenDifferent(Blackhole blackhole)
{
blackhole.consume(buffer1.equals(buffer3));
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void equals_hashTableUtils_whenEqual(Blackhole blackhole)
{
blackhole.consume(HashTableUtils.memoryEquals(memory1, 0, memory2, 0, numBytes));
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void equals_hashTableUtils_whenDifferent(Blackhole blackhole)
{
blackhole.consume(HashTableUtils.memoryEquals(memory1, 0, memory3, 0, numBytes));
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void equals_memoryEqualTo_whenEqual(Blackhole blackhole)
{
blackhole.consume(memory1.equalTo(0, memory2, 0, numBytes));
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void equals_memoryEqualTo_whenDifferent(Blackhole blackhole)
{
blackhole.consume(memory1.equalTo(0, memory3, 0, numBytes));
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void hash_byteBufferHashCode(Blackhole blackhole)
{
blackhole.consume(buffer1.hashCode());
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void hash_hashTableUtils(Blackhole blackhole)
{
blackhole.consume(HashTableUtils.hashMemory(memory1, 0, numBytes));
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void hash_memoryXxHash64(Blackhole blackhole)
{
blackhole.consume(memory1.xxHash64(0, numBytes, 0));
}
}

View File

@ -474,6 +474,7 @@ public class HllSketchSqlAggregatorTest extends CalciteTestBase
+ " Log Config K : 12\n"
+ " Hll Target : HLL_4\n"
+ " Current Mode : LIST\n"
+ " Memory : false\n"
+ " LB : 2.0\n"
+ " Estimate : 2.000000004967054\n"
+ " UB : 2.000099863468538\n"
@ -483,6 +484,7 @@ public class HllSketchSqlAggregatorTest extends CalciteTestBase
+ " LOG CONFIG K : 12\n"
+ " HLL TARGET : HLL_4\n"
+ " CURRENT MODE : LIST\n"
+ " MEMORY : FALSE\n"
+ " LB : 2.0\n"
+ " ESTIMATE : 2.000000004967054\n"
+ " UB : 2.000099863468538\n"
@ -611,6 +613,7 @@ public class HllSketchSqlAggregatorTest extends CalciteTestBase
+ " Log Config K : 12\n"
+ " Hll Target : HLL_4\n"
+ " Current Mode : LIST\n"
+ " Memory : false\n"
+ " LB : 2.0\n"
+ " Estimate : 2.000000004967054\n"
+ " UB : 2.000099863468538\n"

View File

@ -82,6 +82,7 @@
<avatica.version>1.15.0</avatica.version>
<avro.version>1.9.1</avro.version>
<calcite.version>1.21.0</calcite.version>
<datasketches.version>1.2.0-incubating</datasketches.version>
<derby.version>10.14.2.0</derby.version>
<dropwizard.metrics.version>4.0.0</dropwizard.metrics.version>
<guava.version>16.0.1</guava.version>
@ -1000,12 +1001,12 @@
<dependency>
<groupId>org.apache.datasketches</groupId>
<artifactId>datasketches-java</artifactId>
<version>1.1.0-incubating</version>
<version>${datasketches.version}</version>
</dependency>
<dependency>
<groupId>org.apache.datasketches</groupId>
<artifactId>datasketches-memory</artifactId>
<version>1.2.0-incubating</version>
<version>${datasketches.version}</version>
</dependency>
<dependency>
<groupId>org.apache.calcite</groupId>

View File

@ -156,6 +156,10 @@
<groupId>javax.validation</groupId>
<artifactId>validation-api</artifactId>
</dependency>
<dependency>
<groupId>org.apache.datasketches</groupId>
<artifactId>datasketches-memory</artifactId>
</dependency>
<!-- Tests -->
<dependency>

View File

@ -56,7 +56,7 @@ public class Groupers
* MurmurHash3 was written by Austin Appleby, and is placed in the public domain. The author
* hereby disclaims copyright to this source code.
*/
private static int smear(int hashCode)
public static int smear(int hashCode)
{
return C2 * Integer.rotateLeft(hashCode * C1, 15);
}

View File

@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.groupby.epinephelinae.collection;
import org.apache.datasketches.memory.Memory;
public class HashTableUtils
{
private HashTableUtils()
{
// No instantiation.
}
/**
* Computes the previous power of two less than or equal to a given "n".
*
* The integer should be between 1 (inclusive) and {@link Integer#MAX_VALUE} for best results. Other parameters will
* return {@link Integer#MIN_VALUE}.
*/
public static int previousPowerOfTwo(final int n)
{
if (n > 0) {
return Integer.highestOneBit(n);
} else {
return Integer.MIN_VALUE;
}
}
/**
* Compute a simple, fast hash code of some memory range.
*
* @param memory a region of memory
* @param position position within the memory region
* @param length length of memory to hash, starting at the position
*/
public static int hashMemory(final Memory memory, final long position, final int length)
{
// Special cases for small, common key sizes to speed them up: e.g. one int key, two int keys, one long key, etc.
// The plus-one sizes (9, 13) are for nullable dimensions. The specific choices of special cases were chosen based
// on benchmarking (see MemoryBenchmark) on a Skylake-based cloud instance.
switch (length) {
case 4:
return memory.getInt(position);
case 8:
return 31 * (31 + memory.getInt(position)) + memory.getInt(position + Integer.BYTES);
case 9:
return 31 * (31 * (31 + memory.getInt(position)) + memory.getInt(position + Integer.BYTES))
+ memory.getByte(position + 2 * Integer.BYTES);
case 12:
return 31 * (31 * (31 + memory.getInt(position)) + memory.getInt(position + Integer.BYTES))
+ memory.getInt(position + 2 * Integer.BYTES);
case 13:
return 31 * (31 * (31 * (31 + memory.getInt(position)) + memory.getInt(position + Integer.BYTES))
+ memory.getInt(position + 2 * Integer.BYTES)) + memory.getByte(position + 3 * Integer.BYTES);
case 16:
return 31 * (31 * (31 * (31 + memory.getInt(position)) + memory.getInt(position + Integer.BYTES))
+ memory.getInt(position + 2 * Integer.BYTES)) + memory.getInt(position + 3 * Integer.BYTES);
default:
int hashCode = 1;
int remainingBytes = length;
long pos = position;
while (remainingBytes >= Integer.BYTES) {
hashCode = 31 * hashCode + memory.getInt(pos);
remainingBytes -= Integer.BYTES;
pos += Integer.BYTES;
}
if (remainingBytes == 1) {
hashCode = 31 * hashCode + memory.getByte(pos);
} else if (remainingBytes == 2) {
hashCode = 31 * hashCode + memory.getByte(pos);
hashCode = 31 * hashCode + memory.getByte(pos + 1);
} else if (remainingBytes == 3) {
hashCode = 31 * hashCode + memory.getByte(pos);
hashCode = 31 * hashCode + memory.getByte(pos + 1);
hashCode = 31 * hashCode + memory.getByte(pos + 2);
}
return hashCode;
}
}
/**
* Compare two memory ranges for equality.
*
* The purpose of this function is to be faster than {@link Memory#equalTo} for the small memory ranges that
* typically comprise keys in hash tables. As of this writing, it is. See "MemoryBenchmark" in the druid-benchmarks
* module for performance evaluation code.
*
* @param memory1 a region of memory
* @param offset1 position within the first memory region
* @param memory2 another region of memory
* @param offset2 position within the second memory region
* @param length length of memory to compare, starting at the positions
*/
public static boolean memoryEquals(
final Memory memory1,
final long offset1,
final Memory memory2,
final long offset2,
final int length
)
{
// Special cases for small, common key sizes to speed them up: e.g. one int key, two int keys, one long key, etc.
// The plus-one sizes (9, 13) are for nullable dimensions. The specific choices of special cases were chosen based
// on benchmarking (see MemoryBenchmark) on a Skylake-based cloud instance.
switch (length) {
case 4:
return memory1.getInt(offset1) == memory2.getInt(offset2);
case 8:
return memory1.getLong(offset1) == memory2.getLong(offset2);
case 9:
return memory1.getLong(offset1) == memory2.getLong(offset2)
&& memory1.getByte(offset1 + Long.BYTES) == memory2.getByte(offset2 + Long.BYTES);
case 12:
return memory1.getInt(offset1) == memory2.getInt(offset2)
&& memory1.getLong(offset1 + Integer.BYTES) == memory2.getLong(offset2 + Integer.BYTES);
case 13:
return memory1.getLong(offset1) == memory2.getLong(offset2)
&& memory1.getInt(offset1 + Long.BYTES) == memory2.getInt(offset2 + Long.BYTES)
&& (memory1.getByte(offset1 + Integer.BYTES + Long.BYTES)
== memory2.getByte(offset2 + Integer.BYTES + Long.BYTES));
case 16:
return memory1.getLong(offset1) == memory2.getLong(offset2)
&& memory1.getLong(offset1 + Long.BYTES) == memory2.getLong(offset2 + Long.BYTES);
default:
return memory1.equalTo(offset1, memory2, offset2, length);
}
}
}

View File

@ -0,0 +1,433 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.groupby.epinephelinae.collection;
import it.unimi.dsi.fastutil.ints.IntIterator;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.query.groupby.epinephelinae.Groupers;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.NoSuchElementException;
/**
* An open-addressed hash table with linear probing backed by {@link WritableMemory}. Does not offer a similar
* interface to {@link java.util.Map} because this is meant to be useful to lower-level, high-performance callers.
* There is no copying or serde of keys and values: callers access the backing memory of the table directly.
*
* This table will not grow itself. Callers must handle growing if required; the {@link #copyTo} method is provided
* to assist.
*/
public class MemoryOpenHashTable
{
private static final byte USED_BYTE = 1;
private static final int USED_BYTE_SIZE = Byte.BYTES;
private final WritableMemory tableMemory;
private final int keySize;
private final int valueSize;
private final int bucketSize;
// Maximum number of elements in the table (based on numBuckets and maxLoadFactor).
private final int maxSize;
// Number of available/used buckets in the table. Always a power of two.
private final int numBuckets;
// Mask that clips a number to [0, numBuckets). Used when searching through buckets.
private final int bucketMask;
// Number of elements in the table right now.
private int size;
/**
* Create a new table.
*
* @param tableMemory backing memory for the table; must be exactly large enough to hold "numBuckets"
* @param numBuckets number of buckets for the table
* @param maxSize maximum number of elements for the table; must be less than numBuckets
* @param keySize key size in bytes
* @param valueSize value size in bytes
*/
public MemoryOpenHashTable(
final WritableMemory tableMemory,
final int numBuckets,
final int maxSize,
final int keySize,
final int valueSize
)
{
this.tableMemory = tableMemory;
this.numBuckets = numBuckets;
this.bucketMask = numBuckets - 1;
this.maxSize = maxSize;
this.keySize = keySize;
this.valueSize = valueSize;
this.bucketSize = bucketSize(keySize, valueSize);
// Our main intended users (VectorGrouper implementations) need the tableMemory to be backed by a big-endian
// ByteBuffer that is coterminous with the tableMemory, since it's going to feed that buffer into VectorAggregators
// instead of interacting with our WritableMemory directly. Nothing about this class actually requires that the
// Memory be backed by a ByteBuffer, but we'll check it here anyway for the benefit of our biggest customer.
verifyMemoryIsByteBuffer(tableMemory);
if (!tableMemory.getTypeByteOrder().equals(ByteOrder.nativeOrder())) {
throw new ISE("tableMemory must be native byte order");
}
if (tableMemory.getCapacity() != memoryNeeded(numBuckets, bucketSize)) {
throw new ISE(
"tableMemory must be size[%,d] but was[%,d]",
memoryNeeded(numBuckets, bucketSize),
tableMemory.getCapacity()
);
}
if (maxSize >= numBuckets) {
throw new ISE("maxSize must be less than numBuckets");
}
if (Integer.bitCount(numBuckets) != 1) {
throw new ISE("numBuckets must be a power of two but was[%,d]", numBuckets);
}
clear();
}
/**
* Returns the amount of memory needed for a table.
*
* This is just a multiplication, which is easy enough to do on your own, but sometimes it's nice for clarity's sake
* to call a function with a name that indicates why the multiplication is happening.
*
* @param numBuckets number of buckets
* @param bucketSize size per bucket (in bytes)
*
* @return size of table (in bytes)
*/
public static int memoryNeeded(final int numBuckets, final int bucketSize)
{
return numBuckets * bucketSize;
}
/**
* Returns the size of each bucket in a table.
*
* @param keySize size of keys (in bytes)
* @param valueSize size of values (in bytes)
*
* @return size of buckets (in bytes)
*/
public static int bucketSize(final int keySize, final int valueSize)
{
return USED_BYTE_SIZE + keySize + valueSize;
}
/**
* Clear the table, resetting size to zero.
*/
public void clear()
{
size = 0;
// Clear used flags.
for (int bucket = 0; bucket < numBuckets; bucket++) {
tableMemory.putByte((long) bucket * bucketSize, (byte) 0);
}
}
/**
* Copy this table into another one. The other table must be large enough to hold all the copied buckets. The other
* table will be cleared before the copy takes place.
*
* @param other the other table
* @param copyHandler a callback that is notified for each copied bucket
*/
public void copyTo(final MemoryOpenHashTable other, @Nullable final BucketCopyHandler copyHandler)
{
if (other.size() > 0) {
other.clear();
}
for (int bucket = 0; bucket < numBuckets; bucket++) {
final int bucketOffset = bucket * bucketSize;
if (isOffsetUsed(bucketOffset)) {
final int keyPosition = bucketOffset + USED_BYTE_SIZE;
final int keyHash = Groupers.smear(HashTableUtils.hashMemory(tableMemory, keyPosition, keySize));
final int newBucket = other.findBucket(keyHash, tableMemory, keyPosition);
if (newBucket >= 0) {
// Not expected to happen, since we cleared the other table first.
throw new ISE("Found already-used bucket while copying");
}
if (!other.canInsertNewBucket()) {
throw new ISE("Unable to copy bucket to new table, size[%,d]", other.size());
}
final int newBucketOffset = -(newBucket + 1) * bucketSize;
assert !other.isOffsetUsed(newBucketOffset);
tableMemory.copyTo(bucketOffset, other.tableMemory, newBucketOffset, bucketSize);
other.size++;
if (copyHandler != null) {
copyHandler.bucketCopied(bucket, -(newBucket + 1), this, other);
}
}
}
// Sanity check.
if (other.size() != size) {
throw new ISE("New table size[%,d] != old table size[%,d] after copying", other.size(), size);
}
}
/**
* Finds the bucket for a particular key.
*
* @param keyHash result of calling {@link HashTableUtils#hashMemory} on this key
* @param keySpace memory containing the key
* @param keySpacePosition position of the key within keySpace
*
* @return bucket number if currently occupied, or {@code -bucket - 1} if not occupied (yet)
*/
public int findBucket(final int keyHash, final Memory keySpace, final int keySpacePosition)
{
int bucket = keyHash & bucketMask;
while (true) {
final int bucketOffset = bucket * bucketSize;
if (tableMemory.getByte(bucketOffset) == 0) {
// Found unused bucket before finding our key.
return -bucket - 1;
}
final boolean keyFound = HashTableUtils.memoryEquals(
tableMemory,
bucketOffset + USED_BYTE_SIZE,
keySpace,
keySpacePosition,
keySize
);
if (keyFound) {
return bucket;
}
bucket = (bucket + 1) & bucketMask;
}
}
/**
* Returns whether this table can accept a new bucket.
*/
public boolean canInsertNewBucket()
{
return size < maxSize;
}
/**
* Initialize a bucket with a particular key.
*
* Do not call this method unless the bucket is currently unused and {@link #canInsertNewBucket()} returns true.
*
* @param bucket bucket number
* @param keySpace memory containing the key
* @param keySpacePosition position of the key within keySpace
*/
public void initBucket(final int bucket, final Memory keySpace, final int keySpacePosition)
{
final int bucketOffset = bucket * bucketSize;
// Method preconditions.
assert canInsertNewBucket() && !isOffsetUsed(bucketOffset);
// Mark the bucket used and write in the key.
tableMemory.putByte(bucketOffset, USED_BYTE);
keySpace.copyTo(keySpacePosition, tableMemory, bucketOffset + USED_BYTE_SIZE, keySize);
size++;
}
/**
* Returns the number of elements currently in the table.
*/
public int size()
{
return size;
}
/**
* Returns the number of buckets in this table. Note that not all of these can actually be used. The amount that
* can be used depends on the "maxSize" parameter provided during construction.
*/
public int numBuckets()
{
return numBuckets;
}
/**
* Returns the size of keys, in bytes.
*/
public int keySize()
{
return keySize;
}
/**
* Returns the size of values, in bytes.
*/
public int valueSize()
{
return valueSize;
}
/**
* Returns the offset within each bucket where the key starts.
*/
public int bucketKeyOffset()
{
return USED_BYTE_SIZE;
}
/**
* Returns the offset within each bucket where the value starts.
*/
public int bucketValueOffset()
{
return USED_BYTE_SIZE + keySize;
}
/**
* Returns the size in bytes of each bucket.
*/
public int bucketSize()
{
return bucketSize;
}
/**
* Returns the position within {@link #memory()} where a particular bucket starts.
*/
public int bucketMemoryPosition(final int bucket)
{
return bucket * bucketSize;
}
/**
* Returns the memory backing this table.
*/
public WritableMemory memory()
{
return tableMemory;
}
/**
* Iterates over all used buckets, returning bucket numbers for each one.
*
* The intent is that callers will pass the bucket numbers to {@link #bucketMemoryPosition} and then use
* {@link #bucketKeyOffset()} and {@link #bucketValueOffset()} to extract keys and values from the buckets as needed.
*/
public IntIterator bucketIterator()
{
return new IntIterator()
{
private int curr = 0;
private int currBucket = -1;
@Override
public boolean hasNext()
{
return curr < size;
}
@Override
public int nextInt()
{
if (curr >= size) {
throw new NoSuchElementException();
}
currBucket++;
while (!isOffsetUsed(currBucket * bucketSize)) {
currBucket++;
}
curr++;
return currBucket;
}
};
}
/**
* Returns whether the bucket at position "bucketOffset" is used or not. Note that this is a bucket position (in
* bytes), not a bucket number.
*/
private boolean isOffsetUsed(final int bucketOffset)
{
return tableMemory.getByte(bucketOffset) == USED_BYTE;
}
/**
* Validates that some Memory is coterminous with a backing big-endian ByteBuffer. Returns quietly if so, throws an
* exception otherwise.
*/
private static void verifyMemoryIsByteBuffer(final Memory memory)
{
final ByteBuffer buffer = memory.getByteBuffer();
if (buffer == null) {
throw new ISE("tableMemory must be ByteBuffer-backed");
}
if (!buffer.order().equals(ByteOrder.BIG_ENDIAN)) {
throw new ISE("tableMemory's ByteBuffer must be in big-endian order");
}
if (buffer.capacity() != memory.getCapacity() || buffer.remaining() != buffer.capacity()) {
throw new ISE("tableMemory's ByteBuffer must be coterminous");
}
}
/**
* Callback used by {@link #copyTo}.
*/
public interface BucketCopyHandler
{
/**
* Indicates that "oldBucket" in "oldTable" was copied to "newBucket" in "newTable".
*
* @param oldBucket old bucket number
* @param newBucket new bucket number
* @param oldTable old table
* @param newTable new table
*/
void bucketCopied(
int oldBucket,
int newBucket,
MemoryOpenHashTable oldTable,
MemoryOpenHashTable newTable
);
}
}

View File

@ -0,0 +1,189 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.groupby.epinephelinae.collection;
import it.unimi.dsi.fastutil.ints.Int2IntLinkedOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.java.util.common.StringUtils;
import org.junit.Assert;
import org.junit.Test;
import java.util.Random;
public class HashTableUtilsTest
{
@Test
public void test_previousPowerOfTwo()
{
final Int2IntMap expectedResults = new Int2IntLinkedOpenHashMap();
expectedResults.put(Integer.MIN_VALUE, Integer.MIN_VALUE);
expectedResults.put(Integer.MIN_VALUE + 1, Integer.MIN_VALUE);
expectedResults.put(-4, Integer.MIN_VALUE);
expectedResults.put(-3, Integer.MIN_VALUE);
expectedResults.put(-2, Integer.MIN_VALUE);
expectedResults.put(-1, Integer.MIN_VALUE);
expectedResults.put(0, Integer.MIN_VALUE);
expectedResults.put(1, 1);
expectedResults.put(2, 2);
expectedResults.put(3, 2);
expectedResults.put(4, 4);
expectedResults.put(5, 4);
expectedResults.put(6, 4);
expectedResults.put(7, 4);
expectedResults.put(8, 8);
expectedResults.put((1 << 30) - 1, 1 << 29);
expectedResults.put(1 << 30, 1 << 30);
expectedResults.put((1 << 30) + 1, 1073741824);
expectedResults.put(Integer.MAX_VALUE - 1, 1073741824);
expectedResults.put(Integer.MAX_VALUE, 1073741824);
for (final Int2IntMap.Entry entry : expectedResults.int2IntEntrySet()) {
Assert.assertEquals(
entry.getIntKey() + " => " + entry.getIntValue(),
entry.getIntValue(),
HashTableUtils.previousPowerOfTwo(entry.getIntKey())
);
}
}
private static WritableMemory generateRandomButNotReallyRandomMemory(final int length)
{
final WritableMemory randomMemory = WritableMemory.allocate(length);
// Fill with random bytes, but use the same seed every run for consistency. This test should pass with
// any seed unless something really pathological happens.
final Random random = new Random(0);
final byte[] randomBytes = new byte[length];
random.nextBytes(randomBytes);
randomMemory.putByteArray(0, randomBytes, 0, length);
return randomMemory;
}
@Test
public void test_hashMemory_allByteLengthsUpTo128()
{
// This test validates that we *can* hash any amount of memory up to 128 bytes, and that if any bit is flipped
// in the memory then the hash changes. It doesn't validate that the hash function is actually good at dispersion.
// That also has a big impact on performance and needs to be checked separately if the hash function is changed.
final int maxBytes = 128;
final WritableMemory randomMemory = generateRandomButNotReallyRandomMemory(maxBytes);
for (int numBytes = 0; numBytes < maxBytes; numBytes++) {
// Grab "numBytes" bytes from the end of randomMemory.
final Memory regionToHash = randomMemory.region(maxBytes - numBytes, numBytes);
// Validate that hashing regionAtEnd is equivalent to hashing the end of a region. This helps validate
// that using a nonzero position is effective.
Assert.assertEquals(
StringUtils.format("numBytes[%s] nonzero position check", numBytes),
HashTableUtils.hashMemory(regionToHash, 0, numBytes),
HashTableUtils.hashMemory(randomMemory, maxBytes - numBytes, numBytes)
);
// Copy the memory and make sure we did it right.
final WritableMemory copyOfRegion = WritableMemory.allocate(numBytes);
regionToHash.copyTo(0, copyOfRegion, 0, numBytes);
Assert.assertTrue(
StringUtils.format("numBytes[%s] copy equality check", numBytes),
regionToHash.equalTo(0, copyOfRegion, 0, numBytes)
);
// Validate that flipping any bit affects the hash.
for (int bit = 0; bit < numBytes * Byte.SIZE; bit++) {
final int bytePosition = bit / Byte.SIZE;
final byte mask = (byte) (1 << (bit % Byte.SIZE));
copyOfRegion.putByte(
bytePosition,
(byte) (copyOfRegion.getByte(bytePosition) ^ mask)
);
Assert.assertNotEquals(
StringUtils.format("numBytes[%s] bit[%s] flip check", numBytes, bit),
HashTableUtils.hashMemory(regionToHash, 0, numBytes),
HashTableUtils.hashMemory(copyOfRegion, 0, numBytes)
);
// Set it back and make sure we did it right.
copyOfRegion.putByte(
bytePosition,
(byte) (copyOfRegion.getByte(bytePosition) ^ mask)
);
Assert.assertTrue(
StringUtils.format("numBytes[%s] bit[%s] reset check", numBytes, bit),
regionToHash.equalTo(0, copyOfRegion, 0, numBytes)
);
}
}
}
@Test
public void test_memoryEquals_allByteLengthsUpTo128()
{
// This test validates that we can compare any two slices of memory of size up to 128 bytes, and that if any bit
// is flipped in two identical memory slices, then the comparison correctly returns not equal.
final int maxBytes = 128;
final WritableMemory randomMemory = generateRandomButNotReallyRandomMemory(maxBytes);
for (int numBytes = 0; numBytes < maxBytes; numBytes++) {
// Copy "numBytes" from the end of randomMemory.
final WritableMemory copyOfRegion = WritableMemory.allocate(numBytes);
randomMemory.copyTo(maxBytes - numBytes, copyOfRegion, 0, numBytes);
// Compare the two.
Assert.assertTrue(
StringUtils.format("numBytes[%s] nonzero position check", numBytes),
HashTableUtils.memoryEquals(randomMemory, maxBytes - numBytes, copyOfRegion, 0, numBytes)
);
// Validate that flipping any bit affects equality.
for (int bit = 0; bit < numBytes * Byte.SIZE; bit++) {
final int bytePosition = bit / Byte.SIZE;
final byte mask = (byte) (1 << (bit % Byte.SIZE));
copyOfRegion.putByte(
bytePosition,
(byte) (copyOfRegion.getByte(bytePosition) ^ mask)
);
Assert.assertFalse(
StringUtils.format("numBytes[%s] bit[%s] flip check", numBytes, bit),
HashTableUtils.memoryEquals(randomMemory, maxBytes - numBytes, copyOfRegion, 0, numBytes)
);
// Set it back and make sure we did it right.
copyOfRegion.putByte(
bytePosition,
(byte) (copyOfRegion.getByte(bytePosition) ^ mask)
);
Assert.assertTrue(
StringUtils.format("numBytes[%s] bit[%s] reset check", numBytes, bit),
HashTableUtils.memoryEquals(randomMemory, maxBytes - numBytes, copyOfRegion, 0, numBytes)
);
}
}
}
}

View File

@ -0,0 +1,352 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.groupby.epinephelinae.collection;
import com.google.common.collect.ImmutableMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntIterator;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.java.util.common.Pair;
import org.junit.Assert;
import org.junit.Test;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
public class MemoryOpenHashTableTest
{
@Test
public void testMemoryNeeded()
{
Assert.assertEquals(512, MemoryOpenHashTable.memoryNeeded(128, 4));
}
@Test
public void testEmptyTable()
{
final MemoryOpenHashTable table = createTable(8, .75, 4, 4);
Assert.assertEquals(8, table.numBuckets());
Assert.assertEquals(72, table.memory().getCapacity());
Assert.assertEquals(9, table.bucketSize());
assertEqualsMap(ImmutableMap.of(), table);
}
@Test
public void testInsertRepeatedKeys()
{
final MemoryOpenHashTable table = createTable(8, .7, Integer.BYTES, Integer.BYTES);
final WritableMemory keyMemory = WritableMemory.allocate(Integer.BYTES);
// Insert the following keys repeatedly.
final int[] keys = {0, 1, 2};
for (int i = 0; i < 3; i++) {
for (int key : keys) {
// Find bucket for key.
keyMemory.putInt(0, key);
int bucket = table.findBucket(HashTableUtils.hashMemory(keyMemory, 0, Integer.BYTES), keyMemory, 0);
if (bucket < 0) {
Assert.assertTrue(table.canInsertNewBucket());
bucket = -(bucket + 1);
table.initBucket(bucket, keyMemory, 0);
final int valuePosition = table.bucketMemoryPosition(bucket) + table.bucketValueOffset();
// Initialize to zero.
table.memory().putInt(valuePosition, 0);
}
// Add the key.
final int valuePosition = table.bucketMemoryPosition(bucket) + table.bucketValueOffset();
table.memory().putInt(
valuePosition,
table.memory().getInt(valuePosition) + key
);
}
}
final Map<ByteBuffer, ByteBuffer> expectedMap = new HashMap<>();
expectedMap.put(expectedKey(0), expectedValue(0));
expectedMap.put(expectedKey(1), expectedValue(3));
expectedMap.put(expectedKey(2), expectedValue(6));
assertEqualsMap(expectedMap, table);
}
@Test
public void testInsertDifferentKeysUntilFull()
{
final MemoryOpenHashTable table = createTable(256, .999, Integer.BYTES, Integer.BYTES);
final Map<ByteBuffer, ByteBuffer> expectedMap = new HashMap<>();
int key = 0;
while (table.canInsertNewBucket()) {
final int value = Integer.MAX_VALUE - key;
// Find bucket for key (which should not already exist).
final int bucket = findAndInitBucket(table, key);
Assert.assertTrue("bucket < 0 for key " + key, bucket < 0);
// Insert bucket and write value.
writeValueToBucket(table, -(bucket + 1), value);
expectedMap.put(expectedKey(key), expectedValue(value));
key += 7;
}
// This table should fill up at 255 elements (256 buckets, .999 load factor)
Assert.assertEquals("expected size", 255, table.size());
assertEqualsMap(expectedMap, table);
}
@Test
public void testCopyTo()
{
final MemoryOpenHashTable table1 = createTable(64, .7, Integer.BYTES, Integer.BYTES);
final MemoryOpenHashTable table2 = createTable(128, .7, Integer.BYTES, Integer.BYTES);
final Int2IntMap expectedMap = new Int2IntOpenHashMap();
expectedMap.put(0, 1);
expectedMap.put(-1, 2);
expectedMap.put(111, 123);
expectedMap.put(Integer.MAX_VALUE, Integer.MIN_VALUE);
expectedMap.put(Integer.MIN_VALUE, Integer.MAX_VALUE);
// Populate table1.
for (Int2IntMap.Entry entry : expectedMap.int2IntEntrySet()) {
final int bucket = findAndInitBucket(table1, entry.getIntKey());
Assert.assertTrue("bucket < 0 for key " + entry.getIntKey(), bucket < 0);
writeValueToBucket(table1, -(bucket + 1), entry.getIntValue());
}
// Copy to table2.
table1.copyTo(table2, ((oldBucket, newBucket, oldTable, newTable) -> {
Assert.assertSame(table1, oldTable);
Assert.assertSame(table2, newTable);
}));
// Compute expected map to compare these tables to.
final Map<ByteBuffer, ByteBuffer> expectedByteBufferMap =
expectedMap.int2IntEntrySet()
.stream()
.collect(
Collectors.toMap(
entry -> expectedKey(entry.getIntKey()),
entry -> expectedValue(entry.getIntValue())
)
);
assertEqualsMap(expectedByteBufferMap, table1);
assertEqualsMap(expectedByteBufferMap, table2);
}
@Test
public void testClear()
{
final MemoryOpenHashTable table = createTable(64, .7, Integer.BYTES, Integer.BYTES);
final Int2IntMap expectedMap = new Int2IntOpenHashMap();
expectedMap.put(0, 1);
expectedMap.put(-1, 2);
// Populate table.
for (Int2IntMap.Entry entry : expectedMap.int2IntEntrySet()) {
final int bucket = findAndInitBucket(table, entry.getIntKey());
Assert.assertTrue("bucket < 0 for key " + entry.getIntKey(), bucket < 0);
writeValueToBucket(table, -(bucket + 1), entry.getIntValue());
}
// Compute expected map to compare these tables to.
final Map<ByteBuffer, ByteBuffer> expectedByteBufferMap =
expectedMap.int2IntEntrySet()
.stream()
.collect(
Collectors.toMap(
entry -> expectedKey(entry.getIntKey()),
entry -> expectedValue(entry.getIntValue())
)
);
assertEqualsMap(expectedByteBufferMap, table);
// Clear and verify.
table.clear();
assertEqualsMap(ImmutableMap.of(), table);
}
/**
* Finds the bucket for the provided key using {@link MemoryOpenHashTable#findBucket} and initializes it if empty
* using {@link MemoryOpenHashTable#initBucket}. Same return value as {@link MemoryOpenHashTable#findBucket}.
*/
private static int findAndInitBucket(final MemoryOpenHashTable table, final int key)
{
final int keyMemoryPosition = 1; // Helps verify that offsets work
final WritableMemory keyMemory = WritableMemory.allocate(Integer.BYTES + 1);
keyMemory.putInt(keyMemoryPosition, key);
final int bucket = table.findBucket(
HashTableUtils.hashMemory(keyMemory, keyMemoryPosition, Integer.BYTES),
keyMemory,
keyMemoryPosition
);
if (bucket < 0) {
table.initBucket(-(bucket + 1), keyMemory, keyMemoryPosition);
}
return bucket;
}
/**
* Writes a value to a bucket. The bucket must have already been initialized by calling
* {@link MemoryOpenHashTable#initBucket}.
*/
private static void writeValueToBucket(final MemoryOpenHashTable table, final int bucket, final int value)
{
final int valuePosition = table.bucketMemoryPosition(bucket) + table.bucketValueOffset();
table.memory().putInt(valuePosition, value);
}
/**
* Returns a set of key, value pairs from the provided table. Uses the table's {@link MemoryOpenHashTable#bucketIterator()}
* method.
*/
private static Set<ByteBufferPair> pairSet(final MemoryOpenHashTable table)
{
final Set<ByteBufferPair> retVal = new HashSet<>();
final IntIterator bucketIterator = table.bucketIterator();
while (bucketIterator.hasNext()) {
final int bucket = bucketIterator.nextInt();
final ByteBuffer entryBuffer = table.memory().getByteBuffer().duplicate();
entryBuffer.position(table.bucketMemoryPosition(bucket));
entryBuffer.limit(entryBuffer.position() + table.bucketSize());
// Must copy since we're materializing, and the buffer will get reused.
final ByteBuffer keyBuffer = ByteBuffer.allocate(table.keySize());
final ByteBuffer keyDup = entryBuffer.duplicate();
final int keyPosition = keyDup.position() + table.bucketKeyOffset();
keyDup.position(keyPosition);
keyDup.limit(keyPosition + table.keySize());
keyBuffer.put(keyDup);
keyBuffer.position(0);
final ByteBuffer valueBuffer = ByteBuffer.allocate(table.valueSize());
final ByteBuffer valueDup = entryBuffer.duplicate();
final int valuePosition = valueDup.position() + table.bucketValueOffset();
valueDup.position(valuePosition);
valueDup.limit(valuePosition + table.valueSize());
valueBuffer.put(valueDup);
valueBuffer.position(0);
retVal.add(new ByteBufferPair(keyBuffer, valueBuffer));
}
return retVal;
}
private static MemoryOpenHashTable createTable(
final int numBuckets,
final double loadFactor,
final int keySize,
final int valueSize
)
{
final int maxSize = (int) Math.floor(numBuckets * loadFactor);
final ByteBuffer buffer = ByteBuffer.allocate(
MemoryOpenHashTable.memoryNeeded(
numBuckets,
MemoryOpenHashTable.bucketSize(keySize, valueSize)
)
);
// Scribble garbage to make sure that we don't depend on the buffer being clear.
for (int i = 0; i < buffer.capacity(); i++) {
buffer.put(i, (byte) ThreadLocalRandom.current().nextInt());
}
return new MemoryOpenHashTable(
WritableMemory.wrap(buffer, ByteOrder.nativeOrder()),
numBuckets,
maxSize,
keySize,
valueSize
);
}
private static ByteBuffer expectedKey(final int key)
{
return ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.nativeOrder()).putInt(0, key);
}
private static ByteBuffer expectedValue(final int value)
{
return ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.nativeOrder()).putInt(0, value);
}
private static void assertEqualsMap(final Map<ByteBuffer, ByteBuffer> expected, final MemoryOpenHashTable actual)
{
Assert.assertEquals("size", expected.size(), actual.size());
Assert.assertEquals(
"entries",
expected.entrySet()
.stream()
.map(entry -> new ByteBufferPair(entry.getKey(), entry.getValue()))
.collect(Collectors.toSet()),
pairSet(actual)
);
}
private static class ByteBufferPair extends Pair<ByteBuffer, ByteBuffer>
{
public ByteBufferPair(ByteBuffer lhs, ByteBuffer rhs)
{
super(lhs, rhs);
}
@Override
public String toString()
{
final byte[] lhsBytes = new byte[lhs.remaining()];
lhs.duplicate().get(lhsBytes);
final byte[] rhsBytes = new byte[rhs.remaining()];
rhs.duplicate().get(rhsBytes);
return "ByteBufferPair{" +
"lhs=" + Arrays.toString(lhsBytes) +
", rhs=" + Arrays.toString(rhsBytes) +
'}';
}
}
}