Add sequential sketch merging to MSQ (#13205)

* Add sketch fetching framework

* Refactor code to support sequential merge

* Update worker sketch fetcher

* Refactor sketch fetcher

* Refactor sketch fetcher

* Add context parameter and threshold to trigger sequential merge

* Fix test

* Add integration test for non sequential merge

* Address review comments

* Address review comments

* Address review comments

* Resolve maxRetainedBytes

* Add new classes

* Renamed key statistics information class

* Rename fetchStatisticsSnapshotForTimeChunk function

* Address review comments

* Address review comments

* Update documentation and add comments

* Resolve build issues

* Resolve build issues

* Change worker APIs to async

* Address review comments

* Resolve build issues

* Add null time check

* Update integration tests

* Address review comments

* Add log messages and comments

* Resolve build issues

* Add unit tests

* Add unit tests

* Fix timing issue in tests
This commit is contained in:
Adarsh Sanjeev 2022-11-22 09:56:32 +05:30 committed by GitHub
parent fe34ecc5e3
commit 280a0f7158
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 1850 additions and 132 deletions

View File

@ -203,6 +203,28 @@ The following table lists the context parameters for the MSQ task engine:
| `maxParseExceptions`| SELECT, INSERT, REPLACE<br /><br />Maximum number of parse exceptions that are ignored while executing the query before it stops with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value to -1.| 0 |
| `rowsPerSegment` | INSERT or REPLACE<br /><br />The number of rows per segment to target. The actual number of rows per segment may be somewhat higher or lower than this number. In most cases, use the default. For general information about sizing rows per segment, see [Segment Size Optimization](../operations/segment-optimization.md). | 3,000,000 |
| `indexSpec` | INSERT or REPLACE<br /><br />An [`indexSpec`](../ingestion/ingestion-spec.md#indexspec) to use when generating segments. May be a JSON string or object. See [Front coding](../ingestion/ingestion-spec.md#front-coding) for details on configuring an `indexSpec` with front coding. | See [`indexSpec`](../ingestion/ingestion-spec.md#indexspec). |
| `clusterStatisticsMergeMode` | Whether to use parallel or sequential mode for merging of the worker sketches. Can be `PARALLEL`, `SEQUENTIAL` or `AUTO`. See [Sketch Merging Mode](#sketch-merging-mode) for more information. | `AUTO` |
## Sketch Merging Mode
This section details the advantages and performance of various Cluster By Statistics Merge Modes.
If a query requires key statistics to generate partition boundaries, key statistics are gathered by the workers while
reading rows from the datasource. These statistics must be transferred to the controller to be merged together.
`clusterStatisticsMergeMode` configures the way in which this happens.
`PARALLEL` mode fetches the key statistics for all time chunks from all workers together and the controller then downsamples
the sketch if it does not fit in memory. This is faster than `SEQUENTIAL` mode as there is less over head in fetching sketches
for all time chunks together. This is good for small sketches which won't be downsampled even if merged together or if
accuracy in segment sizing for the ingestion is not very important.
`SEQUENTIAL` mode fetches the sketches in ascending order of time and generates the partition boundaries for one time
chunk at a time. This gives more working memory to the controller for merging sketches, which results in less
downsampling and thus, more accuracy. There is, however, a time overhead on fetching sketches in sequential order. This is
good for cases where accuracy is important.
`AUTO` mode tries to find the best approach based on number of workers and size of input rows. If there are more
than 100 workers or if the combined sketch size among all workers is more than 1GB, `SEQUENTIAL` is chosen, otherwise,
`PARALLEL` is chosen.
## Durable Storage
This section enumerates the advantages and performance implications of enabling durable storage while executing MSQ tasks.

View File

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.exec;
/**
* Mode which dictates how {@link WorkerSketchFetcher} gets sketches for the partition boundaries from workers.
*/
public enum ClusterStatisticsMergeMode
{
/**
* Fetches sketch in sequential order based on time. Slower due to overhead, but more accurate.
*/
SEQUENTIAL,
/**
* Fetch all sketches from the worker at once. Faster to generate partitions, but less accurate.
*/
PARALLEL,
/**
* Tries to decide between sequential and parallel modes based on the number of workers and size of the input
*
* If there are more than 100 workers or if the combined sketch size among all workers is more than
* 1,000,000,000 bytes, SEQUENTIAL mode is chosen, otherwise, PARALLEL mode is chosen.
*/
AUTO
}

View File

@ -27,7 +27,7 @@ import org.apache.druid.msq.counters.CounterSnapshots;
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.indexing.MSQControllerTask;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import javax.annotation.Nullable;
import java.util.List;
@ -81,9 +81,11 @@ public interface Controller
// Worker-to-controller messages
/**
* Provide a {@link ClusterByStatisticsSnapshot} for shuffling stages.
* Accepts a {@link PartialKeyStatisticsInformation} and updates the controller key statistics information. If all key
* statistics have been gathered, enqueues the task with the {@link WorkerSketchFetcher} to generate partiton boundaries.
* This is intended to be called by the {@link org.apache.druid.msq.indexing.ControllerChatHandler}.
*/
void updateStatus(int stageNumber, int workerNumber, Object keyStatisticsObject);
void updatePartialKeyStatisticsInformation(int stageNumber, int workerNumber, Object partialKeyStatisticsInformationObject);
/**
* System error reported by a subtask. Note that the errors are organized by

View File

@ -22,7 +22,7 @@ package org.apache.druid.msq.exec;
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import javax.annotation.Nullable;
import java.io.IOException;
@ -34,13 +34,13 @@ import java.util.List;
public interface ControllerClient extends AutoCloseable
{
/**
* Client side method to update the controller with key statistics for a particular stage and worker.
* Controller's implementation collates all the key statistics for a stage to generate the partition boundaries.
* Client side method to update the controller with partial key statistics information for a particular stage and worker.
* Controller's implementation collates all the information for a stage to fetch key statistics from workers.
*/
void postKeyStatistics(
void postPartialKeyStatistics(
StageId stageId,
int workerNumber,
ClusterByStatisticsSnapshot keyStatistics
PartialKeyStatisticsInformation partialKeyStatisticsInformation
) throws IOException;
/**

View File

@ -64,6 +64,7 @@ import org.apache.druid.indexing.common.actions.TaskActionClient;
import org.apache.druid.indexing.overlord.SegmentPublishResult;
import org.apache.druid.indexing.overlord.Segments;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals;
@ -107,6 +108,7 @@ import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher;
import org.apache.druid.msq.indexing.error.MSQWarnings;
import org.apache.druid.msq.indexing.error.QueryNotSupportedFault;
import org.apache.druid.msq.indexing.error.TooManyPartitionsFault;
import org.apache.druid.msq.indexing.error.TooManyWarningsFault;
import org.apache.druid.msq.indexing.error.UnknownFault;
import org.apache.druid.msq.indexing.report.MSQResultsReport;
@ -149,7 +151,8 @@ import org.apache.druid.msq.querykit.scan.ScanQueryKit;
import org.apache.druid.msq.shuffle.DurableStorageInputChannelFactory;
import org.apache.druid.msq.shuffle.DurableStorageUtils;
import org.apache.druid.msq.shuffle.WorkerInputChannelFactory;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import org.apache.druid.msq.util.DimensionSchemaUtils;
import org.apache.druid.msq.util.IntervalUtils;
import org.apache.druid.msq.util.MSQFutureUtils;
@ -201,6 +204,7 @@ import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ThreadLocalRandom;
@ -259,6 +263,7 @@ public class ControllerImpl implements Controller
// For live reports. Written by the main controller thread, read by HTTP threads.
private final ConcurrentHashMap<Integer, Integer> stagePartitionCountsForLiveReports = new ConcurrentHashMap<>();
private WorkerSketchFetcher workerSketchFetcher;
// Time at which the query started.
// For live reports. Written by the main controller thread, read by HTTP threads.
private volatile DateTime queryStartTime = null;
@ -519,6 +524,15 @@ public class ControllerImpl implements Controller
context.registerController(this, closer);
this.netClient = new ExceptionWrappingWorkerClient(context.taskClientFor(this));
ClusterStatisticsMergeMode clusterStatisticsMergeMode =
MultiStageQueryContext.getClusterStatisticsMergeMode(task.getQuerySpec().getQuery().context());
log.debug("Query [%s] cluster statistics merge mode is set to %s.", id(), clusterStatisticsMergeMode);
int statisticsMaxRetainedBytes = WorkerMemoryParameters.createProductionInstanceForController(context.injector())
.getPartitionStatisticsMaxRetainedBytes();
this.workerSketchFetcher = new WorkerSketchFetcher(netClient, clusterStatisticsMergeMode, statisticsMaxRetainedBytes);
closer.register(netClient::close);
final boolean isDurableStorageEnabled =
@ -565,10 +579,12 @@ public class ControllerImpl implements Controller
}
/**
* Provide a {@link ClusterByStatisticsSnapshot} for shuffling stages.
* Accepts a {@link PartialKeyStatisticsInformation} and updates the controller key statistics information. If all key
* statistics information has been gathered, enqueues the task with the {@link WorkerSketchFetcher} to generate
* partiton boundaries. This is intended to be called by the {@link org.apache.druid.msq.indexing.ControllerChatHandler}.
*/
@Override
public void updateStatus(int stageNumber, int workerNumber, Object keyStatisticsObject)
public void updatePartialKeyStatisticsInformation(int stageNumber, int workerNumber, Object partialKeyStatisticsInformationObject)
{
addToKernelManipulationQueue(
queryKernel -> {
@ -582,9 +598,9 @@ public class ControllerImpl implements Controller
stageDef.getShuffleSpec().get().doesAggregateByClusterKey()
);
final ClusterByStatisticsSnapshot keyStatistics;
final PartialKeyStatisticsInformation partialKeyStatisticsInformation;
try {
keyStatistics = mapper.convertValue(keyStatisticsObject, ClusterByStatisticsSnapshot.class);
partialKeyStatisticsInformation = mapper.convertValue(partialKeyStatisticsInformationObject, PartialKeyStatisticsInformation.class);
}
catch (IllegalArgumentException e) {
throw new IAE(
@ -595,7 +611,36 @@ public class ControllerImpl implements Controller
);
}
queryKernel.addResultKeyStatisticsForStageAndWorker(stageId, workerNumber, keyStatistics);
queryKernel.addPartialKeyStatisticsForStageAndWorker(stageId, workerNumber, partialKeyStatisticsInformation);
if (queryKernel.getStagePhase(stageId).equals(ControllerStagePhase.MERGING_STATISTICS)) {
List<String> workerTaskIds = workerTaskLauncher.getTaskList();
CompleteKeyStatisticsInformation completeKeyStatisticsInformation =
queryKernel.getCompleteKeyStatisticsInformation(stageId);
// Queue the sketch fetching task into the worker sketch fetcher.
CompletableFuture<Either<Long, ClusterByPartitions>> clusterByPartitionsCompletableFuture =
workerSketchFetcher.submitFetcherTask(
completeKeyStatisticsInformation,
workerTaskIds,
stageDef
);
// Add the listener to handle completion.
clusterByPartitionsCompletableFuture.whenComplete((clusterByPartitionsEither, throwable) -> {
addToKernelManipulationQueue(holder -> {
if (throwable != null) {
holder.failStageForReason(stageId, UnknownFault.forException(throwable));
} else if (clusterByPartitionsEither.isError()) {
holder.failStageForReason(stageId, new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
} else {
log.debug("Query [%s] Partition boundaries generated for stage %s", id(), stageId);
holder.setClusterByPartitionBoundaries(stageId, clusterByPartitionsEither.valueOrThrow());
}
holder.transitionStageKernel(stageId, queryKernel.getStagePhase(stageId));
});
});
}
}
);
}
@ -1959,11 +2004,7 @@ public class ControllerImpl implements Controller
this.queryDef = queryDef;
this.inputSpecSlicerFactory = inputSpecSlicerFactory;
this.closer = closer;
this.queryKernel = new ControllerQueryKernel(
queryDef,
WorkerMemoryParameters.createProductionInstanceForController(context.injector())
.getPartitionStatisticsMaxRetainedBytes()
);
this.queryKernel = new ControllerQueryKernel(queryDef);
}
/**

View File

@ -31,6 +31,7 @@ import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import javax.annotation.Nullable;
import java.io.IOException;
@ -55,6 +56,23 @@ public class ExceptionWrappingWorkerClient implements WorkerClient
return wrap(workerTaskId, client, c -> c.postWorkOrder(workerTaskId, workOrder));
}
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(String workerTaskId, String queryId, int stageNumber)
{
return client.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, stageNumber);
}
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerTaskId,
String queryId,
int stageNumber,
long timeChunk
)
{
return client.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, queryId, stageNumber, timeChunk);
}
@Override
public ListenableFuture<Void> postResultPartitionBoundaries(
final String workerTaskId,

View File

@ -25,6 +25,7 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.indexing.MSQWorkerTask;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import javax.annotation.Nullable;
import java.io.IOException;
@ -67,6 +68,18 @@ public interface Worker
*/
void postWorkOrder(WorkOrder workOrder);
/**
* Returns the statistics snapshot for the given stageId. This is called from {@link WorkerSketchFetcher} under
* PARALLEL OR AUTO modes.
*/
ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId);
/**
* Returns the statistics snapshot for the given stageId which contains only the sketch for the specified timeChunk.
* This is called from {@link WorkerSketchFetcher} under SEQUENTIAL OR AUTO modes.
*/
ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk);
/**
* Called when the worker chat handler recieves the result partition boundaries for a particular stageNumber
* and queryId

View File

@ -25,6 +25,7 @@ import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import java.io.IOException;
@ -38,6 +39,27 @@ public interface WorkerClient extends AutoCloseable
*/
ListenableFuture<Void> postWorkOrder(String workerTaskId, WorkOrder workOrder);
/**
* Fetches the {@link ClusterByStatisticsSnapshot} from a worker. This is intended to be used by the
* {@link WorkerSketchFetcher} under PARALLEL or AUTO modes.
*/
ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerTaskId,
String queryId,
int stageNumber
);
/**
* Fetches a {@link ClusterByStatisticsSnapshot} which contains only the sketch of the specified timeChunk.
* This is intended to be used by the {@link WorkerSketchFetcher} under SEQUENTIAL or AUTO modes.
*/
ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerTaskId,
String queryId,
int stageNumber,
long timeChunk
);
/**
* Worker's client method to inform it of the partition boundaries for the given stage. This is usually invoked by the
* controller after collating the result statistics from all the workers processing the query

View File

@ -106,6 +106,7 @@ import org.apache.druid.msq.shuffle.DurableStorageUtils;
import org.apache.druid.msq.shuffle.WorkerInputChannelFactory;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import org.apache.druid.msq.util.DecoratedExecutorService;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.PrioritizedCallable;
@ -159,6 +160,7 @@ public class WorkerImpl implements Worker
private final BlockingQueue<Consumer<KernelHolder>> kernelManipulationQueue = new LinkedBlockingDeque<>();
private final ConcurrentHashMap<StageId, ConcurrentHashMap<Integer, ReadableFrameChannel>> stageOutputs = new ConcurrentHashMap<>();
private final ConcurrentHashMap<StageId, CounterTracker> stageCounters = new ConcurrentHashMap<>();
private final ConcurrentHashMap<StageId, WorkerStageKernel> stageKernelMap = new ConcurrentHashMap<>();
private final boolean durableStageStorageEnabled;
/**
@ -365,10 +367,14 @@ public class WorkerImpl implements Worker
if (kernel.getPhase() == WorkerStagePhase.READING_INPUT && kernel.hasResultKeyStatisticsSnapshot()) {
if (controllerAlive) {
controllerClient.postKeyStatistics(
PartialKeyStatisticsInformation partialKeyStatisticsInformation =
kernel.getResultKeyStatisticsSnapshot()
.partialKeyStatistics();
controllerClient.postPartialKeyStatistics(
stageDefinition.getId(),
kernel.getWorkOrder().getWorkerNumber(),
kernel.getResultKeyStatisticsSnapshot()
partialKeyStatisticsInformation
);
}
kernel.startPreshuffleWaitingForResultPartitionBoundaries();
@ -562,6 +568,19 @@ public class WorkerImpl implements Worker
kernelManipulationQueue.add(KernelHolder::setDone);
}
@Override
public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
{
return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
}
@Override
public ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk)
{
ClusterByStatisticsSnapshot snapshot = stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
return snapshot.getSnapshotForTimeChunk(timeChunk);
}
@Override
public CounterSnapshotsTree getCounters()
{
@ -1273,9 +1292,8 @@ public class WorkerImpl implements Worker
}
}
private static class KernelHolder
private class KernelHolder
{
private final Map<StageId, WorkerStageKernel> stageKernelMap = new HashMap<>();
private boolean done = false;
public Map<StageId, WorkerStageKernel> getStageKernelMap()

View File

@ -0,0 +1,340 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.exec;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.IntStream;
/**
* Queues up fetching sketches from workers and progressively generates partitions boundaries.
*/
public class WorkerSketchFetcher
{
private static final Logger log = new Logger(WorkerSketchFetcher.class);
private static final int DEFAULT_THREAD_COUNT = 4;
// If the combined size of worker sketches is more than this threshold, SEQUENTIAL merging mode is used.
static final long BYTES_THRESHOLD = 1_000_000_000L;
// If there are more workers than this threshold, SEQUENTIAL merging mode is used.
static final long WORKER_THRESHOLD = 100;
private final ClusterStatisticsMergeMode clusterStatisticsMergeMode;
private final int statisticsMaxRetainedBytes;
private final WorkerClient workerClient;
private final ExecutorService executorService;
public WorkerSketchFetcher(WorkerClient workerClient, ClusterStatisticsMergeMode clusterStatisticsMergeMode, int statisticsMaxRetainedBytes)
{
this.workerClient = workerClient;
this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
this.executorService = Executors.newFixedThreadPool(DEFAULT_THREAD_COUNT);
this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes;
}
/**
* Submits a request to fetch and generate partitions for the given worker statistics and returns a future for it. It
* decides based on the statistics if it should fetch sketches one by one or together.
*/
public CompletableFuture<Either<Long, ClusterByPartitions>> submitFetcherTask(
CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
List<String> workerTaskIds,
StageDefinition stageDefinition
)
{
ClusterBy clusterBy = stageDefinition.getClusterBy();
switch (clusterStatisticsMergeMode) {
case SEQUENTIAL:
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
case PARALLEL:
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
case AUTO:
if (clusterBy.getBucketByCount() == 0) {
log.debug("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
// If there is no time clustering, there is no scope for sequential merge
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
} else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
log.debug("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId());
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
}
log.debug("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
default:
throw new IllegalStateException("No fetching strategy found for mode: " + clusterStatisticsMergeMode);
}
}
/**
* Fetches the full {@link ClusterByStatisticsCollector} from all workers and generates partition boundaries from them.
* This is faster than fetching them timechunk by timechunk but the collector will be downsampled till it can fit
* on the controller, resulting in less accurate partition boundries.
*/
CompletableFuture<Either<Long, ClusterByPartitions>> inMemoryFullSketchMerging(
StageDefinition stageDefinition,
List<String> workerTaskIds
)
{
CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture = new CompletableFuture<>();
// Create a new key statistics collector to merge worker sketches into
final ClusterByStatisticsCollector mergedStatisticsCollector =
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
final int workerCount = workerTaskIds.size();
// Guarded by synchronized mergedStatisticsCollector
final Set<Integer> finishedWorkers = new HashSet<>();
// Submit a task for each worker to fetch statistics
IntStream.range(0, workerCount).forEach(workerNo -> {
executorService.submit(() -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
workerClient.fetchClusterByStatisticsSnapshot(
workerTaskIds.get(workerNo),
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
partitionFuture.whenComplete((result, exception) -> {
if (exception != null || (result != null && result.isError())) {
snapshotFuture.cancel(true);
}
});
try {
ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = snapshotFuture.get();
if (clusterByStatisticsSnapshot == null) {
throw new ISE("Worker %s returned null sketch, this should never happen", workerNo);
}
synchronized (mergedStatisticsCollector) {
mergedStatisticsCollector.addAll(clusterByStatisticsSnapshot);
finishedWorkers.add(workerNo);
if (finishedWorkers.size() == workerCount) {
log.debug("Query [%s] Received all statistics, generating partitions", stageDefinition.getId().getQueryId());
partitionFuture.complete(stageDefinition.generatePartitionsForShuffle(mergedStatisticsCollector));
}
}
}
catch (Exception e) {
synchronized (mergedStatisticsCollector) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
}
}
});
});
return partitionFuture;
}
/**
* Fetches cluster statistics from all workers and generates partition boundaries from them one time chunk at a time.
* This takes longer due to the overhead of fetching sketches, however, this prevents any loss in accuracy from
* downsampling on the controller.
*/
CompletableFuture<Either<Long, ClusterByPartitions>> sequentialTimeChunkMerging(
CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
StageDefinition stageDefinition,
List<String> workerTaskIds
)
{
SequentialFetchStage sequentialFetchStage = new SequentialFetchStage(
stageDefinition,
workerTaskIds,
completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().entrySet().iterator()
);
sequentialFetchStage.submitFetchingTasksForNextTimeChunk();
return sequentialFetchStage.getPartitionFuture();
}
private class SequentialFetchStage
{
private final StageDefinition stageDefinition;
private final List<String> workerTaskIds;
private final Iterator<Map.Entry<Long, Set<Integer>>> timeSegmentVsWorkerIdIterator;
private final CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture;
// Final sorted list of partition boundaries. This is appended to after statistics for each time chunk are gathered.
private final List<ClusterByPartition> finalPartitionBoundries;
public SequentialFetchStage(
StageDefinition stageDefinition,
List<String> workerTaskIds,
Iterator<Map.Entry<Long, Set<Integer>>> timeSegmentVsWorkerIdIterator
)
{
this.finalPartitionBoundries = new ArrayList<>();
this.stageDefinition = stageDefinition;
this.workerTaskIds = workerTaskIds;
this.timeSegmentVsWorkerIdIterator = timeSegmentVsWorkerIdIterator;
this.partitionFuture = new CompletableFuture<>();
}
/**
* Submits the tasks to fetch key statistics for the time chunk pointed to by {@link #timeSegmentVsWorkerIdIterator}.
* Once the statistics have been gathered from all workers which have them, generates partitions and adds it to
* {@link #finalPartitionBoundries}, stiching the partitions between time chunks using
* {@link #abutAndAppendPartitionBoundries(List, List)} to make them continuous.
*
* The time chunks returned by {@link #timeSegmentVsWorkerIdIterator} should be in ascending order for the partitions
* to be generated correctly.
*
* If {@link #timeSegmentVsWorkerIdIterator} doesn't have any more values, assumes that partition boundaries have
* been successfully generated and completes {@link #partitionFuture} with the result.
*
* Completes the future with an error as soon as the number of partitions exceed max partition count for the stage
* definition.
*/
public void submitFetchingTasksForNextTimeChunk()
{
if (!timeSegmentVsWorkerIdIterator.hasNext()) {
partitionFuture.complete(Either.value(new ClusterByPartitions(finalPartitionBoundries)));
} else {
Map.Entry<Long, Set<Integer>> entry = timeSegmentVsWorkerIdIterator.next();
// Time chunk for which partition boundries are going to be generated for
Long timeChunk = entry.getKey();
Set<Integer> workerIdsWithTimeChunk = entry.getValue();
// Create a new key statistics collector to merge worker sketches into
ClusterByStatisticsCollector mergedStatisticsCollector =
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
// Guarded by synchronized mergedStatisticsCollector
Set<Integer> finishedWorkers = new HashSet<>();
log.debug("Query [%s]. Submitting request for statistics for time chunk %s to %s workers",
stageDefinition.getId().getQueryId(),
timeChunk,
workerIdsWithTimeChunk.size());
// Submits a task for every worker which has a certain time chunk
for (int workerNo : workerIdsWithTimeChunk) {
executorService.submit(() -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
workerClient.fetchClusterByStatisticsSnapshotForTimeChunk(
workerTaskIds.get(workerNo),
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber(),
timeChunk
);
partitionFuture.whenComplete((result, exception) -> {
if (exception != null || (result != null && result.isError())) {
snapshotFuture.cancel(true);
}
});
try {
ClusterByStatisticsSnapshot snapshotForTimeChunk = snapshotFuture.get();
if (snapshotForTimeChunk == null) {
throw new ISE("Worker %s returned null sketch for %s, this should never happen", workerNo, timeChunk);
}
synchronized (mergedStatisticsCollector) {
mergedStatisticsCollector.addAll(snapshotForTimeChunk);
finishedWorkers.add(workerNo);
if (finishedWorkers.size() == workerIdsWithTimeChunk.size()) {
Either<Long, ClusterByPartitions> longClusterByPartitionsEither =
stageDefinition.generatePartitionsForShuffle(mergedStatisticsCollector);
log.debug("Query [%s]. Received all statistics for time chunk %s, generating partitions",
stageDefinition.getId().getQueryId(),
timeChunk);
long totalPartitionCount = finalPartitionBoundries.size() + getPartitionCountFromEither(longClusterByPartitionsEither);
if (totalPartitionCount > stageDefinition.getMaxPartitionCount()) {
// Fail fast if more partitions than the maximum have been reached.
partitionFuture.complete(Either.error(totalPartitionCount));
mergedStatisticsCollector.clear();
} else {
List<ClusterByPartition> timeSketchPartitions = longClusterByPartitionsEither.valueOrThrow().ranges();
abutAndAppendPartitionBoundries(finalPartitionBoundries, timeSketchPartitions);
log.debug("Query [%s]. Finished generating partitions for time chunk %s, total count so far %s",
stageDefinition.getId().getQueryId(),
timeChunk,
finalPartitionBoundries.size());
submitFetchingTasksForNextTimeChunk();
}
}
}
}
catch (Exception e) {
synchronized (mergedStatisticsCollector) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
}
}
});
}
}
}
/**
* Takes a list of sorted {@link ClusterByPartitions} {@param timeSketchPartitions} and adds it to a sorted list
* {@param finalPartitionBoundries}. If {@param finalPartitionBoundries} is not empty, the end time of the last
* partition of {@param finalPartitionBoundries} is changed to abut with the starting time of the first partition
* of {@param timeSketchPartitions}.
*
* This is used to make the partitions generated continuous.
*/
private void abutAndAppendPartitionBoundries(
List<ClusterByPartition> finalPartitionBoundries,
List<ClusterByPartition> timeSketchPartitions
)
{
if (!finalPartitionBoundries.isEmpty()) {
// Stitch up the end time of the last partition with the start time of the first partition.
ClusterByPartition clusterByPartition = finalPartitionBoundries.remove(finalPartitionBoundries.size() - 1);
finalPartitionBoundries.add(new ClusterByPartition(clusterByPartition.getStart(), timeSketchPartitions.get(0).getStart()));
}
finalPartitionBoundries.addAll(timeSketchPartitions);
}
public CompletableFuture<Either<Long, ClusterByPartitions>> getPartitionFuture()
{
return partitionFuture;
}
}
/**
* Gets the partition size from an {@link Either}. If it is an error, the long denotes the number of partitions
* (in the case of creating too many partitions), otherwise checks the size of the list.
*/
private static long getPartitionCountFromEither(Either<Long, ClusterByPartitions> either)
{
if (either.isError()) {
return either.error();
} else {
return either.valueOrThrow().size();
}
}
}

