[Backport] Make sketch encoding configurable (#17086) (#17153)

Makes sketch encoding in MSQ configurable by the user. This would allow a user to configure the sketch encoding method for a specific query.

The default is octet stream encoding.
This commit is contained in:
Adarsh Sanjeev 2024-09-30 11:26:17 +05:30 committed by GitHub
parent a16b75a42c
commit e364d84e12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 99 additions and 44 deletions

View File

@ -631,7 +631,8 @@ public class ControllerImpl implements Controller
this.workerSketchFetcher = new WorkerSketchFetcher(
netClient,
workerManager,
queryKernelConfig.isFaultTolerant()
queryKernelConfig.isFaultTolerant(),
MultiStageQueryContext.getSketchEncoding(querySpec.getQuery().context())
);
closer.register(workerSketchFetcher::close);

View File

@ -32,6 +32,7 @@ import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import javax.annotation.Nullable;
@ -60,23 +61,25 @@ public class ExceptionWrappingWorkerClient implements WorkerClient
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerTaskId,
StageId stageId
StageId stageId,
SketchEncoding sketchEncoding
)
{
return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, stageId));
return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, stageId, sketchEncoding));
}
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerTaskId,
StageId stageId,
long timeChunk
long timeChunk,
SketchEncoding sketchEncoding
)
{
return wrap(
workerTaskId,
client,
c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, stageId, timeChunk)
c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, stageId, timeChunk, sketchEncoding)
);
}

View File

@ -25,6 +25,7 @@ import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import java.io.Closeable;
@ -47,7 +48,8 @@ public interface WorkerClient extends Closeable
*/
ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerId,
StageId stageId
StageId stageId,
SketchEncoding sketchEncoding
);
/**
@ -57,7 +59,8 @@ public interface WorkerClient extends Closeable
ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerId,
StageId stageId,
long timeChunk
long timeChunk,
SketchEncoding sketchEncoding
);
/**

View File

@ -34,6 +34,7 @@ import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
@ -57,6 +58,7 @@ public class WorkerSketchFetcher implements AutoCloseable
private static final int DEFAULT_THREAD_COUNT = 4;
private final WorkerClient workerClient;
private final SketchEncoding sketchEncoding;
private final WorkerManager workerManager;
private final boolean retryEnabled;
@ -68,10 +70,12 @@ public class WorkerSketchFetcher implements AutoCloseable
public WorkerSketchFetcher(
WorkerClient workerClient,
WorkerManager workerManager,
boolean retryEnabled
boolean retryEnabled,
SketchEncoding sketchEncoding
)
{
this.workerClient = workerClient;
this.sketchEncoding = sketchEncoding;
this.executorService = Execs.multiThreaded(DEFAULT_THREAD_COUNT, "SketchFetcherThreadPool-%d");
this.workerManager = workerManager;
this.retryEnabled = retryEnabled;
@ -96,7 +100,7 @@ public class WorkerSketchFetcher implements AutoCloseable
executorService.submit(() -> {
fetchStatsFromWorker(
kernelActions,
() -> workerClient.fetchClusterByStatisticsSnapshot(taskId, stageId),
() -> workerClient.fetchClusterByStatisticsSnapshot(taskId, stageId, sketchEncoding),
taskId,
(kernel, snapshot) ->
kernel.mergeClusterByStatisticsCollectorForAllTimeChunks(stageId, workerNumber, snapshot),
@ -252,7 +256,8 @@ public class WorkerSketchFetcher implements AutoCloseable
() -> workerClient.fetchClusterByStatisticsSnapshotForTimeChunk(
taskId,
new StageId(stageId.getQueryId(), stageId.getStageNumber()),
timeChunk
timeChunk,
sketchEncoding
),
taskId,
(kernel, snapshot) -> kernel.mergeClusterByStatisticsCollectorForTimeChunk(

View File

@ -90,14 +90,15 @@ public abstract class BaseWorkerClientImpl implements WorkerClient
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerId,
StageId stageId
StageId stageId,
SketchEncoding sketchEncoding
)
{
String path = StringUtils.format(
"/keyStatistics/%s/%d?sketchEncoding=%s",
StringUtils.urlEncode(stageId.getQueryId()),
stageId.getStageNumber(),
WorkerResource.SketchEncoding.OCTET_STREAM
sketchEncoding
);
return getClient(workerId).asyncRequest(
@ -110,7 +111,8 @@ public abstract class BaseWorkerClientImpl implements WorkerClient
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerId,
StageId stageId,
long timeChunk
long timeChunk,
SketchEncoding sketchEncoding
)
{
String path = StringUtils.format(
@ -118,7 +120,7 @@ public abstract class BaseWorkerClientImpl implements WorkerClient
StringUtils.urlEncode(stageId.getQueryId()),
stageId.getStageNumber(),
timeChunk,
WorkerResource.SketchEncoding.OCTET_STREAM
sketchEncoding
);
return getClient(workerId).asyncRequest(

View File

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.rpc;
import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde;
/**
* Determines the encoding of key collectors returned by {@link WorkerResource#httpFetchKeyStatistics} and
* {@link WorkerResource#httpFetchKeyStatisticsWithSnapshot}.
*/
public enum SketchEncoding
{
/**
* The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}.
*/
OCTET_STREAM,
/**
* The key collector is encoded as json
*/
JSON
}

View File

