HllSketch Merge Aggregator optimizations (#15162)

* Null byte serde for empty sketches

* Cache for HllSketchMerge

* Check for empty sketches

* Address review comments

* Revert changes to HllSketchHolder

* Handle null sketch holders instead of null sketches

* Add unit test for MSQ HllSketch

* Add comments

* Fix style
This commit is contained in:
Adarsh Sanjeev 2023-11-03 08:31:22 +05:30 committed by GitHub
parent fb260f3e41
commit 9576fd3141
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 254 additions and 41 deletions

View File

@ -205,10 +205,14 @@ public abstract class HllSketchAggregatorFactory extends AggregatorFactory
@Override
public Object finalizeComputation(@Nullable final Object object)
{
if (!shouldFinalize || object == null) {
if (!shouldFinalize) {
return object;
}
if (object == null) {
return 0.0D;
}
final HllSketchHolder sketch = HllSketchHolder.fromObj(object);
final double estimate = sketch.getEstimate();

View File

@ -48,13 +48,23 @@ public class HllSketchHolderObjectStrategy implements ObjectStrategy<HllSketchHo
@Override
public HllSketchHolder fromByteBuffer(final ByteBuffer buf, final int size)
{
if (size == 0 || isSafeToConvertToNullSketch(buf, size)) {
return null;
}
return HllSketchHolder.of(HllSketch.wrap(Memory.wrap(buf, ByteOrder.LITTLE_ENDIAN).region(buf.position(), size)));
}
@Override
public byte[] toBytes(final HllSketchHolder sketch)
public byte[] toBytes(final HllSketchHolder holder)
{
return sketch.getSketch().toCompactByteArray();
if (holder == null) {
return new byte[] {};
}
HllSketch sketch = holder.getSketch();
if (sketch == null || sketch.isEmpty()) {
return new byte[] {};
}
return sketch.toCompactByteArray();
}
@Nullable
@ -67,4 +77,57 @@ public class HllSketchHolderObjectStrategy implements ObjectStrategy<HllSketchHo
)
);
}
/**
* Checks if a sketch is empty and can be converted to null. Returns true if it is and false if it is not, or if is
* not possible to say for sure.
* Checks the initial 8 byte header to find the type of internal sketch implementation, then uses the logic the
* corresponding implementation uses to tell if a sketch is empty while deserializing it.
*/
private static boolean isSafeToConvertToNullSketch(ByteBuffer buf, int size)
{
if (size < 8) {
// Sanity check.
// HllSketches as bytes should be at least 8 bytes even with an empty sketch. If this is not the case, return
// false since we can't be sure.
return false;
}
final int position = buf.position();
// Get preamble int. These should correspond to the type of internal implementaion as a sanity check.
final int preInts = buf.get(position) & 0x3F; // get(PREAMBLE_INTS_BYTE) & PREAMBLE_MASK
// Get org.apache.datasketches.hll.CurMode. This indicates the type of internal data structure.
final int curMode = buf.get(position + 7) & 3; // get(MODE_BYTE) & CUR_MODE_MASK
switch (curMode) {
case 0: // LIST
if (preInts != 2) {
// preInts should be LIST_PREINTS, Sanity check.
return false;
}
// Based on org.apache.datasketches.hll.PreambleUtil.extractListCount
int listCount = buf.get(position + 6) & 0xFF; // get(LIST_COUNT_BYTE) & 0xFF
return listCount == 0;
case 1: // SET
if (preInts != 3 || size < 9) {
// preInts should be HASH_SET_PREINTS, Sanity check.
// We also need to read an additional byte for Set implementations.
return false;
}
// Based on org.apache.datasketches.hll.PreambleUtil.extractHashSetCount
int setCount = buf.get(position + 8); // get(HASH_SET_COUNT_INT)
return setCount == 0;
case 2: // HLL
if (preInts != 10) {
// preInts should be HLL_PREINTS, Sanity check.
return false;
}
// Based on org.apache.datasketches.hll.DirectHllArray.isEmpty
final int flags = buf.get(position + 5); // get(FLAGS_BYTE)
return (flags & 4) > 0; // (flags & EMPTY_FLAG_MASK) > 0
default: // Unknown implementation
// Can't say for sure, so return false.
return false;
}
}
}

View File

