Add octet streaming of sketchs in MSQ (#16269)

There are a few issues with using Jackson serialization in sending datasketches between controller and worker in MSQ. This caused a blowup due to holding multiple copies of the sketch being stored.

This PR aims to resolve this by switching to deserializing the sketch payload without Jackson.

The PR adds a new query parameter used during communication between controller and worker while fetching sketches, "sketchEncoding".

    If the value of this parameter is OCTET, the sketch is returned as a binary encoding, done by ClusterByStatisticsSnapshotSerde.
    If the value is not the above, the sketch is encoded by Jackson as before.
This commit is contained in:
Adarsh Sanjeev 2024-05-28 18:12:38 +05:30 committed by GitHub
parent 9d77ef04f4
commit 21f725f33e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1171 additions and 37 deletions

View File

@ -211,6 +211,12 @@
<version>${project.parent.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.druid.extensions</groupId>
<artifactId>druid-multi-stage-query</artifactId>
<version>${project.parent.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<properties>

View File

@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.benchmark;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.key.KeyTestUtils;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import java.util.stream.LongStream;
@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 3)
@Measurement(iterations = 5)
public class MsqSketchesBenchmark extends InitializedNullHandlingTest
{
private static final int MAX_BYTES = 1_000_000_000;
private static final int MAX_BUCKETS = 10_000;
private static final RowSignature SIGNATURE = RowSignature.builder()
.add("x", ColumnType.LONG)
.add("y", ColumnType.LONG)
.add("z", ColumnType.STRING)
.build();
private static final ClusterBy CLUSTER_BY_XYZ_BUCKET_BY_X = new ClusterBy(
ImmutableList.of(
new KeyColumn("x", KeyOrder.ASCENDING),
new KeyColumn("y", KeyOrder.ASCENDING),
new KeyColumn("z", KeyOrder.ASCENDING)
),
1
);
@Param({"1", "1000"})
private long numBuckets;
@Param({"100000", "1000000"})
private long numRows;
@Param({"true", "false"})
private boolean aggregate;
private ObjectMapper jsonMapper;
private ClusterByStatisticsSnapshot snapshot;
@Setup(Level.Trial)
public void setup()
{
jsonMapper = TestHelper.makeJsonMapper();
final Iterable<RowKey> keys = () ->
LongStream.range(0, numRows)
.mapToObj(n -> createKey(numBuckets, n))
.iterator();
ClusterByStatisticsCollectorImpl collector = makeCollector(aggregate);
keys.forEach(k -> collector.add(k, 1));
snapshot = collector.snapshot();
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void benchmarkJacksonSketch(Blackhole blackhole) throws IOException
{
final byte[] serializedSnapshot = jsonMapper.writeValueAsBytes(snapshot);
final ClusterByStatisticsSnapshot deserializedSnapshot = jsonMapper.readValue(
serializedSnapshot,
ClusterByStatisticsSnapshot.class
);
blackhole.consume(deserializedSnapshot);
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void benchmarkOctetSketch(Blackhole blackhole) throws IOException
{
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ClusterByStatisticsSnapshotSerde.serialize(byteArrayOutputStream, snapshot);
final ByteBuffer serializedSnapshot = ByteBuffer.wrap(byteArrayOutputStream.toByteArray());
final ClusterByStatisticsSnapshot deserializedSnapshot = ClusterByStatisticsSnapshotSerde.deserialize(serializedSnapshot);
blackhole.consume(deserializedSnapshot);
}
private ClusterByStatisticsCollectorImpl makeCollector(final boolean aggregate)
{
return (ClusterByStatisticsCollectorImpl) ClusterByStatisticsCollectorImpl.create(MsqSketchesBenchmark.CLUSTER_BY_XYZ_BUCKET_BY_X, SIGNATURE, MAX_BYTES, MAX_BUCKETS, aggregate, false);
}
private static RowKey createKey(final long numBuckets, final long keyNo)
{
final Object[] key = new Object[3];
key[0] = keyNo % numBuckets;
key[1] = keyNo % 5;
key[2] = StringUtils.repeat("*", 67);
return KeyTestUtils.createKey(KeyTestUtils.createKeySignature(MsqSketchesBenchmark.CLUSTER_BY_XYZ_BUCKET_BY_X.getColumns(), SIGNATURE), key);
}
}

View File

@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.indexing.client;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.druid.java.util.http.client.response.BytesFullResponseHolder;
import org.apache.druid.java.util.http.client.response.ClientResponse;
import org.apache.druid.java.util.http.client.response.HttpResponseHandler;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.handler.codec.http.HttpChunk;
import org.jboss.netty.handler.codec.http.HttpResponse;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import java.nio.ByteBuffer;
import java.util.function.Function;
public class SketchResponseHandler implements HttpResponseHandler<BytesFullResponseHolder, ClusterByStatisticsSnapshot>
{
private final ObjectMapper jsonMapper;
private Function<BytesFullResponseHolder, ClusterByStatisticsSnapshot> deserializerFunction;
public SketchResponseHandler(ObjectMapper jsonMapper)
{
this.jsonMapper = jsonMapper;
}
@Override
public ClientResponse<BytesFullResponseHolder> handleResponse(HttpResponse response, HttpResponseHandler.TrafficCop
trafficCop)
{
final BytesFullResponseHolder holder = new BytesFullResponseHolder(response);
final String contentType = response.headers().get(HttpHeaders.CONTENT_TYPE);
if (MediaType.APPLICATION_OCTET_STREAM.equals(contentType)) {
deserializerFunction = responseHolder -> ClusterByStatisticsSnapshotSerde.deserialize(ByteBuffer.wrap(responseHolder.getContent()));
} else {
deserializerFunction = responseHolder -> responseHolder.deserialize(jsonMapper, new TypeReference<ClusterByStatisticsSnapshot>()
{
});
}
holder.addChunk(getContentBytes(response.getContent()));
return ClientResponse.unfinished(holder);
}
@Override
public ClientResponse<BytesFullResponseHolder> handleChunk(
ClientResponse<BytesFullResponseHolder> response,
HttpChunk chunk,
long chunkNum
)
{
BytesFullResponseHolder holder = response.getObj();
if (holder == null) {
return ClientResponse.finished(null);
}
holder.addChunk(getContentBytes(chunk.getContent()));
return response;
}
@Override
public ClientResponse<ClusterByStatisticsSnapshot> done(ClientResponse<BytesFullResponseHolder> response)
{
return ClientResponse.finished(deserializerFunction.apply(response.getObj()));
}
@Override
public void exceptionCaught(ClientResponse<BytesFullResponseHolder> clientResponse, Throwable e)
{
}
private byte[] getContentBytes(ChannelBuffer content)
{
byte[] contentBytes = new byte[content.readableBytes()];
content.readBytes(contentBytes);
return contentBytes;
}
}

View File

@ -32,11 +32,13 @@ import org.apache.druid.msq.indexing.MSQWorkerTask;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde;
import org.apache.druid.segment.realtime.firehose.ChatHandler;
import org.apache.druid.segment.realtime.firehose.ChatHandlers;
import org.apache.druid.server.security.Action;
import org.apache.druid.utils.CloseableUtils;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
@ -185,11 +187,12 @@ public class WorkerChatHandler implements ChatHandler
@POST
@Path("/keyStatistics/{queryId}/{stageNumber}")
@Produces(MediaType.APPLICATION_JSON)
@Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM})
@Consumes(MediaType.APPLICATION_JSON)
public Response httpFetchKeyStatistics(
@PathParam("queryId") final String queryId,
@PathParam("stageNumber") final int stageNumber,
@QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding,
@Context final HttpServletRequest req
)
{
@ -198,9 +201,17 @@ public class WorkerChatHandler implements ChatHandler
StageId stageId = new StageId(queryId, stageNumber);
try {
clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
return Response.status(Response.Status.ACCEPTED)
.entity(clusterByStatisticsSnapshot)
.build();
if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) {
return Response.status(Response.Status.ACCEPTED)
.type(MediaType.APPLICATION_OCTET_STREAM)
.entity((StreamingOutput) output -> ClusterByStatisticsSnapshotSerde.serialize(output, clusterByStatisticsSnapshot))
.build();
} else {
return Response.status(Response.Status.ACCEPTED)
.type(MediaType.APPLICATION_JSON)
.entity(clusterByStatisticsSnapshot)
.build();
}
}
catch (Exception e) {
String errorMessage = StringUtils.format(
@ -217,12 +228,13 @@ public class WorkerChatHandler implements ChatHandler
@POST
@Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}")
@Produces(MediaType.APPLICATION_JSON)
@Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM})
@Consumes(MediaType.APPLICATION_JSON)
public Response httpFetchKeyStatisticsWithSnapshot(
@PathParam("queryId") final String queryId,
@PathParam("stageNumber") final int stageNumber,
@PathParam("timeChunk") final long timeChunk,
@QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding,
@Context final HttpServletRequest req
)
{
@ -231,9 +243,17 @@ public class WorkerChatHandler implements ChatHandler
StageId stageId = new StageId(queryId, stageNumber);
try {
snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk);
return Response.status(Response.Status.ACCEPTED)
.entity(snapshotForTimeChunk)
.build();
if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) {
return Response.status(Response.Status.ACCEPTED)
.type(MediaType.APPLICATION_OCTET_STREAM)
.entity((StreamingOutput) output -> ClusterByStatisticsSnapshotSerde.serialize(output, snapshotForTimeChunk))
.build();
} else {
return Response.status(Response.Status.ACCEPTED)
.type(MediaType.APPLICATION_JSON)
.entity(snapshotForTimeChunk)
.build();
}
}
catch (Exception e) {
String errorMessage = StringUtils.format(
@ -289,4 +309,20 @@ public class WorkerChatHandler implements ChatHandler
ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper());
return Response.status(Response.Status.OK).entity(worker.getCounters()).build();
}
/**
* Determines the encoding of key collectors returned by {@link #httpFetchKeyStatistics} and
* {@link #httpFetchKeyStatisticsWithSnapshot}.
*/
public enum SketchEncoding
{
/**
* The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}.
*/
OCTET_STREAM,
/**
* The key collector is encoded as json
*/
JSON
}
}