View File

@ -26,7 +26,8 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.exec.Controller;
import org.apache.druid.msq.exec.ControllerClient;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import org.apache.druid.segment.realtime.firehose.ChatHandler;
import org.apache.druid.segment.realtime.firehose.ChatHandlers;
import org.apache.druid.server.security.Action;
@ -58,16 +59,17 @@ public class ControllerChatHandler implements ChatHandler
}
/**
* Used by subtasks to post {@link ClusterByStatisticsSnapshot} for shuffling stages.
* Used by subtasks to post {@link PartialKeyStatisticsInformation} for shuffling stages.
*
* See {@link ControllerClient#postKeyStatistics} for the client-side code that calls this API.
* See {@link ControllerClient#postPartialKeyStatistics(StageId, int, PartialKeyStatisticsInformation)}
* for the client-side code that calls this API.
*/
@POST
@Path("/keyStatistics/{queryId}/{stageNumber}/{workerNumber}")
@Path("/partialKeyStatisticsInformation/{queryId}/{stageNumber}/{workerNumber}")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response httpPostKeyStatistics(
final Object keyStatisticsObject,
public Response httpPostPartialKeyStatistics(
final Object partialKeyStatisticsObject,
@PathParam("queryId") final String queryId,
@PathParam("stageNumber") final int stageNumber,
@PathParam("workerNumber") final int workerNumber,
@ -75,7 +77,7 @@ public class ControllerChatHandler implements ChatHandler
)
{
ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper());
controller.updateStatus(stageNumber, workerNumber, keyStatisticsObject);
controller.updatePartialKeyStatisticsInformation(stageNumber, workerNumber, partialKeyStatisticsObject);
return Response.status(Response.Status.ACCEPTED).build();
}