@ -22,13 +22,11 @@ package org.apache.druid.query.aggregation.datasketches.hll;
import org.apache.datasketches.hll.HllSketch;
import org.apache.datasketches.hll.TgtHllType;
import org.apache.datasketches.hll.Union;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.ColumnValueSelector;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
/**
* This aggregator merges existing sketches.
@ -64,10 +62,7 @@ public class HllSketchMergeBufferAggregator implements BufferAggregator
return;
}
final WritableMemory mem = WritableMemory.writableWrap(buf, ByteOrder.LITTLE_ENDIAN)
.writableRegion(position, helper.getSize());
final Union union = Union.writableWrap(mem);
final Union union = helper.getOrCreateUnion(buf, position);
union.update(sketch.getSketch());
}
@ -80,7 +75,7 @@ public class HllSketchMergeBufferAggregator implements BufferAggregator
@Override
public void close()
{
// nothing to close
helper.close();
}
@Override
@ -104,4 +99,10 @@ public class HllSketchMergeBufferAggregator implements BufferAggregator
// See https://github.com/apache/druid/pull/6893#discussion_r250726028
inspector.visit("lgK", helper.getLgK());
}
@Override
public void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer)
{
helper.relocate(oldPosition, newPosition, oldBuffer, newBuffer);
}
}

View File

@ -19,6 +19,8 @@
package org.apache.druid.query.aggregation.datasketches.hll;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import org.apache.datasketches.hll.HllSketch;
import org.apache.datasketches.hll.TgtHllType;
import org.apache.datasketches.hll.Union;
@ -26,15 +28,18 @@ import org.apache.datasketches.memory.WritableMemory;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.IdentityHashMap;
public class HllSketchMergeBufferAggregatorHelper
{
private final int lgK;
private final TgtHllType tgtHllType;
private final int size;
private final IdentityHashMap<ByteBuffer, Int2ObjectMap<Union>> unions = new IdentityHashMap<>();
private final IdentityHashMap<ByteBuffer, WritableMemory> memCache = new IdentityHashMap<>();
/**
* Used by {@link #init(ByteBuffer, int)}. We initialize by copying a prebuilt empty Union image.
* Used by {@link #initializeEmptyUnion(ByteBuffer, int)}. We initialize by copying a prebuilt empty Union image.
* {@link HllSketchBuildBufferAggregator} does something similar, but different enough that we don't share code. The
* "build" flavor uses {@link HllSketch} objects and the "merge" flavor uses {@link Union} objects.
*/
@ -57,20 +62,7 @@ public class HllSketchMergeBufferAggregatorHelper
*/
public void init(final ByteBuffer buf, final int position)
{
// Copy prebuilt empty union object.
// Not necessary to cache a Union wrapper around the initialized memory, because:
// - It is cheap to reconstruct by re-wrapping the memory in "aggregate" and "get".
// - Unlike the HllSketch objects used by HllSketchBuildBufferAggregator, our Union objects never exceed the
// max size and therefore do not need to be potentially moved in-heap.
final int oldPosition = buf.position();
try {
buf.position(position);
buf.put(emptyUnion);
}
finally {
buf.position(oldPosition);
}
createNewUnion(buf, position, false);
}
/**
@ -93,4 +85,76 @@ public class HllSketchMergeBufferAggregatorHelper
{
return size;
}
public void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer)
{
createNewUnion(newBuffer, newPosition, true);
Int2ObjectMap<Union> unionMap = unions.get(oldBuffer);
if (unionMap != null) {
unionMap.remove(oldPosition);
if (unionMap.isEmpty()) {
unions.remove(oldBuffer);
memCache.remove(oldBuffer);
}
}
}
public Union getOrCreateUnion(ByteBuffer buf, int position)
{
Int2ObjectMap<Union> unionMap = unions.get(buf);
Union union = unionMap != null ? unionMap.get(position) : null;
if (union != null) {
return union;
}
return createNewUnion(buf, position, true);
}
private Union createNewUnion(ByteBuffer buf, int position, boolean isWrapped)
{
if (!isWrapped) {
initializeEmptyUnion(buf, position);
}
final WritableMemory mem = getMemory(buf).writableRegion(position, size);
Union union = Union.writableWrap(mem);
Int2ObjectMap<Union> unionMap = unions.get(buf);
if (unionMap == null) {
unionMap = new Int2ObjectOpenHashMap<>();
unions.put(buf, unionMap);
}
unionMap.put(position, union);
return union;
}
/**
* Copy prebuilt empty union object into the specified buffer and position
*/
private void initializeEmptyUnion(ByteBuffer buf, int position)
{
final int oldPosition = buf.position();
try {
buf.position(position);
buf.put(emptyUnion);
}
finally {
buf.position(oldPosition);
}
}
public void close()
{
unions.clear();
memCache.clear();
}
private WritableMemory getMemory(ByteBuffer buffer)
{
WritableMemory mem = memCache.get(buffer);
if (mem == null) {
mem = WritableMemory.writableWrap(buffer, ByteOrder.LITTLE_ENDIAN);
memCache.put(buffer, mem);
}
return mem;
}
}