View File

@ -37,6 +37,8 @@ import org.apache.druid.java.util.http.client.response.BytesFullResponseHandler;
import org.apache.druid.java.util.http.client.response.BytesFullResponseHolder;
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.exec.WorkerClient;
import org.apache.druid.msq.indexing.client.SketchResponseHandler;
import org.apache.druid.msq.indexing.client.WorkerChatHandler;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
@ -92,17 +94,15 @@ public abstract class BaseWorkerClientImpl implements WorkerClient
)
{
String path = StringUtils.format(
"/keyStatistics/%s/%d",
"/keyStatistics/%s/%d?sketchEncoding=%s",
StringUtils.urlEncode(stageId.getQueryId()),
stageId.getStageNumber()
stageId.getStageNumber(),
WorkerChatHandler.SketchEncoding.OCTET_STREAM
);
return FutureUtils.transform(
getClient(workerId).asyncRequest(
new RequestBuilder(HttpMethod.POST, path).header(HttpHeaders.ACCEPT, contentType),
new BytesFullResponseHandler()
),
holder -> deserialize(holder, new TypeReference<ClusterByStatisticsSnapshot>() {})
return getClient(workerId).asyncRequest(
new RequestBuilder(HttpMethod.POST, path),
new SketchResponseHandler(objectMapper)
);
}
@ -114,18 +114,16 @@ public abstract class BaseWorkerClientImpl implements WorkerClient
)
{
String path = StringUtils.format(
"/keyStatisticsForTimeChunk/%s/%d/%d",
"/keyStatisticsForTimeChunk/%s/%d/%d?sketchEncoding=%s",
StringUtils.urlEncode(stageId.getQueryId()),
stageId.getStageNumber(),
timeChunk
timeChunk,
WorkerChatHandler.SketchEncoding.OCTET_STREAM
);
return FutureUtils.transform(
getClient(workerId).asyncRequest(
new RequestBuilder(HttpMethod.POST, path).header(HttpHeaders.ACCEPT, contentType),
new BytesFullResponseHandler()
),
holder -> deserialize(holder, new TypeReference<ClusterByStatisticsSnapshot>() {})
return getClient(workerId).asyncRequest(
new RequestBuilder(HttpMethod.POST, path),
new SketchResponseHandler(objectMapper)
);
}
@ -150,7 +148,7 @@ public abstract class BaseWorkerClientImpl implements WorkerClient
}
/**
* Client-side method for {@link org.apache.druid.msq.indexing.client.WorkerChatHandler#httpPostCleanupStage}.
* Client-side method for {@link WorkerChatHandler#httpPostCleanupStage}.
*/
@Override
public ListenableFuture<Void> postCleanupStage(

View File

@ -20,6 +20,7 @@
package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
@ -39,7 +40,7 @@ public class ClusterByStatisticsSnapshot
private final Set<Integer> hasMultipleValues;
@JsonCreator
ClusterByStatisticsSnapshot(
public ClusterByStatisticsSnapshot(
@JsonProperty("buckets") final Map<Long, Bucket> buckets,
@JsonProperty("hasMultipleValues") @Nullable final Set<Integer> hasMultipleValues
)
@ -54,7 +55,7 @@ public class ClusterByStatisticsSnapshot
}
@JsonProperty("buckets")
Map<Long, Bucket> getBuckets()
public Map<Long, Bucket> getBuckets()
{
return buckets;
}
@ -70,7 +71,7 @@ public class ClusterByStatisticsSnapshot
@JsonProperty("hasMultipleValues")
@JsonInclude(JsonInclude.Include.NON_EMPTY)
Set<Integer> getHasMultipleValues()
public Set<Integer> getHasMultipleValues()
{
return hasMultipleValues;
}
@ -103,14 +104,14 @@ public class ClusterByStatisticsSnapshot
return Objects.hash(buckets, hasMultipleValues);
}
static class Bucket
public static class Bucket
{
private final RowKey bucketKey;
private final double bytesRetained;
private final KeyCollectorSnapshot keyCollectorSnapshot;
@JsonCreator
Bucket(
public Bucket(
@JsonProperty("bucketKey") RowKey bucketKey,
@JsonProperty("data") KeyCollectorSnapshot keyCollectorSnapshot,
@JsonProperty("bytesRetained") double bytesRetained
@ -127,6 +128,12 @@ public class ClusterByStatisticsSnapshot
return bucketKey;
}
@JsonIgnore
public double getBytesRetained()
{
return bytesRetained;
}
@JsonProperty("data")
public KeyCollectorSnapshot getKeyCollectorSnapshot()
{

View File

@ -25,6 +25,8 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.statistics.serde.DelegateOrMinSerializer;
import org.apache.druid.msq.statistics.serde.KeyCollectorSnapshotSerializer;
import javax.annotation.Nullable;
import java.util.Objects;
@ -87,4 +89,19 @@ public class DelegateOrMinKeyCollectorSnapshot<T extends KeyCollectorSnapshot> i
{
return Objects.hash(snapshot, minKey);
}
@Override
public KeyCollectorSnapshotSerializer getSerializer()
{
return new DelegateOrMinSerializer();
}
@Override
public String toString()
{
return "DelegateOrMinKeyCollectorSnapshot{" +
"snapshot=" + snapshot +
", minKey=" + minKey +
'}';
}
}

View File

@ -26,6 +26,8 @@ import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.msq.statistics.serde.DistinctSnapshotSerializer;
import org.apache.druid.msq.statistics.serde.KeyCollectorSnapshotSerializer;
import java.util.HashMap;
import java.util.List;
@ -40,7 +42,7 @@ public class DistinctKeySnapshot implements KeyCollectorSnapshot
private final int spaceReductionFactor;
@JsonCreator
DistinctKeySnapshot(
public DistinctKeySnapshot(
@JsonProperty("keys") final List<SerializablePair<RowKey, Long>> keys,
@JsonProperty("spaceReductionFactor") final int spaceReductionFactor
)
@ -94,4 +96,10 @@ public class DistinctKeySnapshot implements KeyCollectorSnapshot
// Not expected to be called in production, so it's OK that this calls getKeysAsMap() each time.
return Objects.hash(getKeysAsMap(), spaceReductionFactor);
}
@Override
public KeyCollectorSnapshotSerializer getSerializer()
{
return new DistinctSnapshotSerializer();
}
}

View File

@ -21,6 +21,7 @@ package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.apache.druid.msq.statistics.serde.KeyCollectorSnapshotSerializer;
/**
* Marker interface for deserialization.
@ -33,4 +34,5 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
})
public interface KeyCollectorSnapshot
{
KeyCollectorSnapshotSerializer getSerializer();
}

View File

@ -22,6 +22,8 @@ package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.druid.msq.statistics.serde.KeyCollectorSnapshotSerializer;
import org.apache.druid.msq.statistics.serde.QuantilesSnapshotSerializer;
import java.util.Objects;
@ -71,4 +73,19 @@ public class QuantilesSketchKeyCollectorSnapshot implements KeyCollectorSnapshot
{
return Objects.hash(encodedSketch, averageKeyLength);
}
@Override
public KeyCollectorSnapshotSerializer getSerializer()
{
return new QuantilesSnapshotSerializer();
}
@Override
public String toString()
{
return "QuantilesSketchKeyCollectorSnapshot{" +
"encodedSketch='" + encodedSketch + '\'' +
", averageKeyLength=" + averageKeyLength +
'}';
}
}

View File

@ -0,0 +1,194 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics.serde;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.KeyCollectorSnapshot;
import javax.validation.constraints.NotNull;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* Handles the serialization and deserialization of {@link ClusterByStatisticsSnapshot}, into a byte array.
*/
public class ClusterByStatisticsSnapshotSerde
{
private static final byte EMPTY_HEADER = 0x0;
/**
* Deserializes the {@link ClusterByStatisticsSnapshot} and writes it to the {@link OutputStream}.
* <br>
* Format:
* - 1 byte : Header byte, used for holding version
* - 4 bytes: Number of buckets, N
* - 4 bytes: Number of multivalue entries in {@link ClusterByStatisticsSnapshot#getHasMultipleValues()}
* - 4 * number of multivalue entries bytes: List of integers
* - N buckets as byte arrays serialized by {@link #serializeBucket(OutputStream, ClusterByStatisticsSnapshot.Bucket)}
*/
public static void serialize(OutputStream outputStream, @NotNull ClusterByStatisticsSnapshot snapshot) throws IOException
{
final Map<Long, ClusterByStatisticsSnapshot.Bucket> buckets = snapshot.getBuckets();
final Set<Integer> multipleValueBuckets = snapshot.getHasMultipleValues();
// Write a header byte, to be used to contain any metadata in the future.
outputStream.write(EMPTY_HEADER);
writeIntToStream(outputStream, buckets.size());
ByteBuffer multivalueBuffer = ByteBuffer.allocate(Integer.BYTES + multipleValueBuckets.size())
.putInt(multipleValueBuckets.size());
multipleValueBuckets.forEach(multivalueBuffer::putInt);
outputStream.write(multivalueBuffer.array());
// Serialize the buckets
for (Map.Entry<Long, ClusterByStatisticsSnapshot.Bucket> entry : buckets.entrySet()) {
writeLongToStream(outputStream, entry.getKey());
serializeBucket(outputStream, entry.getValue());
}
}
private static final int HEADER_OFFSET = 0;
private static final int BUCKET_COUNT_OFFSET = HEADER_OFFSET + Byte.BYTES;
private static final int MV_SET_SIZE_OFFSET = BUCKET_COUNT_OFFSET + Integer.BYTES;
private static final int MV_VALUES_OFFSET = MV_SET_SIZE_OFFSET + Integer.BYTES;
private static final int TIMECHUNK_OFFSET = 0;
private static final int BUCKET_SIZE_OFFSET = TIMECHUNK_OFFSET + Long.BYTES;
private static final int BUCKET_OFFSET = BUCKET_SIZE_OFFSET + Integer.BYTES;
public static ClusterByStatisticsSnapshot deserialize(ByteBuffer byteBuffer)
{
int position = byteBuffer.position();
final int bucketCount = byteBuffer.getInt(position + BUCKET_COUNT_OFFSET);
final int mvSetSize = byteBuffer.getInt(position + MV_SET_SIZE_OFFSET);
final Set<Integer> hasMultiValues = new HashSet<>();
for (int offset = position + MV_VALUES_OFFSET; offset < position + MV_VALUES_OFFSET + mvSetSize * Integer.BYTES; offset += Integer.BYTES) {
hasMultiValues.add(byteBuffer.getInt(offset));
}
final Map<Long, ClusterByStatisticsSnapshot.Bucket> buckets = new HashMap<>();
// Move the buffer position
int nextBucket = position + MV_VALUES_OFFSET + Integer.BYTES * mvSetSize;
for (int bucketNo = 0; bucketNo < bucketCount; bucketNo++) {
position = byteBuffer.position(nextBucket).position();
final long timeChunk = byteBuffer.getLong(position + TIMECHUNK_OFFSET);
final int snapshotSize = byteBuffer.getInt(position + BUCKET_SIZE_OFFSET);
final ByteBuffer duplicate = (ByteBuffer) byteBuffer.duplicate()
.order(byteBuffer.order())
.position(position + BUCKET_OFFSET)
.limit(position + BUCKET_OFFSET + snapshotSize);
ClusterByStatisticsSnapshot.Bucket bucket = deserializeBucket(duplicate);
buckets.put(timeChunk, bucket);
nextBucket = position + BUCKET_OFFSET + snapshotSize;
}
return new ClusterByStatisticsSnapshot(buckets, hasMultiValues);
}
/**
* Format:
* - 8 bytes: bytesRetained
* - 4 bytes: keyArray length
* - 4 bytes: snapshot length
* - keyArray length bytes: serialized key array
* - snapshot length bytes: serialized snapshot
*/
static void serializeBucket(OutputStream outputStream, ClusterByStatisticsSnapshot.Bucket bucket) throws IOException
{
final byte[] bucketKeyArray = bucket.getBucketKey().array();
final double bytesRetained = bucket.getBytesRetained();
final KeyCollectorSnapshot snapshot = bucket.getKeyCollectorSnapshot();
final byte[] serializedSnapshot = snapshot.getSerializer().serialize(snapshot);
final int length = Double.BYTES // Bytes retained
+ 2 * Integer.BYTES // keyArray length and snapshot length
+ bucketKeyArray.length // serialized key array
+ serializedSnapshot.length; // serialized snapshot
outputStream.write(
ByteBuffer.allocate(Integer.BYTES + length) // Additionally, store length of the serialized array.
.putInt(length)
.putDouble(bytesRetained)
.putInt(bucketKeyArray.length)
.putInt(serializedSnapshot.length)
.put(bucketKeyArray)
.put(serializedSnapshot)
.array()
);
}
private static final int BYTES_RETAINED_OFFSET = 0;
private static final int KEY_LENGTH_OFFSET = BYTES_RETAINED_OFFSET + Double.BYTES;
private static final int SNAPSHOT_LENGTH_OFFSET = KEY_LENGTH_OFFSET + Integer.BYTES;
private static final int KEY_OFFSET = SNAPSHOT_LENGTH_OFFSET + Integer.BYTES;
static ClusterByStatisticsSnapshot.Bucket deserializeBucket(ByteBuffer byteBuffer)
{
int position = byteBuffer.position();
final double bytesRetained = byteBuffer.getDouble(position + BYTES_RETAINED_OFFSET);
final int keyLength = byteBuffer.getInt(position + KEY_LENGTH_OFFSET);
final int snapshotLength = byteBuffer.getInt(position + SNAPSHOT_LENGTH_OFFSET);
final ByteBuffer keyBuffer = (ByteBuffer) byteBuffer.duplicate()
.order(byteBuffer.order())
.position(position + KEY_OFFSET)
.limit(position + KEY_OFFSET + keyLength);
final byte[] byteKey = new byte[keyLength];
keyBuffer.get(byteKey);
final int snapshotOffset = position + KEY_OFFSET + keyLength;
final ByteBuffer snapshotBuffer = (ByteBuffer) byteBuffer.duplicate()
.order(byteBuffer.order())
.position(snapshotOffset)
.limit(snapshotOffset + snapshotLength);
return new ClusterByStatisticsSnapshot.Bucket(
RowKey.wrap(byteKey),
KeyCollectorSnapshotDeserializer.deserialize(snapshotBuffer),
bytesRetained
);
}
private static void writeIntToStream(OutputStream outputStream, int integerToWrite) throws IOException
{
outputStream.write(ByteBuffer.allocate(Integer.BYTES).putInt(integerToWrite).array());
}
private static void writeLongToStream(OutputStream outputStream, long longToWrite) throws IOException
{
outputStream.write(ByteBuffer.allocate(Long.BYTES).putLong(longToWrite).array());
}
}

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics.serde;
import com.google.common.base.Preconditions;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.msq.statistics.DelegateOrMinKeyCollectorSnapshot;
import org.apache.druid.msq.statistics.KeyCollectorSnapshot;
import java.nio.ByteBuffer;
public class DelegateOrMinSerializer extends KeyCollectorSnapshotSerializer
{
public static final byte TYPE = (byte) 0;
public static final byte ROWKEY_SNAPSHOT = (byte) 0;
public static final byte SKETCH_SNAPSHOT = (byte) 1;
@Override
protected byte getType()
{
return TYPE;
}
@Override
protected byte[] serializeKeyCollector(KeyCollectorSnapshot snapshot)
{
final DelegateOrMinKeyCollectorSnapshot<?> delegateOrMinKeySnapshot = (DelegateOrMinKeyCollectorSnapshot<?>) snapshot;
final RowKey minKey = delegateOrMinKeySnapshot.getMinKey();
if (minKey != null) {
// The sketch contains a minkey, and the sketch is null
return ByteBuffer.allocate(1 + Integer.BYTES + minKey.array().length)
.put(ROWKEY_SNAPSHOT)
.putInt(minKey.array().length)
.put(minKey.array())
.array();
} else {
// The sketch contains a delegate sketch, and the minkey is null.
final KeyCollectorSnapshot delegateSnapshot = Preconditions.checkNotNull((DelegateOrMinKeyCollectorSnapshot<?>) snapshot).getSnapshot();
byte[] serializedSnapshot = delegateSnapshot.getSerializer().serialize(delegateSnapshot);
return ByteBuffer.allocate(1 + Integer.BYTES + serializedSnapshot.length)
.put(SKETCH_SNAPSHOT)
.putInt(serializedSnapshot.length)
.put(serializedSnapshot)
.array();
}
}
}

View File

@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics.serde;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.msq.statistics.DistinctKeySnapshot;
import org.apache.druid.msq.statistics.KeyCollectorSnapshot;
import java.nio.ByteBuffer;
public class DistinctSnapshotSerializer extends KeyCollectorSnapshotSerializer
{
public static final byte TYPE = (byte) 1;
@Override
protected byte getType()
{
return TYPE;
}
@Override
protected byte[] serializeKeyCollector(KeyCollectorSnapshot collectorSnapshot)
{
final DistinctKeySnapshot snapshot = (DistinctKeySnapshot) collectorSnapshot;
int length = 2 * Integer.BYTES;
for (SerializablePair<RowKey, Long> key : snapshot.getKeys()) {
length += key.lhs.array().length + Integer.BYTES + Long.BYTES;
}
final ByteBuffer buffer = ByteBuffer.allocate(length)
.putInt(snapshot.getSpaceReductionFactor())
.putInt(snapshot.getKeys().size());
for (SerializablePair<RowKey, Long> key : snapshot.getKeys()) {
byte[] array = key.lhs.array();
buffer.putLong(key.rhs)
.putInt(array.length)
.put(array);
}
return buffer.array();
}
}

View File

@ -0,0 +1,143 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics.serde;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.msq.statistics.DelegateOrMinKeyCollectorSnapshot;
import org.apache.druid.msq.statistics.DistinctKeySnapshot;
import org.apache.druid.msq.statistics.KeyCollectorSnapshot;
import org.apache.druid.msq.statistics.QuantilesSketchKeyCollectorSnapshot;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
public class KeyCollectorSnapshotDeserializer
{
private static final ImmutableMap<Byte, Function<ByteBuffer, KeyCollectorSnapshot>> DESERIALIZERS =
ImmutableMap.<Byte, Function<ByteBuffer, KeyCollectorSnapshot>>builder()
.put(
QuantilesSnapshotSerializer.TYPE,
QuantilesSnapshotDeserializer::deserialize
).put(
DelegateOrMinSerializer.TYPE,
DelegateOrMinDeserializer::deserialize
).put(
DistinctSnapshotSerializer.TYPE,
DistinctSnapshotDeserializer::deserialize
)
.build();
public static KeyCollectorSnapshot deserialize(ByteBuffer byteBuffer)
{
int position = byteBuffer.position();
byte type = byteBuffer.get(position);
byteBuffer.position(position + 1);
return DESERIALIZERS.get(type).apply(byteBuffer);
}
static class DelegateOrMinDeserializer
{
private static final int TYPE_OFFSET = 0;
private static final int ARRAY_LENGTH_OFFSET = TYPE_OFFSET + Byte.BYTES;
private static final int ARRAY_OFFSET = ARRAY_LENGTH_OFFSET + Integer.BYTES;
public static KeyCollectorSnapshot deserialize(ByteBuffer byteBuffer)
{
int position = byteBuffer.position();
byte type = byteBuffer.get(position + TYPE_OFFSET);
int length = byteBuffer.getInt(position + ARRAY_LENGTH_OFFSET);
final ByteBuffer duplicate = (ByteBuffer) byteBuffer.duplicate()
.order(byteBuffer.order())
.position(position + ARRAY_OFFSET)
.limit(position + ARRAY_OFFSET + length);
if (type == DelegateOrMinSerializer.ROWKEY_SNAPSHOT) {
byte[] rowKey = new byte[length];
duplicate.get(rowKey, 0, length);
return new DelegateOrMinKeyCollectorSnapshot<>(null, RowKey.wrap(rowKey));
} else if (type == DelegateOrMinSerializer.SKETCH_SNAPSHOT) {
return new DelegateOrMinKeyCollectorSnapshot<>(KeyCollectorSnapshotDeserializer.deserialize(duplicate), null);
} else {
throw new UnsupportedOperationException();
}
}
}
static class DistinctSnapshotDeserializer
{
private static final int SPACE_REDUCTION_FACTOR_OFFSET = 0;
private static final int LIST_LENGTH_OFFSET = SPACE_REDUCTION_FACTOR_OFFSET + Integer.BYTES;
private static final int LIST_OFFSET = LIST_LENGTH_OFFSET + Integer.BYTES;
private static final int WEIGHT_OFFSET = 0;
private static final int KEY_LENGTH_OFFSET = WEIGHT_OFFSET + Long.BYTES;
private static final int KEY_OFFSET = KEY_LENGTH_OFFSET + Integer.BYTES;
public static KeyCollectorSnapshot deserialize(ByteBuffer byteBuffer)
{
int position = byteBuffer.position();
final int spaceReductionFactor = byteBuffer.getInt(position + SPACE_REDUCTION_FACTOR_OFFSET);
final int listLength = byteBuffer.getInt(position + LIST_LENGTH_OFFSET);
List<SerializablePair<RowKey, Long>> keys = new ArrayList<>();
position = byteBuffer.position(position + LIST_OFFSET).position();
for (int i = 0; i < listLength; i++) {
long weight = byteBuffer.getLong(position + WEIGHT_OFFSET);
int keyLength = byteBuffer.getInt(position + KEY_LENGTH_OFFSET);
final ByteBuffer duplicate = (ByteBuffer) byteBuffer.duplicate()
.order(byteBuffer.order())
.position(position + KEY_OFFSET)
.limit(position + KEY_OFFSET + keyLength);
byte[] key = new byte[keyLength];
duplicate.get(key);
keys.add(new SerializablePair<>(RowKey.wrap(key), weight));
position = byteBuffer.position(position + KEY_OFFSET + keyLength).position();
}
return new DistinctKeySnapshot(keys, spaceReductionFactor);
}
}
static class QuantilesSnapshotDeserializer
{
private static final int AVG_KEY_LENGTH_OFFSET = 0;
private static final int SKETCH_LENGTH_OFFSET = AVG_KEY_LENGTH_OFFSET + Double.BYTES;
private static final int SKETCH_OFFSET = SKETCH_LENGTH_OFFSET + Integer.BYTES;
public static QuantilesSketchKeyCollectorSnapshot deserialize(ByteBuffer byteBuffer)
{
int position = byteBuffer.position();
final double avgKeyLength = byteBuffer.getDouble(position + AVG_KEY_LENGTH_OFFSET);
final int sketchLength = byteBuffer.getInt(position + SKETCH_LENGTH_OFFSET);
final byte[] sketchBytes = new byte[sketchLength];
byteBuffer.position(position + SKETCH_OFFSET);
byteBuffer.get(sketchBytes);
final String sketch = StringUtils.encodeBase64String(sketchBytes);
return new QuantilesSketchKeyCollectorSnapshot(sketch, avgKeyLength);
}
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics.serde;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.KeyCollectorSnapshot;
import java.nio.ByteBuffer;
/**
* Serializes a {@link ClusterByStatisticsSnapshot} into a byte array.
*/
public abstract class KeyCollectorSnapshotSerializer
{
/**
* The type of sketch which has been serialized. The value returned by type cannot be the same across
* various implementation.
*/
protected abstract byte getType();
/**
* Converts the key collector in the argument into a byte array representation.
*/
protected abstract byte[] serializeKeyCollector(KeyCollectorSnapshot collectorSnapshot);
public byte[] serialize(KeyCollectorSnapshot collectorSnapshot)
{
byte type = getType();
byte[] value = serializeKeyCollector(collectorSnapshot);
return ByteBuffer.allocate(Byte.BYTES + value.length)
.put(type)
.put(value)
.array();
}
}

View File

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics.serde;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.msq.statistics.KeyCollectorSnapshot;
import org.apache.druid.msq.statistics.QuantilesSketchKeyCollectorSnapshot;
import java.nio.ByteBuffer;
/**
* Format:
* - 8 bytes: {@link QuantilesSketchKeyCollectorSnapshot#getAverageKeyLength()}
* - 4 bytes: length of quantile sketch snapshot (n)
* - n bytes: the sketch snapshot
*/
public class QuantilesSnapshotSerializer extends KeyCollectorSnapshotSerializer
{
public static final byte TYPE = (byte) 2;
@Override
protected byte getType()
{
return TYPE;
}
@Override
protected byte[] serializeKeyCollector(KeyCollectorSnapshot collectorSnapshot)
{
final QuantilesSketchKeyCollectorSnapshot quantileSnapshot = (QuantilesSketchKeyCollectorSnapshot) collectorSnapshot;
double averageKeyLength = quantileSnapshot.getAverageKeyLength();
final byte[] sketch = StringUtils.decodeBase64String(quantileSnapshot.getEncodedSketch());
return ByteBuffer.allocate(Double.BYTES + Integer.BYTES + sketch.length)
.putDouble(averageKeyLength)
.putInt(sketch.length)
.put(sketch)
.array();
}
}

View File

@ -92,7 +92,7 @@ public class WorkerChatHandlerTest
WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
Assert.assertEquals(
ClusterByStatisticsSnapshot.empty(),
chatHandler.httpFetchKeyStatistics(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), req)
chatHandler.httpFetchKeyStatistics(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), null, req)
.getEntity()
);
}
@ -103,7 +103,7 @@ public class WorkerChatHandlerTest
WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
Assert.assertEquals(
Response.Status.BAD_REQUEST.getStatusCode(),
chatHandler.httpFetchKeyStatistics("123", 2, req)
chatHandler.httpFetchKeyStatistics("123", 2, null, req)
.getStatus()
);
}
@ -114,7 +114,7 @@ public class WorkerChatHandlerTest
WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
Assert.assertEquals(
ClusterByStatisticsSnapshot.empty(),
chatHandler.httpFetchKeyStatisticsWithSnapshot(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), 1, req)
chatHandler.httpFetchKeyStatisticsWithSnapshot(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), 1, null, req)
.getEntity()
);
}
@ -125,7 +125,7 @@ public class WorkerChatHandlerTest
WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
Assert.assertEquals(
Response.Status.BAD_REQUEST.getStatusCode(),
chatHandler.httpFetchKeyStatisticsWithSnapshot("123", 2, 1, req)
chatHandler.httpFetchKeyStatisticsWithSnapshot("123", 2, 1, null, req)
.getStatus()
);
}