@ -373,19 +373,4 @@ public class WorkerResource
return Response.status(Response.Status.OK).entity(worker.getCounters()).build();
}
/**
* Determines the encoding of key collectors returned by {@link #httpFetchKeyStatistics} and
* {@link #httpFetchKeyStatisticsWithSnapshot}.
*/
public enum SketchEncoding
{
/**
* The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}.
*/
OCTET_STREAM,
/**
* The key collector is encoded as json
*/
JSON
}
}

View File

@ -38,6 +38,7 @@ import org.apache.druid.msq.indexing.destination.MSQSelectDestination;
import org.apache.druid.msq.indexing.error.MSQWarnings;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.rpc.ControllerResource;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.sql.MSQMode;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
@ -138,6 +139,9 @@ public class MultiStageQueryContext
public static final String CTX_CLUSTER_STATISTICS_MERGE_MODE = "clusterStatisticsMergeMode";
public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.SEQUENTIAL.toString();
public static final String CTX_SKETCH_ENCODING_MODE = "sketchEncoding";
public static final String DEFAULT_CTX_SKETCH_ENCODING_MODE = SketchEncoding.OCTET_STREAM.toString();
public static final String CTX_ROWS_PER_SEGMENT = "rowsPerSegment";
public static final int DEFAULT_ROWS_PER_SEGMENT = 3000000;
@ -265,6 +269,15 @@ public class MultiStageQueryContext
);
}
public static SketchEncoding getSketchEncoding(QueryContext queryContext)
{
return QueryContexts.getAsEnum(
CTX_SKETCH_ENCODING_MODE,
queryContext.getString(CTX_SKETCH_ENCODING_MODE, DEFAULT_CTX_SKETCH_ENCODING_MODE),
SketchEncoding.class
);
}
public static boolean isFinalizeAggregations(final QueryContext queryContext)
{
return queryContext.getBoolean(

View File

@ -27,6 +27,7 @@ import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.junit.After;
@ -101,13 +102,13 @@ public class WorkerSketchFetcherTest
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size());
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM));
// When fetching snapshots, return a mock and add it to queue
doAnswer(invocation -> {
ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class);
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any());
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), any());
target.inMemoryFullSketchMerging((kernelConsumer) -> {
kernelConsumer.accept(kernel);
@ -124,13 +125,13 @@ public class WorkerSketchFetcherTest
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size());
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM));
// When fetching snapshots, return a mock and add it to queue
doAnswer(invocation -> {
ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class);
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong());
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong(), any());
target.sequentialTimeChunkMerging(
(kernelConsumer) -> {
@ -152,7 +153,7 @@ public class WorkerSketchFetcherTest
{
doReturn(false).when(completeKeyStatisticsInformation).isComplete();
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM));
Assert.assertThrows(ISE.class, () -> target.sequentialTimeChunkMerging(
(ignore) -> {},
completeKeyStatisticsInformation,
@ -167,7 +168,7 @@ public class WorkerSketchFetcherTest
{
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size());
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM));
workersWithFailedFetchParallel(ImmutableSet.of(TASK_1));
@ -196,7 +197,7 @@ public class WorkerSketchFetcherTest
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size());
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM));
workersWithFailedFetchSequential(ImmutableSet.of(TASK_1));
CountDownLatch retryLatch = new CountDownLatch(1);
@ -223,7 +224,7 @@ public class WorkerSketchFetcherTest
public void test_InMemoryRetryDisabled_multipleFailures() throws InterruptedException
{
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM));
workersWithFailedFetchParallel(ImmutableSet.of(TASK_1, TASK_0));
@ -252,7 +253,7 @@ public class WorkerSketchFetcherTest
public void test_InMemoryRetryDisabled_singleFailure() throws InterruptedException
{
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM));
workersWithFailedFetchParallel(ImmutableSet.of(TASK_1));
@ -283,7 +284,7 @@ public class WorkerSketchFetcherTest
{
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM));
workersWithFailedFetchSequential(ImmutableSet.of(TASK_1, TASK_0));
@ -315,7 +316,7 @@ public class WorkerSketchFetcherTest
public void test_SequentialRetryDisabled_singleFailure() throws InterruptedException
{
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM));
workersWithFailedFetchSequential(ImmutableSet.of(TASK_1));
@ -352,7 +353,7 @@ public class WorkerSketchFetcherTest
return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0)));
}
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong());
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong(), any());
}
private void workersWithFailedFetchParallel(Set<String> failedTasks)
@ -363,7 +364,7 @@ public class WorkerSketchFetcherTest
return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0)));
}
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any());
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), any());
}
}

View File

@ -29,6 +29,7 @@ import org.apache.druid.msq.exec.Worker;
import org.apache.druid.msq.exec.WorkerClient;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import java.io.InputStream;
@ -54,7 +55,8 @@ public class MSQTestWorkerClient implements WorkerClient
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerTaskId,
StageId stageId
StageId stageId,
SketchEncoding sketchEncoding
)
{
return Futures.immediateFuture(inMemoryWorkers.get(workerTaskId).fetchStatisticsSnapshot(stageId));
@ -64,7 +66,8 @@ public class MSQTestWorkerClient implements WorkerClient
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerTaskId,
StageId stageId,
long timeChunk
long timeChunk,
SketchEncoding sketchEncoding
)
{
return Futures.immediateFuture(