View File

@ -29,7 +29,7 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.exec.ControllerClient;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import org.apache.druid.rpc.IgnoreHttpResponseHandler;
import org.apache.druid.rpc.RequestBuilder;
import org.apache.druid.rpc.ServiceClient;
@ -59,14 +59,14 @@ public class IndexerControllerClient implements ControllerClient
}
@Override
public void postKeyStatistics(
public void postPartialKeyStatistics(
StageId stageId,
int workerNumber,
ClusterByStatisticsSnapshot keyStatistics
PartialKeyStatisticsInformation partialKeyStatisticsInformation
) throws IOException
{
final String path = StringUtils.format(
"/keyStatistics/%s/%s/%d",
"/partialKeyStatisticsInformation/%s/%d/%d",
StringUtils.urlEncode(stageId.getQueryId()),
stageId.getStageNumber(),
workerNumber
@ -74,7 +74,7 @@ public class IndexerControllerClient implements ControllerClient
doRequest(
new RequestBuilder(HttpMethod.POST, path)
.jsonContent(jsonMapper, keyStatistics),
.jsonContent(jsonMapper, partialKeyStatisticsInformation),
IgnoreHttpResponseHandler.INSTANCE
);
}

View File

@ -41,6 +41,7 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.exec.WorkerClient;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.rpc.IgnoreHttpResponseHandler;
import org.apache.druid.rpc.RequestBuilder;
import org.apache.druid.rpc.ServiceClient;
@ -103,6 +104,48 @@ public class IndexerWorkerClient implements WorkerClient
);
}
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerTaskId,
String queryId,
int stageNumber
)
{
String path = StringUtils.format("/keyStatistics/%s/%d",
StringUtils.urlEncode(queryId),
stageNumber);
return FutureUtils.transform(
getClient(workerTaskId).asyncRequest(
new RequestBuilder(HttpMethod.POST, path),
new BytesFullResponseHandler()
),
holder -> deserialize(holder, new TypeReference<ClusterByStatisticsSnapshot>() {})
);
}
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerTaskId,
String queryId,
int stageNumber,
long timeChunk
)
{
String path = StringUtils.format("/keyStatisticsForTimeChunk/%s/%d/%d",
StringUtils.urlEncode(queryId),
stageNumber,
timeChunk);
return FutureUtils.transform(
getClient(workerTaskId).asyncRequest(
new RequestBuilder(HttpMethod.POST, path),
new BytesFullResponseHandler()
),
holder -> deserialize(holder, new TypeReference<ClusterByStatisticsSnapshot>() {})
);
}
@Override
public ListenableFuture<Void> postResultPartitionBoundaries(
String workerTaskId,

View File

@ -28,6 +28,7 @@ import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.exec.Worker;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.segment.realtime.firehose.ChatHandler;
import org.apache.druid.segment.realtime.firehose.ChatHandlers;
import org.apache.druid.server.security.Action;
@ -179,6 +180,45 @@ public class WorkerChatHandler implements ChatHandler
}
}
@POST
@Path("/keyStatistics/{queryId}/{stageNumber}")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response httpFetchKeyStatistics(
@PathParam("queryId") final String queryId,
@PathParam("stageNumber") final int stageNumber,
@Context final HttpServletRequest req
)
{
ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper());
ClusterByStatisticsSnapshot clusterByStatisticsSnapshot;
StageId stageId = new StageId(queryId, stageNumber);
clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
return Response.status(Response.Status.ACCEPTED)
.entity(clusterByStatisticsSnapshot)
.build();
}
@POST
@Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response httpSketch(
@PathParam("queryId") final String queryId,
@PathParam("stageNumber") final int stageNumber,
@PathParam("timeChunk") final long timeChunk,
@Context final HttpServletRequest req
)
{
ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper());
ClusterByStatisticsSnapshot snapshotForTimeChunk;
StageId stageId = new StageId(queryId, stageNumber);
snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk);
return Response.status(Response.Status.ACCEPTED)
.entity(snapshotForTimeChunk)
.build();
}
/**
* See {@link org.apache.druid.msq.exec.WorkerClient#postCleanupStage} for the client-side code that calls this API.
*/

View File