View File

@ -19,7 +19,6 @@
package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
@ -37,6 +36,7 @@ import org.apache.druid.indexing.common.task.batch.TooManyBucketsException;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
@ -48,7 +48,10 @@ import org.junit.Assert;
import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
@ -968,6 +971,7 @@ public class ClusterByStatisticsCollectorImplTest extends InitializedNullHandlin
final ClusterByStatisticsCollector collector
)
{
// Verify jackson serialization
try {
final ObjectMapper jsonMapper = TestHelper.makeJsonMapper();
final ClusterByStatisticsSnapshot snapshot = collector.snapshot();
@ -977,8 +981,15 @@ public class ClusterByStatisticsCollectorImplTest extends InitializedNullHandlin
);
Assert.assertEquals(StringUtils.format("%s: snapshot is serializable", testName), snapshot, snapshot2);
// Verify octet stream serialization
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ClusterByStatisticsSnapshotSerde.serialize(byteArrayOutputStream, snapshot);
final ClusterByStatisticsSnapshot snapshot3 = ClusterByStatisticsSnapshotSerde.deserialize(ByteBuffer.wrap(byteArrayOutputStream.toByteArray()));
Assert.assertEquals(StringUtils.format("%s: snapshot is serializable", testName), snapshot, snapshot3);
}
catch (JsonProcessingException e) {
catch (IOException e) {
throw new RuntimeException(e);
}
}