View File

@ -21,7 +21,6 @@ package org.apache.druid.query.aggregation.datasketches.hll;
import org.apache.datasketches.hll.TgtHllType;
import org.apache.datasketches.hll.Union;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.datasketches.util.ToObjectVectorColumnProcessorFactory;
import org.apache.druid.segment.ColumnProcessors;
@ -29,7 +28,6 @@ import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.function.Supplier;
public class HllSketchMergeVectorAggregator implements VectorAggregator
@ -65,10 +63,7 @@ public class HllSketchMergeVectorAggregator implements VectorAggregator
{
final Object[] vector = objectSupplier.get();
final WritableMemory mem = WritableMemory.writableWrap(buf, ByteOrder.LITTLE_ENDIAN)
.writableRegion(position, helper.getSize());
final Union union = Union.writableWrap(mem);
final Union union = helper.getOrCreateUnion(buf, position);
for (int i = startRow; i < endRow; i++) {
if (vector[i] != null) {
union.update(((HllSketchHolder) vector[i]).getSketch());
@ -85,7 +80,6 @@ public class HllSketchMergeVectorAggregator implements VectorAggregator
final int positionOffset
)
{
final WritableMemory mem = WritableMemory.writableWrap(buf, ByteOrder.LITTLE_ENDIAN);
final Object[] vector = objectSupplier.get();
for (int i = 0; i < numRows; i++) {
@ -93,7 +87,7 @@ public class HllSketchMergeVectorAggregator implements VectorAggregator
if (o != null) {
final int position = positions[i] + positionOffset;
final Union union = Union.writableWrap(mem.writableRegion(position, helper.getSize()));
final Union union = helper.getOrCreateUnion(buf, position);
union.update(o.getSketch());
}
}
@ -108,6 +102,12 @@ public class HllSketchMergeVectorAggregator implements VectorAggregator
@Override
public void close()
{
// Nothing to close.
helper.close();
}
@Override
public void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer)
{
helper.relocate(oldPosition, newPosition, oldBuffer, newBuffer);
}
}

View File

@ -97,7 +97,11 @@ public class HllSketchToEstimatePostAggregator implements PostAggregator
@Override
public Object compute(final Map<String, Object> combinedAggregators)
{
final HllSketchHolder holder = HllSketchHolder.fromObj(field.compute(combinedAggregators));
Object hllSketchHolderObject = field.compute(combinedAggregators);
if (hllSketchHolderObject == null) {
return 0.0D;
}
final HllSketchHolder holder = HllSketchHolder.fromObj(hllSketchHolderObject);
// The union object always uses an HLL_8 sketch, so we always get that. The target type doesn't actually impact
// the estimate anyway, so whatever gives us the "cheapest" operation should be good.
double estimate = holder.getEstimate();

View File

@ -103,7 +103,11 @@ public class HllSketchToEstimateWithBoundsPostAggregator implements PostAggregat
@Override
public double[] compute(final Map<String, Object> combinedAggregators)
{
final HllSketchHolder sketch = HllSketchHolder.fromObj(field.compute(combinedAggregators));
Object hllSketchHolderObject = field.compute(combinedAggregators);
if (hllSketchHolderObject == null) {
return new double[] {0.0D, 0.0D, 0.0D};
}
final HllSketchHolder sketch = HllSketchHolder.fromObj(hllSketchHolderObject);
return new double[] {sketch.getEstimate(), sketch.getLowerBound(numStdDevs), sketch.getUpperBound(numStdDevs)};
}

View File

@ -83,7 +83,11 @@ public class HllSketchToStringPostAggregator implements PostAggregator
@Override
public String compute(final Map<String, Object> combinedAggregators)
{
final HllSketch sketch = HllSketchHolder.fromObj(field.compute(combinedAggregators)).getSketch();
Object hllSketchHolderObject = field.compute(combinedAggregators);
if (hllSketchHolderObject == null) {
return "Null Sketch";
}
final HllSketch sketch = HllSketchHolder.fromObj(hllSketchHolderObject).getSketch();
return sketch.toString();
}

View File

@ -121,8 +121,11 @@ public class HllSketchUnionPostAggregator implements PostAggregator
{
final Union union = new Union(lgK);
for (final PostAggregator field : fields) {
final HllSketchHolder sketch = HllSketchHolder.fromObj(field.compute(combinedAggregators));
union.update(sketch.getSketch());
Object hllSketchHolderObject = field.compute(combinedAggregators);
if (hllSketchHolderObject != null) {
final HllSketchHolder holder = HllSketchHolder.fromObj(hllSketchHolderObject);
union.update(holder.getSketch());
}
}
return HllSketchHolder.of(union.getResult(tgtHllType));
}

View File

@ -82,7 +82,7 @@ public class HllPostAggExprMacros
final Object valObj = eval.value();
if (valObj == null) {
return ExprEval.of(null);
return ExprEval.of(0.0D);
}
HllSketchHolder h = HllSketchHolder.fromObj(valObj);
double estimate = h.getEstimate();

View File

@ -82,7 +82,7 @@ public class HllSketchAggregatorFactoryTest
@Test
public void testFinalizeComputationNull()
{
Assert.assertNull(target.finalizeComputation(null));
Assert.assertEquals(0.0D, target.finalizeComputation(null));
}
@Test

View File

@ -23,6 +23,8 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.datasketches.hll.HllSketch;
import org.apache.druid.data.input.MapBasedRow;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.StringEncoding;
import org.apache.druid.java.util.common.granularity.Granularities;
@ -36,6 +38,8 @@ import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.GroupByQueryConfig;
import org.apache.druid.query.groupby.GroupByQueryRunnerTest;
import org.apache.druid.query.groupby.ResultRow;
import org.apache.druid.query.groupby.epinephelinae.GroupByTestColumnSelectorFactory;
import org.apache.druid.query.groupby.epinephelinae.GrouperTestUtil;
import org.apache.druid.query.timeseries.TimeseriesResultValue;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
@ -417,6 +421,22 @@ public class HllSketchAggregatorTest extends InitializedNullHandlingTest
Assert.assertEquals(expectedSummary, ((HllSketchHolder) row.get(4)).getSketch().toString());
}
@Test
public void testRelocation()
{
final GroupByTestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
HllSketchHolder sketchHolder = new HllSketchHolder(null, new HllSketch());
sketchHolder.getSketch().update(1);
columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("sketch", sketchHolder)));
HllSketchHolder[] holders = groupByHelper.runRelocateVerificationTest(
new HllSketchMergeAggregatorFactory("sketch", "sketch", null, null, null, true, true),
columnSelectorFactory,
HllSketchHolder.class
);
Assert.assertEquals(holders[0].getEstimate(), holders[1].getEstimate(), 0);
}
private static String buildParserJson(List<String> dimensions, List<String> columns)
{
Map<String, Object> timestampSpec = ImmutableMap.of(

View File

@ -26,6 +26,7 @@ import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
import org.apache.druid.msq.test.MSQTestBase;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.hll.HllSketchBuildAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.groupby.GroupByQuery;
@ -95,4 +96,49 @@ public class MSQDataSketchesTest extends MSQTestBase
)
.verifyResults();
}
@Test
public void testEmptyHllSketch()
{
RowSignature resultSignature =
RowSignature.builder()
.add("c", ColumnType.LONG)
.build();
GroupByQuery query =
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(
aggregators(
new FilteredAggregatorFactory(
new HllSketchBuildAggregatorFactory("a0", "dim2", 12, "HLL_4", null, true, true),
equality("dim1", "nonexistent", ColumnType.STRING),
"a0"
)
)
)
.setContext(DEFAULT_MSQ_CONTEXT)
.build();
testSelectQuery()
.setSql("SELECT APPROX_COUNT_DISTINCT_DS_HLL(dim2) FILTER(WHERE dim1 = 'nonexistent') AS c FROM druid.foo")
.setExpectedMSQSpec(MSQSpec.builder()
.query(query)
.columnMappings(new ColumnMappings(ImmutableList.of(
new ColumnMapping("a0", "c"))
))
.tuningConfig(MSQTuningConfig.defaultConfig())
.destination(TaskReportMSQDestination.INSTANCE)
.build())
.setQueryContext(DEFAULT_MSQ_CONTEXT)
.setExpectedRowSignature(resultSignature)
.setExpectedResultRows(
ImmutableList.of(
new Object[]{0L}
)
)
.verifyResults();
}
}