@ -41,7 +41,8 @@ import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import javax.annotation.Nullable;
import java.util.HashMap;
@ -65,7 +66,6 @@ import java.util.stream.Collectors;
public class ControllerQueryKernel
{
private final QueryDefinition queryDef;
private final int partitionStatisticsMaxRetainedBytes;
/**
* Stage ID -> tracker for that stage. An extension of the state of this kernel.
@ -107,10 +107,9 @@ public class ControllerQueryKernel
*/
private final Set<StageId> effectivelyFinishedStages = new HashSet<>();
public ControllerQueryKernel(final QueryDefinition queryDef, final int partitionStatisticsMaxRetainedBytes)
public ControllerQueryKernel(final QueryDefinition queryDef)
{
this.queryDef = queryDef;
this.partitionStatisticsMaxRetainedBytes = partitionStatisticsMaxRetainedBytes;
this.inflowMap = ImmutableMap.copyOf(computeStageInflowMap(queryDef));
this.outflowMap = ImmutableMap.copyOf(computeStageOutflowMap(queryDef));
@ -266,8 +265,7 @@ public class ControllerQueryKernel
stageDef,
stageWorkerCountMap,
slicer,
assignmentStrategy,
partitionStatisticsMaxRetainedBytes
assignmentStrategy
);
stageTracker.put(nextStage, stageKernel);
}
@ -334,6 +332,22 @@ public class ControllerQueryKernel
return getStageKernelOrThrow(stageId).getResultPartitionBoundaries();
}
/**
* Delegates call to {@link ControllerStageTracker#getCompleteKeyStatisticsInformation()}
*/
public CompleteKeyStatisticsInformation getCompleteKeyStatisticsInformation(final StageId stageId)
{
return getStageKernelOrThrow(stageId).getCompleteKeyStatisticsInformation();
}
/**
* Delegates call to {@link ControllerStageTracker#setClusterByPartitionBoundaries(ClusterByPartitions)} ()}
*/
public void setClusterByPartitionBoundaries(final StageId stageId, ClusterByPartitions clusterByPartitions)
{
getStageKernelOrThrow(stageId).setClusterByPartitionBoundaries(clusterByPartitions);
}
/**
* Delegates call to {@link ControllerStageTracker#collectorEncounteredAnyMultiValueField()}
*/
@ -390,22 +404,24 @@ public class ControllerQueryKernel
}
/**
* Delegates call to {@link ControllerStageTracker#addResultKeyStatisticsForWorker(int, ClusterByStatisticsSnapshot)}.
* Delegates call to {@link ControllerStageTracker#addPartialKeyStatisticsForWorker(int, PartialKeyStatisticsInformation)}.
* If calling this causes transition for the stage kernel, then this gets registered in this query kernel
*/
public void addResultKeyStatisticsForStageAndWorker(
public void addPartialKeyStatisticsForStageAndWorker(
final StageId stageId,
final int workerNumber,
final ClusterByStatisticsSnapshot snapshot
final PartialKeyStatisticsInformation partialKeyStatisticsInformation
)
{
ControllerStagePhase newPhase = getStageKernelOrThrow(stageId).addResultKeyStatisticsForWorker(
ControllerStageTracker stageKernel = getStageKernelOrThrow(stageId);
ControllerStagePhase newPhase = stageKernel.addPartialKeyStatisticsForWorker(
workerNumber,
snapshot
partialKeyStatisticsInformation
);
// If the phase is POST_READING or FAILED, that implies the kernel has transitioned. We need to account for that
// If the kernel phase has transitioned, we need to account for that.
switch (newPhase) {
case MERGING_STATISTICS:
case POST_READING:
case FAILED:
transitionStageKernel(stageId, newPhase);
@ -436,6 +452,12 @@ public class ControllerQueryKernel
return getStageKernelOrThrow(stageId).getFailureReason();
}
public void failStageForReason(final StageId stageId, MSQFault fault)
{
getStageKernelOrThrow(stageId).failForReason(fault);
transitionStageKernel(stageId, ControllerStagePhase.FAILED);
}
/**
* Delegates call to {@link ControllerStageTracker#fail()} and registers this transition to FAILED in this query kernel
*/

View File

@ -48,6 +48,17 @@ public enum ControllerStagePhase
}
},
// Waiting to fetch key statistics in the background from the workers and incrementally generate partitions.
// This phase is only transitioned to once all partialKeyInformation are recieved from workers.
// Transitioning to this phase should also enqueue the task to fetch key statistics to WorkerSketchFetcher.
MERGING_STATISTICS {
@Override
public boolean canTransitionFrom(final ControllerStagePhase priorPhase)
{
return priorPhase == READING_INPUT;
}
},
// Post the inputs have been read and mapped to frames, in the `POST_READING` stage, we pre-shuffle and determing the partition boundaries.
// This step for a stage spits out the statistics of the data as a whole (and not just the individual records). This
// phase is not required in non-pre shuffle contexts.
@ -55,7 +66,7 @@ public enum ControllerStagePhase
@Override
public boolean canTransitionFrom(final ControllerStagePhase priorPhase)
{
return priorPhase == READING_INPUT;
return priorPhase == MERGING_STATISTICS;
}
},

View File