View File

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics.serde;
import com.google.common.collect.ImmutableList;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
public class KeyCollectorSnapshotSerializerTest extends InitializedNullHandlingTest
{
private static final int MAX_BYTES = 1_000_000_000;
private static final int MAX_BUCKETS = 10_000;
private static final RowSignature SIGNATURE = RowSignature.builder()
.add("x", ColumnType.LONG)
.add("y", ColumnType.LONG)
.add("z", ColumnType.STRING)
.build();
private static final ClusterBy CLUSTER_BY_XYZ_BUCKET_BY_X = new ClusterBy(
ImmutableList.of(
new KeyColumn("x", KeyOrder.ASCENDING),
new KeyColumn("y", KeyOrder.ASCENDING),
new KeyColumn("z", KeyOrder.ASCENDING)
),
1
);
@Test
public void testEmptyQuantilesSnapshot() throws IOException
{
ClusterByStatisticsCollectorImpl collector = makeCollector(false);
ClusterByStatisticsSnapshot snapshot = collector.snapshot();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ClusterByStatisticsSnapshotSerde.serialize(byteArrayOutputStream, snapshot);
final ByteBuffer serializedSnapshot = ByteBuffer.wrap(byteArrayOutputStream.toByteArray());
final ClusterByStatisticsSnapshot deserializedSnapshot = ClusterByStatisticsSnapshotSerde.deserialize(serializedSnapshot);
Assert.assertEquals(snapshot, deserializedSnapshot);
}
@Test
public void testEmptyDistinctKeySketchSnapshot() throws IOException
{
ClusterByStatisticsCollectorImpl collector = makeCollector(true);
ClusterByStatisticsSnapshot snapshot = collector.snapshot();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ClusterByStatisticsSnapshotSerde.serialize(byteArrayOutputStream, snapshot);
final ByteBuffer serializedSnapshot = ByteBuffer.wrap(byteArrayOutputStream.toByteArray());
final ClusterByStatisticsSnapshot deserializedSnapshot = ClusterByStatisticsSnapshotSerde.deserialize(serializedSnapshot);
Assert.assertEquals(snapshot, deserializedSnapshot);
}
private ClusterByStatisticsCollectorImpl makeCollector(final boolean aggregate)
{
return (ClusterByStatisticsCollectorImpl) ClusterByStatisticsCollectorImpl.create(CLUSTER_BY_XYZ_BUCKET_BY_X, SIGNATURE, MAX_BYTES, MAX_BUCKETS, aggregate, false);
}
}