@ -28,6 +28,7 @@ import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.indexing.error.InsertTimeNullFault;
import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.indexing.error.TooManyPartitionsFault;
import org.apache.druid.msq.indexing.error.UnknownFault;
@ -38,11 +39,12 @@ import org.apache.druid.msq.input.stage.ReadablePartitions;
import org.apache.druid.msq.input.stage.StageInputSlice;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import javax.annotation.Nullable;
import java.util.List;
import java.util.TreeMap;
/**
* Controller-side state machine for each stage. Used by {@link ControllerQueryKernel} to form the overall state
@ -57,13 +59,13 @@ class ControllerStageTracker
private final int workerCount;
private final WorkerInputs workerInputs;
private final IntSet workersWithResultKeyStatistics = new IntAVLTreeSet();
private final IntSet workersWithReportedKeyStatistics = new IntAVLTreeSet();
private final IntSet workersWithResultsComplete = new IntAVLTreeSet();
private ControllerStagePhase phase = ControllerStagePhase.NEW;
@Nullable
private final ClusterByStatisticsCollector resultKeyStatisticsCollector;
public final CompleteKeyStatisticsInformation completeKeyStatisticsInformation;
// Result partitions and where they can be read from.
@Nullable
@ -81,8 +83,7 @@ class ControllerStageTracker
private ControllerStageTracker(
final StageDefinition stageDef,
final WorkerInputs workerInputs,
final int partitionStatisticsMaxRetainedBytes
final WorkerInputs workerInputs
)
{
this.stageDef = stageDef;
@ -90,11 +91,11 @@ class ControllerStageTracker
this.workerInputs = workerInputs;
if (stageDef.mustGatherResultKeyStatistics()) {
this.resultKeyStatisticsCollector =
stageDef.createResultKeyStatisticsCollector(partitionStatisticsMaxRetainedBytes);
this.completeKeyStatisticsInformation =
new CompleteKeyStatisticsInformation(new TreeMap<>(), false, 0);
} else {
this.resultKeyStatisticsCollector = null;
generateResultPartitionsAndBoundaries();
this.completeKeyStatisticsInformation = null;
generateResultPartitionsAndBoundariesWithoutKeyStatistics();
}
}
@ -107,12 +108,11 @@ class ControllerStageTracker
final StageDefinition stageDef,
final Int2IntMap stageWorkerCountMap,
final InputSpecSlicer slicer,
final WorkerAssignmentStrategy assignmentStrategy,
final int partitionStatisticsMaxRetainedBytes
final WorkerAssignmentStrategy assignmentStrategy
)
{
final WorkerInputs workerInputs = WorkerInputs.create(stageDef, stageWorkerCountMap, slicer, assignmentStrategy);
return new ControllerStageTracker(stageDef, workerInputs, partitionStatisticsMaxRetainedBytes);
return new ControllerStageTracker(stageDef, workerInputs);
}
/**
@ -175,18 +175,12 @@ class ControllerStageTracker
*/
boolean collectorEncounteredAnyMultiValueField()
{
if (resultKeyStatisticsCollector == null) {
if (completeKeyStatisticsInformation == null) {
throw new ISE("Stage does not gather result key statistics");
} else if (resultPartitions == null) {
} else if (workersWithReportedKeyStatistics.size() != workerCount) {
throw new ISE("Result key statistics are not ready");
} else {
for (int i = 0; i < resultKeyStatisticsCollector.getClusterBy().getColumns().size(); i++) {
if (resultKeyStatisticsCollector.hasMultipleValues(i)) {
return true;
}
}
return false;
return completeKeyStatisticsInformation.hasMultipleValues();
}
}
@ -219,10 +213,6 @@ class ControllerStageTracker
*/
void finish()
{
if (resultKeyStatisticsCollector != null) {
resultKeyStatisticsCollector.clear();
}
transitionTo(ControllerStagePhase.FINISHED);
}
@ -234,23 +224,31 @@ class ControllerStageTracker
return workerInputs;
}
/**
* Returns the merged key statistics.
*/
@Nullable
public CompleteKeyStatisticsInformation getCompleteKeyStatisticsInformation()
{
return completeKeyStatisticsInformation;
}
/**
* Adds result key statistics for a particular worker number. If statistics have already been added for this worker,
* then this call ignores the new ones and does nothing.
*
* @param workerNumber the worker
* @param snapshot worker statistics
* @param partialKeyStatisticsInformation partial key statistics
*/
ControllerStagePhase addResultKeyStatisticsForWorker(
ControllerStagePhase addPartialKeyStatisticsForWorker(
final int workerNumber,
final ClusterByStatisticsSnapshot snapshot
final PartialKeyStatisticsInformation partialKeyStatisticsInformation
)
{
if (phase != ControllerStagePhase.READING_INPUT) {
throw new ISE("Cannot add result key statistics from stage [%s]", phase);
}
if (resultKeyStatisticsCollector == null) {
if (!stageDef.mustGatherResultKeyStatistics() || !stageDef.doesShuffle() || completeKeyStatisticsInformation == null) {
throw new ISE("Stage does not gather result key statistics");
}
@ -259,16 +257,21 @@ class ControllerStageTracker
}
try {
if (workersWithResultKeyStatistics.add(workerNumber)) {
resultKeyStatisticsCollector.addAll(snapshot);
if (workersWithReportedKeyStatistics.add(workerNumber)) {
if (workersWithResultKeyStatistics.size() == workerCount) {
generateResultPartitionsAndBoundaries();
if (partialKeyStatisticsInformation.getTimeSegments().contains(null)) {
// Time should not contain null value
failForReason(InsertTimeNullFault.instance());
return getPhase();
}
completeKeyStatisticsInformation.mergePartialInformation(workerNumber, partialKeyStatisticsInformation);
if (workersWithReportedKeyStatistics.size() == workerCount) {
// All workers have sent the partial key statistics information.
// Transition to MERGING_STATISTICS state to queue fetch clustering statistics from workers.
transitionTo(ControllerStagePhase.MERGING_STATISTICS);
// Phase can become FAILED after generateResultPartitionsAndBoundaries, if there were too many partitions.
if (phase != ControllerStagePhase.FAILED) {
transitionTo(ControllerStagePhase.POST_READING);
}
}
}
}
@ -280,6 +283,33 @@ class ControllerStageTracker
return getPhase();
}
/**
* Sets the {@link #resultPartitions} and {@link #resultPartitionBoundaries} and transitions the phase to POST_READING.
*/
void setClusterByPartitionBoundaries(ClusterByPartitions clusterByPartitions)
{
if (resultPartitions != null) {
throw new ISE("Result partitions have already been generated");
}
if (!stageDef.mustGatherResultKeyStatistics()) {
throw new ISE("Result partitions does not require key statistics, should not have set partition boundries here");
}
if (!ControllerStagePhase.MERGING_STATISTICS.equals(getPhase())) {
throw new ISE("Cannot set partition boundires from key statistics from stage [%s]", getPhase());
}
this.resultPartitionBoundaries = clusterByPartitions;
this.resultPartitions = ReadablePartitions.striped(
stageDef.getStageNumber(),
workerCount,
clusterByPartitions.size()
);
transitionTo(ControllerStagePhase.POST_READING);
}
/**
* Accepts and sets the results that each worker produces for this particular stage
*
@ -339,12 +369,11 @@ class ControllerStageTracker
}
/**
* Sets {@link #resultPartitions} (always) and {@link #resultPartitionBoundaries}.
* Sets {@link #resultPartitions} (always) and {@link #resultPartitionBoundaries} without using key statistics.
*
* If {@link StageDefinition#mustGatherResultKeyStatistics()} is true, this method cannot be called until after
* statistics have been provided to {@link #addResultKeyStatisticsForWorker} for all workers.
* If {@link StageDefinition#mustGatherResultKeyStatistics()} is true, this method should not be called.
*/
private void generateResultPartitionsAndBoundaries()
private void generateResultPartitionsAndBoundariesWithoutKeyStatistics()
{
if (resultPartitions != null) {
throw new ISE("Result partitions have already been generated");
@ -353,12 +382,12 @@ class ControllerStageTracker
final int stageNumber = stageDef.getStageNumber();
if (stageDef.doesShuffle()) {
if (stageDef.mustGatherResultKeyStatistics() && workersWithResultKeyStatistics.size() != workerCount) {
throw new ISE("Cannot generate result partitions without all worker statistics");
if (stageDef.mustGatherResultKeyStatistics()) {
throw new ISE("Cannot generate result partitions without key statistics");
}
final Either<Long, ClusterByPartitions> maybeResultPartitionBoundaries =
stageDef.generatePartitionsForShuffle(resultKeyStatisticsCollector);
stageDef.generatePartitionsForShuffle(null);
if (maybeResultPartitionBoundaries.isError()) {
failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
@ -397,15 +426,11 @@ class ControllerStageTracker
*
* @param fault reason why this stage has failed
*/
private void failForReason(final MSQFault fault)
void failForReason(final MSQFault fault)
{
transitionTo(ControllerStagePhase.FAILED);
this.failureReason = fault;
if (resultKeyStatisticsCollector != null) {
resultKeyStatisticsCollector.clear();
}
}
void transitionTo(final ControllerStagePhase newPhase)

View File

@ -35,6 +35,7 @@ import org.apache.druid.segment.column.RowSignature;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
@ -56,7 +57,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
private final boolean[] hasMultipleValues;
private final int maxRetainedBytes;
private final long maxRetainedBytes;
private final int maxBuckets;
private long totalRetainedBytes;
@ -64,7 +65,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
final ClusterBy clusterBy,
final RowKeyReader keyReader,
final KeyCollectorFactory<?, ?> keyCollectorFactory,
final int maxRetainedBytes,
final long maxRetainedBytes,
final int maxBuckets,
final boolean checkHasMultipleValues
)
@ -86,7 +87,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
public static ClusterByStatisticsCollector create(
final ClusterBy clusterBy,
final RowSignature signature,
final int maxRetainedBytes,
final long maxRetainedBytes,
final int maxBuckets,
final boolean aggregate,
final boolean checkHasMultipleValues
@ -167,7 +168,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
public ClusterByStatisticsCollector addAll(final ClusterByStatisticsSnapshot snapshot)
{
// Add all key collectors from the other collector.
for (ClusterByStatisticsSnapshot.Bucket otherBucket : snapshot.getBuckets()) {
for (ClusterByStatisticsSnapshot.Bucket otherBucket : snapshot.getBuckets().values()) {
//noinspection rawtypes, unchecked
final KeyCollector<?> otherKeyCollector =
((KeyCollectorFactory) keyCollectorFactory).fromSnapshot(otherBucket.getKeyCollectorSnapshot());
@ -315,13 +316,20 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
{
assertRetainedByteCountsAreTrackedCorrectly();
final List<ClusterByStatisticsSnapshot.Bucket> bucketSnapshots = new ArrayList<>();
final Map<Long, ClusterByStatisticsSnapshot.Bucket> bucketSnapshots = new HashMap<>();
final RowKeyReader trimmedRowReader = keyReader.trimmedKeyReader(clusterBy.getBucketByCount());
for (final Map.Entry<RowKey, BucketHolder> bucketEntry : buckets.entrySet()) {
//noinspection rawtypes, unchecked
final KeyCollectorSnapshot keyCollectorSnapshot =
((KeyCollectorFactory) keyCollectorFactory).toSnapshot(bucketEntry.getValue().keyCollector);
bucketSnapshots.add(new ClusterByStatisticsSnapshot.Bucket(bucketEntry.getKey(), keyCollectorSnapshot));
Long bucketKey = Long.MIN_VALUE;
// If there is a clustering on time, read the first field from each bucket and add it to the snapshots.
if (clusterBy.getBucketByCount() == 1) {
bucketKey = (Long) trimmedRowReader.read(bucketEntry.getKey(), 0);
}
bucketSnapshots.put(bucketKey, new ClusterByStatisticsSnapshot.Bucket(bucketEntry.getKey(), keyCollectorSnapshot, totalRetainedBytes));
}
final IntSet hasMultipleValuesSet;

View File

@ -23,22 +23,23 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.frame.key.RowKey;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
public class ClusterByStatisticsSnapshot
{
private final List<Bucket> buckets;
private final Map<Long, Bucket> buckets;
private final Set<Integer> hasMultipleValues;
@JsonCreator
ClusterByStatisticsSnapshot(
@JsonProperty("buckets") final List<Bucket> buckets,
@JsonProperty("buckets") final Map<Long, Bucket> buckets,
@JsonProperty("hasMultipleValues") @Nullable final Set<Integer> hasMultipleValues
)
{
@ -48,15 +49,21 @@ public class ClusterByStatisticsSnapshot
public static ClusterByStatisticsSnapshot empty()
{
return new ClusterByStatisticsSnapshot(Collections.emptyList(), null);
return new ClusterByStatisticsSnapshot(Collections.emptyMap(), null);
}
@JsonProperty("buckets")
List<Bucket> getBuckets()
Map<Long, Bucket> getBuckets()
{
return buckets;
}
public ClusterByStatisticsSnapshot getSnapshotForTimeChunk(long timeChunk)
{
Bucket bucket = buckets.get(timeChunk);
return new ClusterByStatisticsSnapshot(ImmutableMap.of(timeChunk, bucket), null);
}
@JsonProperty("hasMultipleValues")
@JsonInclude(JsonInclude.Include.NON_EMPTY)
Set<Integer> getHasMultipleValues()
@ -64,6 +71,15 @@ public class ClusterByStatisticsSnapshot
return hasMultipleValues;
}
public PartialKeyStatisticsInformation partialKeyStatistics()
{
double bytesRetained = 0;
for (ClusterByStatisticsSnapshot.Bucket bucket : buckets.values()) {
bytesRetained += bucket.bytesRetained;
}
return new PartialKeyStatisticsInformation(buckets.keySet(), !getHasMultipleValues().isEmpty(), bytesRetained);
}
@Override
public boolean equals(Object o)
{
@ -86,16 +102,19 @@ public class ClusterByStatisticsSnapshot
static class Bucket
{
private final RowKey bucketKey;
private final double bytesRetained;
private final KeyCollectorSnapshot keyCollectorSnapshot;
@JsonCreator
Bucket(
@JsonProperty("bucketKey") RowKey bucketKey,
@JsonProperty("data") KeyCollectorSnapshot keyCollectorSnapshot
@JsonProperty("data") KeyCollectorSnapshot keyCollectorSnapshot,
@JsonProperty("bytesRetained") double bytesRetained
)
{
this.bucketKey = Preconditions.checkNotNull(bucketKey, "bucketKey");
this.keyCollectorSnapshot = Preconditions.checkNotNull(keyCollectorSnapshot, "data");
this.bytesRetained = bytesRetained;
}
@JsonProperty

View File

@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics;
import com.google.common.collect.ImmutableSortedMap;
import java.util.HashSet;
import java.util.Set;
import java.util.SortedMap;
/**
* Class maintained by the controller to merge {@link PartialKeyStatisticsInformation} sent by the worker.
*/
public class CompleteKeyStatisticsInformation
{
private final SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap;
private boolean multipleValues;
private double bytesRetained;
public CompleteKeyStatisticsInformation(
final SortedMap<Long, Set<Integer>> timeChunks,
boolean multipleValues,
double bytesRetained
)
{
this.timeSegmentVsWorkerMap = timeChunks;
this.multipleValues = multipleValues;
this.bytesRetained = bytesRetained;
}
/**
* Merges the {@link PartialKeyStatisticsInformation} into the complete key statistics information object.
* {@link #timeSegmentVsWorkerMap} is updated in sorted order with the timechunks from
* {@param partialKeyStatisticsInformation}, {@link #multipleValues} is set to true if
* {@param partialKeyStatisticsInformation} contains multipleValues and the bytes retained by the partial sketch
* is added to {@link #bytesRetained}.
*/
public void mergePartialInformation(int workerNumber, PartialKeyStatisticsInformation partialKeyStatisticsInformation)
{
for (Long timeSegment : partialKeyStatisticsInformation.getTimeSegments()) {
this.timeSegmentVsWorkerMap
.computeIfAbsent(timeSegment, key -> new HashSet<>())
.add(workerNumber);
}
this.multipleValues = this.multipleValues || partialKeyStatisticsInformation.hasMultipleValues();
this.bytesRetained += bytesRetained;
}
public SortedMap<Long, Set<Integer>> getTimeSegmentVsWorkerMap()
{
return ImmutableSortedMap.copyOfSorted(timeSegmentVsWorkerMap);
}
public boolean hasMultipleValues()
{
return multipleValues;
}
public double getBytesRetained()
{
return bytesRetained;
}
}

View File

@ -22,16 +22,19 @@ package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.ISE;
import javax.annotation.Nullable;
import java.util.Objects;
@JsonTypeName(DelegateOrMinKeyCollectorSnapshot.TYPE)
public class DelegateOrMinKeyCollectorSnapshot<T extends KeyCollectorSnapshot> implements KeyCollectorSnapshot
{
static final String FIELD_SNAPSHOT = "snapshot";
static final String FIELD_MIN_KEY = "minKey";
static final String TYPE = "delegate";
private final T snapshot;
private final RowKey minKey;

View File

@ -22,6 +22,7 @@ package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.frame.key.RowKey;
@ -31,8 +32,10 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
@JsonTypeName(DistinctKeySnapshot.TYPE)
public class DistinctKeySnapshot implements KeyCollectorSnapshot
{
static final String TYPE = "distinct";
private final List<SerializablePair<RowKey, Long>> keys;
private final int spaceReductionFactor;

View File

@ -19,9 +19,18 @@
package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
/**
* Marker interface for deserialization.
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "collectorType")
@JsonSubTypes(value = {
@JsonSubTypes.Type(name = DelegateOrMinKeyCollectorSnapshot.TYPE, value = DelegateOrMinKeyCollectorSnapshot.class),
@JsonSubTypes.Type(name = QuantilesSketchKeyCollectorSnapshot.TYPE, value = QuantilesSketchKeyCollectorSnapshot.class),
@JsonSubTypes.Type(name = DistinctKeySnapshot.TYPE, value = DistinctKeySnapshot.class),
})
public interface KeyCollectorSnapshot
{
}

View File

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Set;
/**
* Class sent by worker to controller after reading input to generate partition boundries.
*/
public class PartialKeyStatisticsInformation
{
private final Set<Long> timeSegments;
private final boolean multipleValues;
private final double bytesRetained;
@JsonCreator
public PartialKeyStatisticsInformation(
@JsonProperty("timeSegments") Set<Long> timeSegments,
@JsonProperty("multipleValues") boolean hasMultipleValues,
@JsonProperty("bytesRetained") double bytesRetained
)
{
this.timeSegments = timeSegments;
this.multipleValues = hasMultipleValues;
this.bytesRetained = bytesRetained;
}
@JsonProperty("timeSegments")
public Set<Long> getTimeSegments()
{
return timeSegments;
}
@JsonProperty("multipleValues")
public boolean hasMultipleValues()
{
return multipleValues;
}
@JsonProperty("bytesRetained")
public double getBytesRetained()
{
return bytesRetained;
}
}

View File

@ -21,11 +21,14 @@ package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import java.util.Objects;
@JsonTypeName(QuantilesSketchKeyCollectorSnapshot.TYPE)
public class QuantilesSketchKeyCollectorSnapshot implements KeyCollectorSnapshot
{
static final String TYPE = "quantile";
private final String encodedSketch;
private final double averageKeyLength;

View File

@ -25,6 +25,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.opencsv.RFC4180Parser;
import com.opencsv.RFC4180ParserBuilder;
import org.apache.druid.msq.exec.ClusterStatisticsMergeMode;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.sql.MSQMode;
import org.apache.druid.query.QueryContext;
@ -58,6 +59,8 @@ public class MultiStageQueryContext
private static final boolean DEFAULT_FINALIZE_AGGREGATIONS = true;
public static final String CTX_ENABLE_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage";
public static final String CTX_CLUSTER_STATISTICS_MERGE_MODE = "clusterStatisticsMergeMode";
public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.AUTO.toString();
private static final boolean DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = false;
public static final String CTX_DESTINATION = "destination";
@ -93,6 +96,18 @@ public class MultiStageQueryContext
);
}
public static ClusterStatisticsMergeMode getClusterStatisticsMergeMode(QueryContext queryContext)
{
return ClusterStatisticsMergeMode.valueOf(
String.valueOf(
queryContext.getString(
CTX_CLUSTER_STATISTICS_MERGE_MODE,
DEFAULT_CLUSTER_STATISTICS_MERGE_MODE
)
)
);
}
public static boolean isFinalizeAggregations(final QueryContext queryContext)
{
return queryContext.getBoolean(

View File

@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.exec;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.util.Collections;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class WorkerSketchFetcherAutoModeTest
{
@Mock
private CompleteKeyStatisticsInformation completeKeyStatisticsInformation;
@Mock
private StageDefinition stageDefinition;
@Mock
private ClusterBy clusterBy;
private AutoCloseable mocks;
private WorkerSketchFetcher target;
@Before
public void setUp()
{
mocks = MockitoAnnotations.openMocks(this);
target = spy(new WorkerSketchFetcher(mock(WorkerClient.class), ClusterStatisticsMergeMode.AUTO, 300_000_000));
// Don't actually try to fetch sketches
doReturn(null).when(target).inMemoryFullSketchMerging(any(), any());
doReturn(null).when(target).sequentialTimeChunkMerging(any(), any(), any());
doReturn(StageId.fromString("1_1")).when(stageDefinition).getId();
doReturn(clusterBy).when(stageDefinition).getClusterBy();
}
@After
public void tearDown() throws Exception
{
mocks.close();
}
@Test
public void test_submitFetcherTask_belowThresholds_ShouldBeParallel()
{
// Bytes below threshold
doReturn(10.0).when(completeKeyStatisticsInformation).getBytesRetained();
// Cluster by bucket count not 0
doReturn(1).when(clusterBy).getBucketByCount();
// Worker count below threshold
doReturn(1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
verify(target, times(1)).inMemoryFullSketchMerging(any(), any());
verify(target, times(0)).sequentialTimeChunkMerging(any(), any(), any());
}
@Test
public void test_submitFetcherTask_workerCountAboveThreshold_shouldBeSequential()
{
// Bytes below threshold
doReturn(10.0).when(completeKeyStatisticsInformation).getBytesRetained();
// Cluster by bucket count not 0
doReturn(1).when(clusterBy).getBucketByCount();
// Worker count below threshold
doReturn((int) WorkerSketchFetcher.WORKER_THRESHOLD + 1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
verify(target, times(0)).inMemoryFullSketchMerging(any(), any());
verify(target, times(1)).sequentialTimeChunkMerging(any(), any(), any());
}
@Test
public void test_submitFetcherTask_noClusterByColumns_shouldBeParallel()
{
// Bytes above threshold
doReturn(WorkerSketchFetcher.BYTES_THRESHOLD + 10.0).when(completeKeyStatisticsInformation).getBytesRetained();
// Cluster by bucket count 0
doReturn(ClusterBy.none()).when(stageDefinition).getClusterBy();
// Worker count above threshold
doReturn((int) WorkerSketchFetcher.WORKER_THRESHOLD + 1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
verify(target, times(1)).inMemoryFullSketchMerging(any(), any());
verify(target, times(0)).sequentialTimeChunkMerging(any(), any(), any());
}
@Test
public void test_submitFetcherTask_bytesRetainedAboveThreshold_shouldBeSequential()
{
// Bytes above threshold
doReturn(WorkerSketchFetcher.BYTES_THRESHOLD + 10.0).when(completeKeyStatisticsInformation).getBytesRetained();
// Cluster by bucket count not 0
doReturn(1).when(clusterBy).getBucketByCount();
// Worker count below threshold
doReturn(1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
verify(target, times(0)).inMemoryFullSketchMerging(any(), any());
verify(target, times(1)).sequentialTimeChunkMerging(any(), any(), any());
}
}

View File

@ -0,0 +1,295 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.exec;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.SortedMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import static org.easymock.EasyMock.mock;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class WorkerSketchFetcherTest
{
@Mock
private CompleteKeyStatisticsInformation completeKeyStatisticsInformation;
@Mock
private StageDefinition stageDefinition;
@Mock
private ClusterBy clusterBy;
@Mock
private ClusterByStatisticsCollector mergedClusterByStatisticsCollector1;
@Mock
private ClusterByStatisticsCollector mergedClusterByStatisticsCollector2;
@Mock
private WorkerClient workerClient;
private ClusterByPartitions expectedPartitions1;
private ClusterByPartitions expectedPartitions2;
private AutoCloseable mocks;
private WorkerSketchFetcher target;
@Before
public void setUp()
{
mocks = MockitoAnnotations.openMocks(this);
doReturn(StageId.fromString("1_1")).when(stageDefinition).getId();
doReturn(clusterBy).when(stageDefinition).getClusterBy();
doReturn(25_000).when(stageDefinition).getMaxPartitionCount();
expectedPartitions1 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(mock(RowKey.class), mock(RowKey.class))));
expectedPartitions2 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(mock(RowKey.class), mock(RowKey.class))));
doReturn(Either.value(expectedPartitions1)).when(stageDefinition).generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector1));
doReturn(Either.value(expectedPartitions2)).when(stageDefinition).generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector2));
doReturn(
mergedClusterByStatisticsCollector1,
mergedClusterByStatisticsCollector2
).when(stageDefinition).createResultKeyStatisticsCollector(anyInt());
}
@After
public void tearDown() throws Exception
{
mocks.close();
}
@Test
public void test_submitFetcherTask_parallelFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception
{
// Store futures in a queue
final Queue<ListenableFuture<ClusterByStatisticsSnapshot>> futureQueue = new ConcurrentLinkedQueue<>();
final List<String> workerIds = ImmutableList.of("0", "1", "2", "3");
final CountDownLatch latch = new CountDownLatch(workerIds.size());
target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.PARALLEL, 300_000_000));
// When fetching snapshots, return a mock and add future to queue
doAnswer(invocation -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotListenableFuture =
spy(Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)));
futureQueue.add(snapshotListenableFuture);
latch.countDown();
latch.await();
return snapshotListenableFuture;
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt());
// Cause a worker to fail instead of returning the result
doAnswer(invocation -> {
latch.countDown();
latch.await();
return Futures.immediateFailedFuture(new InterruptedException("interrupted"));
}).when(workerClient).fetchClusterByStatisticsSnapshot(eq("2"), any(), anyInt());
CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
workerIds,
stageDefinition
);
// Assert that the final result is failed and all other task futures are also cancelled.
Assert.assertThrows(CompletionException.class, eitherCompletableFuture::join);
Thread.sleep(1000);
Assert.assertTrue(eitherCompletableFuture.isCompletedExceptionally());
// Verify that the statistics collector was cleared due to the error.
verify(mergedClusterByStatisticsCollector1, times(1)).clear();
// Verify that other task futures were requested to be cancelled.
Assert.assertFalse(futureQueue.isEmpty());
for (ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture : futureQueue) {
verify(snapshotFuture, times(1)).cancel(eq(true));
}
}
@Test
public void test_submitFetcherTask_parallelFetch_mergePerformedCorrectly()
throws ExecutionException, InterruptedException
{
// Store snapshots in a queue
final Queue<ClusterByStatisticsSnapshot> snapshotQueue = new ConcurrentLinkedQueue<>();
final List<String> workerIds = ImmutableList.of("0", "1", "2", "3", "4");
final CountDownLatch latch = new CountDownLatch(workerIds.size());
target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.PARALLEL, 300_000_000));
// When fetching snapshots, return a mock and add it to queue
doAnswer(invocation -> {
ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class);
snapshotQueue.add(snapshot);
latch.countDown();
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt());
CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
workerIds,
stageDefinition
);
// Assert that the final result is complete and all other sketches returned have been merged.
eitherCompletableFuture.join();
Thread.sleep(1000);
Assert.assertTrue(eitherCompletableFuture.isDone() && !eitherCompletableFuture.isCompletedExceptionally());
Assert.assertFalse(snapshotQueue.isEmpty());
// Verify that all statistics were added to controller.
for (ClusterByStatisticsSnapshot snapshot : snapshotQueue) {
verify(mergedClusterByStatisticsCollector1, times(1)).addAll(eq(snapshot));
}
// Check that the partitions returned by the merged collector is returned by the final future.
Assert.assertEquals(expectedPartitions1, eitherCompletableFuture.get().valueOrThrow());
}
@Test
public void test_submitFetcherTask_sequentialFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception
{
// Store futures in a queue
final Queue<ListenableFuture<ClusterByStatisticsSnapshot>> futureQueue = new ConcurrentLinkedQueue<>();
SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap = ImmutableSortedMap.of(1L, ImmutableSet.of(0, 1, 2), 2L, ImmutableSet.of(0, 1, 4));
doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap();
final CyclicBarrier barrier = new CyclicBarrier(3);
target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.SEQUENTIAL, 300_000_000));
// When fetching snapshots, return a mock and add future to queue
doAnswer(invocation -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotListenableFuture =
spy(Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)));
futureQueue.add(snapshotListenableFuture);
barrier.await();
return snapshotListenableFuture;
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(anyString(), anyString(), anyInt(), anyLong());
// Cause a worker in the second time chunk to fail instead of returning the result
doAnswer(invocation -> {
barrier.await();
return Futures.immediateFailedFuture(new InterruptedException("interrupted"));
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(eq("4"), any(), anyInt(), eq(2L));
CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
ImmutableList.of("0", "1", "2", "3", "4"),
stageDefinition
);
// Assert that the final result is failed and all other task futures are also cancelled.
Assert.assertThrows(CompletionException.class, eitherCompletableFuture::join);
Thread.sleep(1000);
Assert.assertTrue(eitherCompletableFuture.isCompletedExceptionally());
// Verify that the correct statistics collector was cleared due to the error.
verify(mergedClusterByStatisticsCollector1, times(0)).clear();
verify(mergedClusterByStatisticsCollector2, times(1)).clear();
// Verify that other task futures were requested to be cancelled.
Assert.assertFalse(futureQueue.isEmpty());
for (ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture : futureQueue) {
verify(snapshotFuture, times(1)).cancel(eq(true));
}
}
@Test
public void test_submitFetcherTask_sequentialFetch_mergePerformedCorrectly()
throws ExecutionException, InterruptedException
{
// Store snapshots in a queue
final Queue<ClusterByStatisticsSnapshot> snapshotQueue = new ConcurrentLinkedQueue<>();
SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap = ImmutableSortedMap.of(1L, ImmutableSet.of(0, 1, 2), 2L, ImmutableSet.of(0, 1, 4));
doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap();
final CyclicBarrier barrier = new CyclicBarrier(3);
target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.SEQUENTIAL, 300_000_000));
// When fetching snapshots, return a mock and add it to queue
doAnswer(invocation -> {
ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class);
snapshotQueue.add(snapshot);
barrier.await();
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyInt(), anyLong());
CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
ImmutableList.of("0", "1", "2", "3", "4"),
stageDefinition
);
// Assert that the final result is complete and all other sketches returned have been merged.
eitherCompletableFuture.join();
Thread.sleep(1000);
Assert.assertTrue(eitherCompletableFuture.isDone() && !eitherCompletableFuture.isCompletedExceptionally());
Assert.assertFalse(snapshotQueue.isEmpty());
// Verify that all statistics were added to controller.
snapshotQueue.stream().limit(3).forEach(snapshot -> {
verify(mergedClusterByStatisticsCollector1, times(1)).addAll(eq(snapshot));
});
snapshotQueue.stream().skip(3).limit(3).forEach(snapshot -> {
verify(mergedClusterByStatisticsCollector2, times(1)).addAll(eq(snapshot));
});
ClusterByPartitions expectedResult =
new ClusterByPartitions(
ImmutableList.of(
new ClusterByPartition(expectedPartitions1.get(0).getStart(), expectedPartitions2.get(0).getStart()),
new ClusterByPartition(expectedPartitions2.get(0).getStart(), expectedPartitions2.get(0).getEnd())
)
);
// Check that the partitions returned by the merged collector is returned by the final future.
Assert.assertEquals(expectedResult, eitherCompletableFuture.get().valueOrThrow());
}
}

View File

@ -21,6 +21,7 @@ package org.apache.druid.msq.kernel.controller;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.KeyTestUtils;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.IAE;
@ -31,6 +32,7 @@ import org.apache.druid.msq.input.MapInputSpecSlicer;
import org.apache.druid.msq.input.stage.StageInputSpec;
import org.apache.druid.msq.input.stage.StageInputSpecSlicer;
import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
@ -80,7 +82,7 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
public ControllerQueryKernelTester queryDefinition(QueryDefinition queryDefinition)
{
this.queryDefinition = Preconditions.checkNotNull(queryDefinition);
this.controllerQueryKernel = new ControllerQueryKernel(queryDefinition, 10_000_000);
this.controllerQueryKernel = new ControllerQueryKernel(queryDefinition);
return this;
}
@ -121,10 +123,10 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
if (queryDefinition.getStageDefinition(stageNumber).mustGatherResultKeyStatistics()) {
for (int i = 0; i < numWorkers; ++i) {
controllerQueryKernel.addResultKeyStatisticsForStageAndWorker(
controllerQueryKernel.addPartialKeyStatisticsForStageAndWorker(
new StageId(queryDefinition.getQueryId(), stageNumber),
i,
ClusterByStatisticsSnapshot.empty()
ClusterByStatisticsSnapshot.empty().partialKeyStatistics()
);
}
} else {
@ -238,7 +240,7 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
controllerQueryKernel.finishStage(new StageId(queryDefinition.getQueryId(), stageNumber), strict);
}
public void addResultKeyStatisticsForStageAndWorker(int stageNumber, int workerNumber)
public ClusterByStatisticsCollector addResultKeyStatisticsForStageAndWorker(int stageNumber, int workerNumber)
{
Preconditions.checkArgument(initialized);
@ -254,11 +256,12 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
keyStatsCollector.add(key, 1);
}
controllerQueryKernel.addResultKeyStatisticsForStageAndWorker(
controllerQueryKernel.addPartialKeyStatisticsForStageAndWorker(
new StageId(queryDefinition.getQueryId(), stageNumber),
workerNumber,
keyStatsCollector.snapshot()
keyStatsCollector.snapshot().partialKeyStatistics()
);
return keyStatsCollector;
}
public void setResultsCompleteForStageAndWorker(int stageNumber, int workerNumber)
@ -271,6 +274,18 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
);
}
public void setPartitionBoundaries(int stageNumber, ClusterByStatisticsCollector clusterByStatisticsCollector)
{
Preconditions.checkArgument(initialized);
StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber);
StageDefinition stageDefinition = controllerQueryKernel.getStageDefinition(stageId);
ClusterByPartitions clusterByPartitions =
stageDefinition
.generatePartitionsForShuffle(clusterByStatisticsCollector)
.valueOrThrow();
controllerQueryKernel.setClusterByPartitionBoundaries(stageId, clusterByPartitions);
}
public void failStage(int stageNumber)
{
Preconditions.checkArgument(initialized);

View File

@ -20,6 +20,7 @@
package org.apache.druid.msq.kernel.controller;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.junit.Assert;
import org.junit.Test;
@ -146,8 +147,13 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
Assert.assertEquals(ImmutableSet.of(0), newStageNumbers);
Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers);
controllerQueryKernelTester.startStage(0);
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.POST_READING);
ClusterByStatisticsCollector clusterByStatisticsCollector =
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
0,
0
);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.setPartitionBoundaries(0, clusterByStatisticsCollector);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
@ -156,9 +162,20 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
Assert.assertEquals(ImmutableSet.of(1), newStageNumbers);
Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers);
controllerQueryKernelTester.startStage(1);
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(1, 0);
clusterByStatisticsCollector =
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
1,
0
);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(1, 1);
clusterByStatisticsCollector.addAll(
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
1,
1
)
);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.setPartitionBoundaries(1, clusterByStatisticsCollector);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.POST_READING);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(1, 0);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.POST_READING);
@ -182,9 +199,19 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
Assert.assertEquals(ImmutableSet.of(1), effectivelyFinishedStageNumbers);
controllerQueryKernelTester.startStage(3);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(3, 0);
ClusterByStatisticsCollector clusterByStatisticsCollector3 =
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
3,
0
);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(3, 1);
ClusterByStatisticsCollector clusterByStatisticsCollector4 =
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
3,
1
);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.setPartitionBoundaries(3, clusterByStatisticsCollector3.addAll(clusterByStatisticsCollector4));
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.POST_READING);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(3, 0);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.POST_READING);
@ -217,11 +244,21 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
controllerQueryKernelTester.createAndGetNewStageNumbers();
controllerQueryKernelTester.startStage(0);
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(0, 0);
ClusterByStatisticsCollector clusterByStatisticsCollector =
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
0,
0
);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.POST_READING);
clusterByStatisticsCollector.addAll(
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
0,
1
)
);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.setPartitionBoundaries(0, clusterByStatisticsCollector);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.POST_READING);

View File

@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.statistics;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.segment.TestHelper;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
public class PartialKeyStatisticsInformationSerdeTest
{
private ObjectMapper objectMapper;
@Before
public void setUp()
{
objectMapper = TestHelper.makeJsonMapper();
objectMapper.registerModules(new MSQIndexingModule().getJacksonModules());
objectMapper.enable(JsonParser.Feature.STRICT_DUPLICATE_DETECTION);
}
@Test
public void testSerde() throws JsonProcessingException
{
PartialKeyStatisticsInformation partialInformation = new PartialKeyStatisticsInformation(
ImmutableSet.of(2L, 3L),
false,
0.0
);
final String json = objectMapper.writeValueAsString(partialInformation);
final PartialKeyStatisticsInformation deserializedKeyStatistics = objectMapper.readValue(
json,
PartialKeyStatisticsInformation.class
);
Assert.assertEquals(json, partialInformation.getTimeSegments(), deserializedKeyStatistics.getTimeSegments());
Assert.assertEquals(json, partialInformation.hasMultipleValues(), deserializedKeyStatistics.hasMultipleValues());
Assert.assertEquals(json, partialInformation.getBytesRetained(), deserializedKeyStatistics.getBytesRetained(), 0);
}
}

View File

@ -25,7 +25,7 @@ import org.apache.druid.msq.exec.Controller;
import org.apache.druid.msq.exec.ControllerClient;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import javax.annotation.Nullable;
import java.util.List;
@ -40,17 +40,17 @@ public class MSQTestControllerClient implements ControllerClient
}
@Override
public void postKeyStatistics(
public void postPartialKeyStatistics(
StageId stageId,
int workerNumber,
ClusterByStatisticsSnapshot keyStatistics
PartialKeyStatisticsInformation partialKeyStatisticsInformation
)
{
try {
controller.updateStatus(stageId.getStageNumber(), workerNumber, keyStatistics);
controller.updatePartialKeyStatisticsInformation(stageId.getStageNumber(), workerNumber, partialKeyStatisticsInformation);
}
catch (Exception e) {
throw new ISE(e, "unable to post key statistics");
throw new ISE(e, "unable to post partial key statistics");
}
}

View File

@ -29,6 +29,7 @@ import org.apache.druid.msq.exec.Worker;
import org.apache.druid.msq.exec.WorkerClient;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import java.io.InputStream;
import java.util.Arrays;
@ -50,6 +51,29 @@ public class MSQTestWorkerClient implements WorkerClient
return Futures.immediateFuture(null);
}
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerTaskId,
String queryId,
int stageNumber
)
{
StageId stageId = new StageId(queryId, stageNumber);
return Futures.immediateFuture(inMemoryWorkers.get(workerTaskId).fetchStatisticsSnapshot(stageId));
}
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerTaskId,
String queryId,
int stageNumber,
long timeChunk
)
{
StageId stageId = new StageId(queryId, stageNumber);
return Futures.immediateFuture(inMemoryWorkers.get(workerTaskId).fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk));
}
@Override
public ListenableFuture<Void> postResultPartitionBoundaries(
String workerTaskId,

View File

@ -71,7 +71,7 @@ Start the cluster:
```bash
cd $DRUID_DEV/integration-tests-ex/cases
./cluster.sh <category> up
./cluster.sh up <category>
```
Where `<category>` is one of the test categories. Then launch the

View File

@ -178,6 +178,11 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-sql</artifactId>
<version>25.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.druid.extensions</groupId>
<artifactId>druid-multi-stage-query</artifactId>

View File

@ -0,0 +1,206 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.testsEx.msq;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.msq.exec.ClusterStatisticsMergeMode;
import org.apache.druid.msq.sql.SqlTaskStatus;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.sql.http.SqlQuery;
import org.apache.druid.testing.IntegrationTestingConfig;
import org.apache.druid.testing.clients.CoordinatorResourceTestClient;
import org.apache.druid.testing.clients.SqlResourceTestClient;
import org.apache.druid.testing.utils.DataLoaderHelper;
import org.apache.druid.testing.utils.MsqTestQueryHelper;
import org.apache.druid.testsEx.categories.MultiStageQuery;
import org.apache.druid.testsEx.config.DruidTestRunner;
import org.junit.Assert;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
@RunWith(DruidTestRunner.class)
@Category(MultiStageQuery.class)
public class ITKeyStatisticsSketchMergeMode
{
@Inject
private MsqTestQueryHelper msqHelper;
@Inject
private SqlResourceTestClient msqClient;
@Inject
private IntegrationTestingConfig config;
@Inject
private ObjectMapper jsonMapper;
@Inject
private DataLoaderHelper dataLoaderHelper;
@Inject
private CoordinatorResourceTestClient coordinatorClient;
private static final String QUERY_FILE = "/multi-stage-query/wikipedia_msq_select_query1.json";
@Test
public void testMsqIngestionParallelMerging() throws Exception
{
String datasource = "dst";
// Clear up the datasource from the previous runs
coordinatorClient.unloadSegmentsForDataSource(datasource);
String queryLocal =
StringUtils.format(
"INSERT INTO %s\n"
+ "SELECT\n"
+ " TIME_PARSE(\"timestamp\") AS __time,\n"
+ " isRobot,\n"
+ " diffUrl,\n"
+ " added,\n"
+ " countryIsoCode,\n"
+ " regionName,\n"
+ " channel,\n"
+ " flags,\n"
+ " delta,\n"
+ " isUnpatrolled,\n"
+ " isNew,\n"
+ " deltaBucket,\n"
+ " isMinor,\n"
+ " isAnonymous,\n"
+ " deleted,\n"
+ " cityName,\n"
+ " metroCode,\n"
+ " namespace,\n"
+ " comment,\n"
+ " page,\n"
+ " commentLength,\n"
+ " countryName,\n"
+ " user,\n"
+ " regionIsoCode\n"
+ "FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{\"type\":\"local\",\"files\":[\"/resources/data/batch_index/json/wikipedia_index_data1.json\"]}',\n"
+ " '{\"type\":\"json\"}',\n"
+ " '[{\"type\":\"string\",\"name\":\"timestamp\"},{\"type\":\"string\",\"name\":\"isRobot\"},{\"type\":\"string\",\"name\":\"diffUrl\"},{\"type\":\"long\",\"name\":\"added\"},{\"type\":\"string\",\"name\":\"countryIsoCode\"},{\"type\":\"string\",\"name\":\"regionName\"},{\"type\":\"string\",\"name\":\"channel\"},{\"type\":\"string\",\"name\":\"flags\"},{\"type\":\"long\",\"name\":\"delta\"},{\"type\":\"string\",\"name\":\"isUnpatrolled\"},{\"type\":\"string\",\"name\":\"isNew\"},{\"type\":\"double\",\"name\":\"deltaBucket\"},{\"type\":\"string\",\"name\":\"isMinor\"},{\"type\":\"string\",\"name\":\"isAnonymous\"},{\"type\":\"long\",\"name\":\"deleted\"},{\"type\":\"string\",\"name\":\"cityName\"},{\"type\":\"long\",\"name\":\"metroCode\"},{\"type\":\"string\",\"name\":\"namespace\"},{\"type\":\"string\",\"name\":\"comment\"},{\"type\":\"string\",\"name\":\"page\"},{\"type\":\"long\",\"name\":\"commentLength\"},{\"type\":\"string\",\"name\":\"countryName\"},{\"type\":\"string\",\"name\":\"user\"},{\"type\":\"string\",\"name\":\"regionIsoCode\"}]'\n"
+ " )\n"
+ ")\n"
+ "PARTITIONED BY DAY\n"
+ "CLUSTERED BY \"__time\"",
datasource
);
ImmutableMap<String, Object> context = ImmutableMap.of(
MultiStageQueryContext.CTX_CLUSTER_STATISTICS_MERGE_MODE,
ClusterStatisticsMergeMode.PARALLEL
);
// Submit the task and wait for the datasource to get loaded
SqlQuery sqlQuery = new SqlQuery(queryLocal, null, false, false, false, context, null);
SqlTaskStatus sqlTaskStatus = msqHelper.submitMsqTask(sqlQuery);
if (sqlTaskStatus.getState().isFailure()) {
Assert.fail(StringUtils.format(
"Unable to start the task successfully.\nPossible exception: %s",
sqlTaskStatus.getError()
));
}
msqHelper.pollTaskIdForCompletion(sqlTaskStatus.getTaskId());
dataLoaderHelper.waitUntilDatasourceIsReady(datasource);
msqHelper.testQueriesFromFile(QUERY_FILE, datasource);
}
@Test
public void testMsqIngestionSequentialMerging() throws Exception
{
String datasource = "dst";
// Clear up the datasource from the previous runs
coordinatorClient.unloadSegmentsForDataSource(datasource);
String queryLocal =
StringUtils.format(
"INSERT INTO %s\n"
+ "SELECT\n"
+ " TIME_PARSE(\"timestamp\") AS __time,\n"
+ " isRobot,\n"
+ " diffUrl,\n"
+ " added,\n"
+ " countryIsoCode,\n"
+ " regionName,\n"
+ " channel,\n"
+ " flags,\n"
+ " delta,\n"
+ " isUnpatrolled,\n"
+ " isNew,\n"
+ " deltaBucket,\n"
+ " isMinor,\n"
+ " isAnonymous,\n"
+ " deleted,\n"
+ " cityName,\n"
+ " metroCode,\n"
+ " namespace,\n"
+ " comment,\n"
+ " page,\n"
+ " commentLength,\n"
+ " countryName,\n"
+ " user,\n"
+ " regionIsoCode\n"
+ "FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{\"type\":\"local\",\"files\":[\"/resources/data/batch_index/json/wikipedia_index_data1.json\"]}',\n"
+ " '{\"type\":\"json\"}',\n"
+ " '[{\"type\":\"string\",\"name\":\"timestamp\"},{\"type\":\"string\",\"name\":\"isRobot\"},{\"type\":\"string\",\"name\":\"diffUrl\"},{\"type\":\"long\",\"name\":\"added\"},{\"type\":\"string\",\"name\":\"countryIsoCode\"},{\"type\":\"string\",\"name\":\"regionName\"},{\"type\":\"string\",\"name\":\"channel\"},{\"type\":\"string\",\"name\":\"flags\"},{\"type\":\"long\",\"name\":\"delta\"},{\"type\":\"string\",\"name\":\"isUnpatrolled\"},{\"type\":\"string\",\"name\":\"isNew\"},{\"type\":\"double\",\"name\":\"deltaBucket\"},{\"type\":\"string\",\"name\":\"isMinor\"},{\"type\":\"string\",\"name\":\"isAnonymous\"},{\"type\":\"long\",\"name\":\"deleted\"},{\"type\":\"string\",\"name\":\"cityName\"},{\"type\":\"long\",\"name\":\"metroCode\"},{\"type\":\"string\",\"name\":\"namespace\"},{\"type\":\"string\",\"name\":\"comment\"},{\"type\":\"string\",\"name\":\"page\"},{\"type\":\"long\",\"name\":\"commentLength\"},{\"type\":\"string\",\"name\":\"countryName\"},{\"type\":\"string\",\"name\":\"user\"},{\"type\":\"string\",\"name\":\"regionIsoCode\"}]'\n"
+ " )\n"
+ ")\n"
+ "PARTITIONED BY DAY\n"
+ "CLUSTERED BY \"__time\"",
datasource
);
ImmutableMap<String, Object> context = ImmutableMap.of(
MultiStageQueryContext.CTX_CLUSTER_STATISTICS_MERGE_MODE,
ClusterStatisticsMergeMode.SEQUENTIAL
);
// Submit the task and wait for the datasource to get loaded
SqlQuery sqlQuery = new SqlQuery(queryLocal, null, false, false, false, context, null);
SqlTaskStatus sqlTaskStatus = msqHelper.submitMsqTask(sqlQuery);
if (sqlTaskStatus.getState().isFailure()) {
Assert.fail(StringUtils.format(
"Unable to start the task successfully.\nPossible exception: %s",
sqlTaskStatus.getError()
));
}
msqHelper.pollTaskIdForCompletion(sqlTaskStatus.getTaskId());
dataLoaderHelper.waitUntilDatasourceIsReady(datasource);
msqHelper.testQueriesFromFile(QUERY_FILE, datasource);
}
}

View File

@ -130,6 +130,28 @@ public class RowKeyReader
}
}
/**
* Trims the key reader to a particular fieldCount. Used to read keys trimmed by {@link #trim(RowKey, int)}.
*/
public RowKeyReader trimmedKeyReader(int trimmedFieldCount)
{
final RowSignature.Builder newSignature = RowSignature.builder();
if (trimmedFieldCount > signature.size()) {
throw new IAE("Cannot trim to [%,d] fields, only have [%,d] fields", trimmedFieldCount, signature);
}
for (int i = 0; i < trimmedFieldCount; i++) {
final String columnName = signature.getColumnName(i);
final ColumnType columnType =
Preconditions.checkNotNull(signature.getColumnType(i).orElse(null), "Type for column [%s]", columnName);
newSignature.add(columnName, columnType);
}
return RowKeyReader.create(newSignature.build());
}
/**
* Trim a key to a particular fieldCount. The returned key may be a copy, but is not guaranteed to be.
*/