View File

@ -19,8 +19,11 @@
package org.apache.druid.java.util.http.client.response;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.jboss.netty.handler.codec.http.HttpResponse;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
@ -56,4 +59,17 @@ public class BytesFullResponseHolder extends FullResponseHolder<byte[]>
return buf.array();
}
/**
* Deserialize a {@link BytesFullResponseHolder} as JSON.
*/
public <T> T deserialize(final ObjectMapper jsonMapper, final TypeReference<T> typeReference)
{
try {
return jsonMapper.readValue(getContent(), typeReference);
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.java.util.http.client.response;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpVersion;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import java.io.IOException;
import java.util.Objects;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.spy;
public class BytesFullResponseHolderTest
{
ObjectMapper objectMapper = spy(new DefaultObjectMapper());
@Test
public void testDeserialize() throws Exception
{
final ResponseObject payload = new ResponseObject("payload123");
final DefaultHttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
final BytesFullResponseHolder target = spy(new BytesFullResponseHolder(response));
target.addChunk(objectMapper.writeValueAsBytes(payload));
final ResponseObject deserialize = target.deserialize(objectMapper, new TypeReference<ResponseObject>() {
});
Assert.assertEquals(payload, deserialize);
Mockito.verify(target, Mockito.times(1)).deserialize(ArgumentMatchers.any(), ArgumentMatchers.any());
}
@Test
public void testDeserializeException() throws IOException
{
final DefaultHttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
final BytesFullResponseHolder target = spy(new BytesFullResponseHolder(response));
Mockito.doThrow(IOException.class).when(objectMapper).readValue(isA(byte[].class), isA(TypeReference.class));
Assert.assertThrows(RuntimeException.class, () -> target.deserialize(objectMapper, new TypeReference<ResponseObject>() {}));
Mockito.verify(target, Mockito.times(1)).deserialize(ArgumentMatchers.any(), ArgumentMatchers.any());
}
static class ResponseObject
{
String payload;
@JsonCreator
public ResponseObject(@JsonProperty("payload") String payload)
{
this.payload = payload;
}
@JsonProperty("payload") public String getPayload()
{
return payload;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ResponseObject that = (ResponseObject) o;
return Objects.equals(payload, that.payload);
}
@Override
public int hashCode()
{
return Objects.hashCode(payload);
}
}
}