View File

@ -30,6 +30,7 @@ import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.IntStream;
@ -144,4 +145,41 @@ public class RowKeyReaderTest extends InitializedNullHandlingTest
MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString("Cannot trim")));
}
@Test
public void test_trimmedKeyReader_zero()
{
RowKey trimmedKey = keyReader.trim(key, 0);
RowKeyReader trimmedKeyReader = keyReader.trimmedKeyReader(0);
Assert.assertEquals(
Collections.emptyList(),
trimmedKeyReader.read(trimmedKey)
);
}
@Test
public void test_trimmedKeyReader_one()
{
RowKey trimmedKey = keyReader.trim(key, 1);
RowKeyReader trimmedKeyReader = keyReader.trimmedKeyReader(1);
Assert.assertEquals(
objects.subList(0, 1),
trimmedKeyReader.read(trimmedKey)
);
}
@Test
public void test_trimmedKeyReader_oneLessThanFullLength()
{
final int numFields = signature.size() - 1;
RowKey trimmedKey = keyReader.trim(key, numFields);
RowKeyReader trimmedKeyReader = keyReader.trimmedKeyReader(numFields);
Assert.assertEquals(
objects.subList(0, numFields),
trimmedKeyReader.read(trimmedKey)
);
}
}

View File

@ -68,6 +68,9 @@ Double.NEGATIVE_INFINITY
Double.NEGATIVE_INFINITY.
Double.POSITIVE_INFINITY
Double.POSITIVE_INFINITY.
downsampled
downsamples
downsampling
Dropwizard
dropwizard
DruidInputSource