Worker retry for MSQ task (#13353)

* Initial commit.

* Fixing error message in retry exceeded exception

* Cleaning up some code

* Adding some test cases.

* Adding java docs.

* Finishing up state test cases.

* Adding some more java docs and fixing spot bugs, intellij inspections

* Fixing intellij inspections and added tests

* Documenting error codes

* Migrate current integration batch tests to equivalent MSQ tests (#13374)

* Migrate current integration batch tests to equivalent MSQ tests using new IT framework

* Fix build issues

* Trigger Build

* Adding more tests and addressing comments

* fixBuildIssues

* fix dependency issues

* Parameterized the test and addressed comments

* Addressing comments

* fixing checkstyle errors

* Adressing comments

* Adding ITTest which kills the worker abruptly

* Review comments phase one

* Adding doc changes

* Adjusting for single threaded execution.

* Adding Sequential Merge PR state handling

* Merge things

* Fixing checkstyle.

* Adding new context param for fault tolerance.
Adding stale task handling in sketchFetcher.
Adding UT's.

* Merge things

* Merge things

* Adding parameterized tests
Created separate module for faultToleranceTests

* Adding missed files

* Review comments and fixing tests.

* Documentation things.

* Fixing IT

* Controller impl fix.

* Fixing racy WorkerSketchFetcherTest.java exception handling.

Co-authored-by: abhagraw <99210446+abhagraw@users.noreply.github.com>
Co-authored-by: Karan Kumar <cryptoe@karans-mbp.lan>
This commit is contained in:
Karan Kumar 2023-01-11 07:38:29 +05:30 committed by GitHub
parent 17936e2920
commit 56076d33fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
70 changed files with 5609 additions and 1050 deletions

View File

@ -525,6 +525,16 @@ jobs:
env: JVM_RUNTIME='-Djvm.runtime=8' USE_INDEXER='middleManager'
script: ./it.sh travis MultiStageQuery
- &integration_tests_ex
name: "(Compile=openjdk8, Run=openjdk8) multi stage query tests with MM"
stage: Tests - phase 2
jdk: openjdk8
services: *integration_test_services
env: JVM_RUNTIME='-Djvm.runtime=8' USE_INDEXER='middleManager'
script: ./it.sh travis MultiStageQueryWithMM
- &integration_tests_ex
name: "(Compile=openjdk8, Run=openjdk8) catalog integration tests"
stage: Tests - phase 2

View File

@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.java.util.common.function;
import java.util.Objects;
/**
* Based on {@link java.util.function.BiConsumer}
*/
@FunctionalInterface
public interface TriConsumer<T, U, V>
{
/**
* Performs this operation on the given arguments.
*
* @param t the first input argument
* @param u the second input argument
* @param v the third input argument
*/
void accept(T t, U u, V v);
/**
* Returns a composed {@code TriConsumer} that performs, in sequence, this
* operation followed by the {@code after} operation. If performing either
* operation throws an exception, it is relayed to the caller of the
* composed operation. If performing this operation throws an exception,
* the {@code after} operation will not be performed.
*
* @param after the operation to perform after this operation
* @return a composed {@code TriConsumer} that performs in sequence this
* operation followed by the {@code after} operation
* @throws NullPointerException if {@code after} is null
*/
default TriConsumer<T, U, V> andThen(TriConsumer<? super T, ? super U, ? super V> after)
{
Objects.requireNonNull(after);
return (t, u, v) -> {
accept(t, u, v);
after.accept(t, u, v);
};
}
}

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.java.util.common.function;
import org.junit.Assert;
import org.junit.Test;
import java.util.HashSet;
import java.util.Set;
public class TriConsumerTest
{
@Test
public void sanityTest()
{
Set<Integer> sumSet = new HashSet<>();
TriConsumer<Integer, Integer, Integer> consumerA = (arg1, arg2, arg3) -> {
sumSet.add(arg1 + arg2 + arg3);
};
TriConsumer<Integer, Integer, Integer> consumerB = (arg1, arg2, arg3) -> {
sumSet.remove(arg1 + arg2 + arg3);
};
consumerA.andThen(consumerB).accept(1, 2, 3);
Assert.assertTrue(sumSet.isEmpty());
}
}

View File

@ -29,7 +29,7 @@ sidebar_label: Known issues
## Multi-stage query task runtime
- Fault tolerance is not implemented. If any task fails, the entire query fails.
- Fault tolerance is partially implemented. Workers get relaunched when they are killed unexpectedly. The controller does not get relaunched if it is killed unexpectedly.
- Worker task stage outputs are stored in the working directory given by `druid.indexer.task.baseDir`. Stages that
generate a large amount of output data may exhaust all available disk space. In this case, the query fails with

View File

@ -322,6 +322,8 @@ The following table lists the context parameters for the MSQ task engine:
| `rowsPerSegment` | INSERT or REPLACE<br /><br />The number of rows per segment to target. The actual number of rows per segment may be somewhat higher or lower than this number. In most cases, use the default. For general information about sizing rows per segment, see [Segment Size Optimization](../operations/segment-optimization.md). | 3,000,000 |
| `indexSpec` | INSERT or REPLACE<br /><br />An [`indexSpec`](../ingestion/ingestion-spec.md#indexspec) to use when generating segments. May be a JSON string or object. See [Front coding](../ingestion/ingestion-spec.md#front-coding) for details on configuring an `indexSpec` with front coding. | See [`indexSpec`](../ingestion/ingestion-spec.md#indexspec). |
| `clusterStatisticsMergeMode` | Whether to use parallel or sequential mode for merging of the worker sketches. Can be `PARALLEL`, `SEQUENTIAL` or `AUTO`. See [Sketch Merging Mode](#sketch-merging-mode) for more information. | `PARALLEL` |
| `durableShuffleStorage` | SELECT, INSERT, REPLACE <br /><br />Whether to use durable storage for shuffle mesh. To use this feature, configure the durable storage at the server level using `druid.msq.intermediate.storage.enable=true`). If these properties are not configured, any query with the context variable `durableShuffleStorage=true` fails with a configuration error. <br /><br /> | `false` |
| `faultTolerance` | SELECT, INSERT, REPLACE<br /><br /> Whether to turn on fault tolerance mode or not. Failed workers are retried based on [Limits](#limits). Cannot be used when `durableShuffleStorage` is explicitly set to false. | `false` |
## Sketch Merging Mode
This section details the advantages and performance of various Cluster By Statistics Merge Modes.
@ -332,17 +334,16 @@ reading rows from the datasource. These statistics must be transferred to the co
`PARALLEL` mode fetches the key statistics for all time chunks from all workers together and the controller then downsamples
the sketch if it does not fit in memory. This is faster than `SEQUENTIAL` mode as there is less over head in fetching sketches
for all time chunks together. This is good for small sketches which won't be downsampled even if merged together or if
for all time chunks together. This is good for small sketches which won't be down sampled even if merged together or if
accuracy in segment sizing for the ingestion is not very important.
`SEQUENTIAL` mode fetches the sketches in ascending order of time and generates the partition boundaries for one time
chunk at a time. This gives more working memory to the controller for merging sketches, which results in less
downsampling and thus, more accuracy. There is, however, a time overhead on fetching sketches in sequential order. This is
down sampling and thus, more accuracy. There is, however, a time overhead on fetching sketches in sequential order. This is
good for cases where accuracy is important.
`AUTO` mode tries to find the best approach based on number of workers and size of input rows. If there are more
than 100 workers or if the combined sketch size among all workers is more than 1GB, `SEQUENTIAL` is chosen, otherwise,
`PARALLEL` is chosen.
`AUTO` mode tries to find the best approach based on number of workers. If there are more
than 100 workers, `SEQUENTIAL` is chosen, otherwise, `PARALLEL` is chosen.
## Durable Storage
@ -376,7 +377,8 @@ The following table lists query limits:
| Number of cluster by columns that can appear in a stage | 1,500 | [`TooManyClusteredByColumns`](#error_TooManyClusteredByColumns) |
| Number of workers for any one stage. | Hard limit is 1,000. Memory-dependent soft limit may be lower. | [`TooManyWorkers`](#error_TooManyWorkers) |
| Maximum memory occupied by broadcasted tables. | 30% of each [processor memory bundle](concepts.md#memory-usage). | [`BroadcastTablesTooLarge`](#error_BroadcastTablesTooLarge) |
| Maximum relaunch attempts per worker. Initial run is not a relaunch. The worker will be spawned 1 + `workerRelaunchLimit` times before the job fails. | 2 | `TooManyAttemptsForWorker` |
| Maximum relaunch attempts for a job across all workers. | 100 | `TooManyAttemptsForJob` |
<a name="errors"></a>
## Error codes
@ -401,6 +403,8 @@ The following table describes error codes you may encounter in the `multiStageQu
| <a name="error_QueryNotSupported">`QueryNotSupported`</a> | QueryKit could not translate the provided native query to a multi-stage query.<br /> <br />This can happen if the query uses features that aren't supported, like GROUPING SETS. | |
| <a name="error_RowTooLarge">`RowTooLarge`</a> | The query tried to process a row that was too large to write to a single frame. See the [Limits](#limits) table for specific limits on frame size. Note that the effective maximum row size is smaller than the maximum frame size due to alignment considerations during frame writing. | `maxFrameSize`: The limit on the frame size. |
| <a name="error_TaskStartTimeout">`TaskStartTimeout`</a> | Unable to launch all the worker tasks in time. <br /> <br />There might be insufficient available slots to start all the worker tasks simultaneously.<br /> <br /> Try splitting up the query into smaller chunks with lesser `maxNumTasks` number. Another option is to increase capacity. | `numTasks`: The number of tasks attempted to launch. |
| <a name="error_TooManyAttemptsForJob">`TooManyAttemptsForJob`</a> | Total relaunch attempt count across all workers exceeded max relaunch attempt limit. See the [Limits](#limits) table for the specific limit. | `maxRelaunchCount`: Max number of relaunches across all the workers defined in the [Limits](#limits) section. <br /><br /> `currentRelaunchCount`: current relaunch counter for the job across all workers. <br /><br /> `taskId`: Latest task id which failed <br /> <br /> `rootErrorMessage`: Error message of the latest failed task.|
| <a name="error_TooManyAttemptsForWorker">`TooManyAttemptsForWorker`</a> | Worker exceeded maximum relaunch attempt count as defined in the [Limits](#limits) section. |`maxPerWorkerRelaunchCount`: Max number of relaunches allowed per worker as defined in the [Limits](#limits) section. <br /><br /> `workerNumber`: the worker number for which the task failed <br /><br /> `taskId`: Latest task id which failed <br /> <br /> `rootErrorMessage`: Error message of the latest failed task.|
| <a name="error_TooManyBuckets">`TooManyBuckets`</a> | Exceeded the maximum number of partition buckets for a stage (5,000 partition buckets).<br />< br />Partition buckets are created for each [`PARTITIONED BY`](#partitioned-by) time chunk for INSERT and REPLACE queries. The most common reason for this error is that your `PARTITIONED BY` is too narrow relative to your data. | `maxBuckets`: The limit on partition buckets. |
| <a name="error_TooManyInputFiles">`TooManyInputFiles`</a> | Exceeded the maximum number of input files or segments per worker (10,000 files or segments).<br /><br />If you encounter this limit, consider adding more workers, or breaking up your query into smaller queries that process fewer files or segments per query. | `numInputFiles`: The total number of input files/segments for the stage.<br /><br />`maxInputFiles`: The maximum number of input files/segments per worker per stage.<br /><br />`minNumWorker`: The minimum number of workers required for a successful run. |
| <a name="error_TooManyPartitions">`TooManyPartitions`</a> | Exceeded the maximum number of partitions for a stage (25,000 partitions).<br /><br />This can occur with INSERT or REPLACE statements that generate large numbers of segments, since each segment is associated with a partition. If you encounter this limit, consider breaking up your INSERT or REPLACE statement into smaller statements that process less data per statement. | `maxPartitions`: The limit on partitions which was exceeded |

View File

@ -104,7 +104,7 @@ public interface Controller
/**
* Periodic update of {@link CounterSnapshots} from subtasks.
*/
void updateCounters(CounterSnapshotsTree snapshotsTree);
void updateCounters(String taskId, CounterSnapshotsTree snapshotsTree);
/**
* Reports that results are ready for a subtask.

View File

@ -47,7 +47,7 @@ public interface ControllerClient extends AutoCloseable
* Client-side method to update the controller with counters for a particular stage and worker. The controller uses
* this to compile live reports, track warnings generated etc.
*/
void postCounters(CounterSnapshotsTree snapshotsTree) throws IOException;
void postCounters(String workerId, CounterSnapshotsTree snapshotsTree) throws IOException;
/**
* Client side method to update the controller with the result object for a particular stage and worker. This also

View File

@ -27,12 +27,16 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntArraySet;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.common.guava.FutureUtils;
import org.apache.druid.data.input.StringTuple;
@ -64,7 +68,6 @@ import org.apache.druid.indexing.common.actions.TaskActionClient;
import org.apache.druid.indexing.overlord.SegmentPublishResult;
import org.apache.druid.indexing.overlord.Segments;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals;
@ -105,12 +108,13 @@ import org.apache.druid.msq.indexing.error.InsertTimeOutOfBoundsFault;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.indexing.error.MSQFaultUtils;
import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher;
import org.apache.druid.msq.indexing.error.MSQWarnings;
import org.apache.druid.msq.indexing.error.QueryNotSupportedFault;
import org.apache.druid.msq.indexing.error.TooManyPartitionsFault;
import org.apache.druid.msq.indexing.error.TooManyWarningsFault;
import org.apache.druid.msq.indexing.error.UnknownFault;
import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault;
import org.apache.druid.msq.indexing.report.MSQResultsReport;
import org.apache.druid.msq.indexing.report.MSQStagesReport;
import org.apache.druid.msq.indexing.report.MSQStatusReport;
@ -151,7 +155,6 @@ import org.apache.druid.msq.querykit.scan.ScanQueryKit;
import org.apache.druid.msq.shuffle.DurableStorageInputChannelFactory;
import org.apache.druid.msq.shuffle.DurableStorageUtils;
import org.apache.druid.msq.shuffle.WorkerInputChannelFactory;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import org.apache.druid.msq.util.DimensionSchemaUtils;
import org.apache.druid.msq.util.IntervalUtils;
@ -159,6 +162,7 @@ import org.apache.druid.msq.util.MSQFutureUtils;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.msq.util.PassthroughAggregatorFactory;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.GroupByQueryConfig;
@ -204,7 +208,6 @@ import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ThreadLocalRandom;
@ -226,7 +229,7 @@ public class ControllerImpl implements Controller
* Queue of "commands" to run on the {@link ControllerQueryKernel}. Various threads insert into the queue
* using {@link #addToKernelManipulationQueue}. The main thread running {@link RunQueryUntilDone#run()} reads
* from the queue and executes the commands.
*
* <p>
* This ensures that all manipulations on {@link ControllerQueryKernel}, and all core logic, are run in
* a single-threaded manner.
*/
@ -263,10 +266,13 @@ public class ControllerImpl implements Controller
// For live reports. Written by the main controller thread, read by HTTP threads.
private final ConcurrentHashMap<Integer, Integer> stagePartitionCountsForLiveReports = new ConcurrentHashMap<>();
private WorkerSketchFetcher workerSketchFetcher;
// Time at which the query started.
// For live reports. Written by the main controller thread, read by HTTP threads.
// WorkerNumber -> WorkOrders which need to be retried and our determined by the controller.
// Map is always populated in the main controller thread by addToRetryQueue, and pruned in retryFailedTasks.
private final Map<Integer, Set<WorkOrder>> workOrdersToRetry = new HashMap<>();
private volatile DateTime queryStartTime = null;
private volatile DruidNode selfDruidNode;
@ -275,6 +281,11 @@ public class ControllerImpl implements Controller
private volatile FaultsExceededChecker faultsExceededChecker = null;
private Map<Integer, ClusterStatisticsMergeMode> stageToStatsMergingMode;
private WorkerMemoryParameters workerMemoryParameters;
private boolean isDurableStorageEnabled;
private boolean isFaultToleranceEnabled;
public ControllerImpl(
final MSQControllerTask task,
final ControllerContext context
@ -500,17 +511,17 @@ public class ControllerImpl implements Controller
return TaskStatus.success(id());
} else {
// errorForReport is nonnull when taskStateForReport != SUCCESS. Use that message.
return TaskStatus.failure(id(), errorForReport.getFault().getCodeWithMessage());
return TaskStatus.failure(id(), MSQFaultUtils.generateMessageWithErrorCode(errorForReport.getFault()));
}
}
/**
* Adds some logic to {@link #kernelManipulationQueue}, where it will, in due time, be executed by the main
* controller loop in {@link RunQueryUntilDone#run()}.
*
* <p>
* If the consumer throws an exception, the query fails.
*/
private void addToKernelManipulationQueue(Consumer<ControllerQueryKernel> kernelConsumer)
public void addToKernelManipulationQueue(Consumer<ControllerQueryKernel> kernelConsumer)
{
if (!kernelManipulationQueue.offer(kernelConsumer)) {
final String message = "Controller kernel queue is full. Main controller loop may be delayed or stuck.";
@ -521,25 +532,42 @@ public class ControllerImpl implements Controller
private QueryDefinition initializeQueryDefAndState(final Closer closer)
{
final QueryContext queryContext = task.getQuerySpec().getQuery().context();
isFaultToleranceEnabled = MultiStageQueryContext.isFaultToleranceEnabled(queryContext);
if (isFaultToleranceEnabled) {
if (!queryContext.containsKey(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE)) {
// if context key not set, enable durableStorage automatically.
isDurableStorageEnabled = true;
} else {
// if context key is set, and durableStorage is turned on.
if (MultiStageQueryContext.isDurableStorageEnabled(queryContext)) {
isDurableStorageEnabled = true;
} else {
throw new MSQException(
UnknownFault.forMessage(
StringUtils.format(
"Context param[%s] cannot be explicitly set to false when context param[%s] is"
+ " set to true. Either remove the context param[%s] or explicitly set it to true.",
MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE,
MultiStageQueryContext.CTX_FAULT_TOLERANCE,
MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE
)));
}
}
} else {
isDurableStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(queryContext);
}
log.debug("Task [%s] durable storage mode is set to %s.", task.getId(), isDurableStorageEnabled);
log.debug("Task [%s] fault tolerance mode is set to %s.", task.getId(), isFaultToleranceEnabled);
this.selfDruidNode = context.selfNode();
context.registerController(this, closer);
this.netClient = new ExceptionWrappingWorkerClient(context.taskClientFor(this));
closer.register(netClient::close);
ClusterStatisticsMergeMode clusterStatisticsMergeMode =
MultiStageQueryContext.getClusterStatisticsMergeMode(task.getQuerySpec().getQuery().context());
log.debug("Query [%s] cluster statistics merge mode is set to %s.", id(), clusterStatisticsMergeMode);
int statisticsMaxRetainedBytes = WorkerMemoryParameters.createProductionInstanceForController(context.injector())
.getPartitionStatisticsMaxRetainedBytes();
this.workerSketchFetcher = new WorkerSketchFetcher(netClient, clusterStatisticsMergeMode, statisticsMaxRetainedBytes);
closer.register(workerSketchFetcher::close);
final boolean isDurableStorageEnabled =
MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context());
final QueryDefinition queryDef = makeQueryDefinition(
id(),
makeQueryControllerToolKit(),
@ -550,9 +578,6 @@ public class ControllerImpl implements Controller
QueryValidator.validateQueryDef(queryDef);
queryDefRef.set(queryDef);
log.debug("Query [%s] durable storage mode is set to %s.", queryDef.getQueryId(), isDurableStorageEnabled);
long maxParseExceptions = -1;
if (task.getSqlQueryContext() != null) {
@ -562,11 +587,19 @@ public class ControllerImpl implements Controller
.orElse(MSQWarnings.DEFAULT_MAX_PARSE_EXCEPTIONS_ALLOWED);
}
this.workerTaskLauncher = new MSQWorkerTaskLauncher(
id(),
task.getDataSource(),
context,
(failedTask, fault) -> {
addToKernelManipulationQueue((kernel) -> {
if (isFaultToleranceEnabled) {
addToRetryQueue(kernel, failedTask.getWorkerNumber(), fault);
} else {
throw new MSQException(fault);
}
});
},
isDurableStorageEnabled,
maxParseExceptions,
// 10 minutes +- 2 minutes jitter
@ -577,16 +610,67 @@ public class ControllerImpl implements Controller
ImmutableMap.of(CannotParseExternalDataFault.CODE, maxParseExceptions)
);
stageToStatsMergingMode = new HashMap<>();
queryDef.getStageDefinitions().forEach(
stageDefinition ->
stageToStatsMergingMode.put(
stageDefinition.getId().getStageNumber(),
finalizeClusterStatisticsMergeMode(
stageDefinition,
MultiStageQueryContext.getClusterStatisticsMergeMode(queryContext)
)
)
);
this.workerMemoryParameters = WorkerMemoryParameters.createProductionInstanceForController(context.injector());
this.workerSketchFetcher = new WorkerSketchFetcher(
netClient,
workerTaskLauncher,
isFaultToleranceEnabled
);
closer.register(workerSketchFetcher::close);
return queryDef;
}
/**
* Adds the work orders for worker to {@link ControllerImpl#workOrdersToRetry} if the {@link ControllerQueryKernel} determines that there
* are work orders which needs reprocessing.
* <br></br>
* This method is not thread safe, so it should always be called inside the main controller thread.
*/
private void addToRetryQueue(ControllerQueryKernel kernel, int worker, MSQFault fault)
{
List<WorkOrder> retriableWorkOrders = kernel.getWorkInCaseWorkerEligibleForRetryElseThrow(worker, fault);
if (retriableWorkOrders.size() != 0) {
log.info("Submitting worker[%s] for relaunch because of fault[%s]", worker, fault);
workerTaskLauncher.submitForRelaunch(worker);
workOrdersToRetry.compute(worker, (workerNumber, workOrders) -> {
if (workOrders == null) {
return new HashSet<>(retriableWorkOrders);
} else {
workOrders.addAll(retriableWorkOrders);
return workOrders;
}
});
} else {
log.info(
"Worker[%d] has no active workOrders that need relaunch therefore not relaunching",
worker
);
}
}
/**
* Accepts a {@link PartialKeyStatisticsInformation} and updates the controller key statistics information. If all key
* statistics information has been gathered, enqueues the task with the {@link WorkerSketchFetcher} to generate
* partiton boundaries. This is intended to be called by the {@link org.apache.druid.msq.indexing.ControllerChatHandler}.
*/
@Override
public void updatePartialKeyStatisticsInformation(int stageNumber, int workerNumber, Object partialKeyStatisticsInformationObject)
public void updatePartialKeyStatisticsInformation(
int stageNumber,
int workerNumber,
Object partialKeyStatisticsInformationObject
)
{
addToKernelManipulationQueue(
queryKernel -> {
@ -602,7 +686,10 @@ public class ControllerImpl implements Controller
final PartialKeyStatisticsInformation partialKeyStatisticsInformation;
try {
partialKeyStatisticsInformation = mapper.convertValue(partialKeyStatisticsInformationObject, PartialKeyStatisticsInformation.class);
partialKeyStatisticsInformation = mapper.convertValue(
partialKeyStatisticsInformationObject,
PartialKeyStatisticsInformation.class
);
}
catch (IllegalArgumentException e) {
throw new IAE(
@ -614,50 +701,18 @@ public class ControllerImpl implements Controller
}
queryKernel.addPartialKeyStatisticsForStageAndWorker(stageId, workerNumber, partialKeyStatisticsInformation);
if (queryKernel.getStagePhase(stageId).equals(ControllerStagePhase.MERGING_STATISTICS)) {
List<String> workerTaskIds = workerTaskLauncher.getTaskList();
CompleteKeyStatisticsInformation completeKeyStatisticsInformation =
queryKernel.getCompleteKeyStatisticsInformation(stageId);
// Queue the sketch fetching task into the worker sketch fetcher.
CompletableFuture<Either<Long, ClusterByPartitions>> clusterByPartitionsCompletableFuture =
workerSketchFetcher.submitFetcherTask(
completeKeyStatisticsInformation,
workerTaskIds,
stageDef,
queryKernel.getWorkerInputsForStage(stageId).workers()
// we only need tasks which are active for this stage.
);
// Add the listener to handle completion.
clusterByPartitionsCompletableFuture.whenComplete((clusterByPartitionsEither, throwable) -> {
addToKernelManipulationQueue(holder -> {
if (throwable != null) {
log.error("Error while fetching stats for stageId[%s]", stageId);
if (throwable instanceof MSQException) {
holder.failStageForReason(stageId, ((MSQException) throwable).getFault());
} else {
holder.failStageForReason(stageId, UnknownFault.forException(throwable));
}
} else if (clusterByPartitionsEither.isError()) {
holder.failStageForReason(stageId, new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
} else {
log.debug("Query [%s] Partition boundaries generated for stage %s", id(), stageId);
holder.setClusterByPartitionBoundaries(stageId, clusterByPartitionsEither.valueOrThrow());
}
holder.transitionStageKernel(stageId, queryKernel.getStagePhase(stageId));
});
});
}
}
);
}
@Override
public void workerError(MSQErrorReport errorReport)
{
if (!workerTaskLauncher.isTaskCanceledByController(errorReport.getTaskId())) {
if (workerTaskLauncher.isTaskCanceledByController(errorReport.getTaskId()) ||
!workerTaskLauncher.isTaskLatest(errorReport.getTaskId())) {
log.info("Ignoring task %s", errorReport.getTaskId());
} else {
workerErrorRef.compareAndSet(null, errorReport);
}
}
@ -693,7 +748,7 @@ public class ControllerImpl implements Controller
* Periodic update of {@link CounterSnapshots} from subtasks.
*/
@Override
public void updateCounters(CounterSnapshotsTree snapshotsTree)
public void updateCounters(String taskId, CounterSnapshotsTree snapshotsTree)
{
taskCountersForLiveReports.putAll(snapshotsTree);
Optional<Pair<String, Long>> warningsExceeded =
@ -705,7 +760,7 @@ public class ControllerImpl implements Controller
Long limit = warningsExceeded.get().rhs;
workerError(MSQErrorReport.fromFault(
id(),
taskId,
selfDruidNode.getHost(),
null,
new TooManyWarningsFault(limit.intValue(), errorCode)
@ -960,7 +1015,7 @@ public class ControllerImpl implements Controller
/**
* Returns a complete list of task ids, ordered by worker number. The Nth task has worker number N.
*
* <p>
* If the currently-running set of tasks is incomplete, returns an absent Optional.
*/
@Override
@ -970,7 +1025,7 @@ public class ControllerImpl implements Controller
return Collections.emptyList();
}
return workerTaskLauncher.getTaskList();
return workerTaskLauncher.getActiveTasks();
}
@SuppressWarnings({"unchecked", "rawtypes"})
@ -1040,17 +1095,82 @@ public class ControllerImpl implements Controller
return retVal;
}
private void contactWorkersForStage(final TaskContactFn contactFn, final IntSet workers)
/**
* A blocking function used to contact multiple workers. Checks if all the workers are running before contacting them.
*
* @param queryKernel
* @param contactFn
* @param workers set of workers to contact
* @param successCallBack After contacting all the tasks, a custom callback is invoked in the main thread for each successfully contacted task.
* @param retryOnFailure If true, after contacting all the tasks, adds this worker to retry queue in the main thread.
* If false, cancel all the futures and propagate the exception to the caller.
*/
private void contactWorkersForStage(
final ControllerQueryKernel queryKernel,
final TaskContactFn contactFn,
final IntSet workers,
final TaskContactSuccess successCallBack,
final boolean retryOnFailure
)
{
final List<String> taskIds = getTaskIds();
final List<ListenableFuture<Void>> taskFutures = new ArrayList<>(workers.size());
final List<ListenableFuture<Boolean>> taskFutures = new ArrayList<>(workers.size());
try {
workerTaskLauncher.waitUntilWorkersReady(workers);
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
Set<String> failedCalls = ConcurrentHashMap.newKeySet();
Set<String> successfulCalls = ConcurrentHashMap.newKeySet();
for (int workerNumber : workers) {
final String taskId = taskIds.get(workerNumber);
taskFutures.add(contactFn.contactTask(netClient, taskId, workerNumber));
SettableFuture<Boolean> settableFuture = SettableFuture.create();
ListenableFuture<Void> apiFuture = contactFn.contactTask(netClient, taskId, workerNumber);
Futures.addCallback(apiFuture, new FutureCallback<Void>()
{
@Override
public void onSuccess(@Nullable Void result)
{
successfulCalls.add(taskId);
settableFuture.set(true);
}
@Override
public void onFailure(Throwable t)
{
if (retryOnFailure) {
log.info(
t,
"Detected failure while contacting task[%s]. Initiating relaunch of worker[%d] if applicable",
taskId,
MSQTasks.workerFromTaskId(taskId)
);
failedCalls.add(taskId);
settableFuture.set(false);
} else {
settableFuture.setException(t);
}
}
});
taskFutures.add(settableFuture);
}
FutureUtils.getUnchecked(MSQFutureUtils.allAsList(taskFutures, true), true);
for (String taskId : successfulCalls) {
successCallBack.onSuccess(taskId);
}
if (retryOnFailure) {
for (String taskId : failedCalls) {
addToRetryQueue(queryKernel, MSQTasks.workerFromTaskId(taskId), new WorkerRpcFailedFault(taskId));
}
}
}
private void startWorkForStage(
@ -1068,35 +1188,45 @@ public class ControllerImpl implements Controller
);
final Int2ObjectMap<WorkOrder> workOrders = queryKernel.createWorkOrders(stageNumber, extraInfos);
final StageId stageId = new StageId(queryDef.getQueryId(), stageNumber);
queryKernel.startStage(stageId);
contactWorkersForStage(
(netClient, taskId, workerNumber) -> netClient.postWorkOrder(taskId, workOrders.get(workerNumber)),
workOrders.keySet()
queryKernel,
(netClient, taskId, workerNumber) -> (
netClient.postWorkOrder(taskId, workOrders.get(workerNumber))), workOrders.keySet(),
(taskId) -> queryKernel.workOrdersSentForWorker(stageId, MSQTasks.workerFromTaskId(taskId)),
isFaultToleranceEnabled
);
}
private void postResultPartitionBoundariesForStage(
final ControllerQueryKernel queryKernel,
final QueryDefinition queryDef,
final int stageNumber,
final ClusterByPartitions resultPartitionBoundaries,
final IntSet workers
)
{
final StageId stageId = new StageId(queryDef.getQueryId(), stageNumber);
contactWorkersForStage(
(netClient, taskId, workerNumber) ->
netClient.postResultPartitionBoundaries(
taskId,
new StageId(queryDef.getQueryId(), stageNumber),
resultPartitionBoundaries
),
workers
queryKernel,
(netClient, taskId, workerNumber) -> netClient.postResultPartitionBoundaries(
taskId,
stageId,
resultPartitionBoundaries
),
workers,
(taskId) -> queryKernel.partitionBoundariesSentForWorker(stageId, MSQTasks.workerFromTaskId(taskId)),
isFaultToleranceEnabled
);
}
/**
* Publish the list of segments. Additionally, if {@link DataSourceMSQDestination#isReplaceTimeChunks()},
* also drop all other segments within the replacement intervals.
*
* <p>
* If any existing segments cannot be dropped because their intervals are not wholly contained within the
* replacement parameter, throws a {@link MSQException} with {@link InsertCannotReplaceExistingSegmentFault}.
*/
@ -1178,7 +1308,7 @@ public class ControllerImpl implements Controller
private CounterSnapshotsTree getCountersFromAllTasks()
{
final CounterSnapshotsTree retVal = new CounterSnapshotsTree();
final List<String> taskList = workerTaskLauncher.getTaskList();
final List<String> taskList = getTaskIds();
final List<ListenableFuture<CounterSnapshotsTree>> futures = new ArrayList<>();
@ -1198,7 +1328,7 @@ public class ControllerImpl implements Controller
private void postFinishToAllTasks()
{
final List<String> taskList = workerTaskLauncher.getTaskList();
final List<String> taskList = getTaskIds();
final List<ListenableFuture<Void>> futures = new ArrayList<>();
@ -1241,7 +1371,7 @@ public class ControllerImpl implements Controller
final InputChannelFactory inputChannelFactory;
if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context())) {
if (isDurableStorageEnabled) {
inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation(
id(),
MSQTasks.makeStorageConnector(context.injector()),
@ -1336,14 +1466,14 @@ public class ControllerImpl implements Controller
/**
* Clean up durable storage, if used for stage output.
*
* <p>
* Note that this is only called by the controller task itself. It isn't called automatically by anything in
* particular if the controller fails early without being able to run its cleanup routines. This can cause files
* to be left in durable storage beyond their useful life.
*/
private void cleanUpDurableStorageIfNeeded()
{
if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context())) {
if (isDurableStorageEnabled) {
final String controllerDirName = DurableStorageUtils.getControllerDirectory(task.getId());
try {
// Delete all temporary files as a failsafe
@ -1531,7 +1661,7 @@ public class ControllerImpl implements Controller
/**
* Checks that a {@link GroupByQuery} is grouping on the primary time column.
*
* <p>
* The logic here is roundabout. First, we check which column in the {@link GroupByQuery} corresponds to the
* output column {@link ColumnHolder#TIME_COLUMN_NAME}, using our {@link ColumnMappings}. Then, we check for the
* presence of an optimization done in {@link DruidQuery#toGroupByQuery()}, where the context parameter
@ -1551,9 +1681,9 @@ public class ControllerImpl implements Controller
/**
* Whether a native query represents an ingestion with rollup.
*
* <p>
* Checks for three things:
*
* <p>
* - The query must be a {@link GroupByQuery}, because rollup requires columns to be split into dimensions and
* aggregations.
* - The query must not finalize aggregations, because rollup requires inserting the intermediate type of
@ -1907,7 +2037,7 @@ public class ControllerImpl implements Controller
/**
* Performs a particular {@link SegmentTransactionalInsertAction}, publishing segments.
*
* <p>
* Throws {@link MSQException} with {@link InsertLockPreemptedFault} if the action fails due to lock preemption.
*/
static void performSegmentPublish(
@ -1935,7 +2065,7 @@ public class ControllerImpl implements Controller
* Method that determines whether an exception was raised due to the task lock for the controller task being
* preempted. Uses string comparison, because the relevant Overlord APIs do not have a more reliable way of
* discerning the cause of errors.
*
* <p>
* Error strings are taken from {@link org.apache.druid.indexing.common.actions.TaskLocks}
* and {@link SegmentAllocateAction}.
*/
@ -1987,11 +2117,6 @@ public class ControllerImpl implements Controller
private final Closer closer;
private final ControllerQueryKernel queryKernel;
/**
* Set of stages that have got their partition boundaries sent out.
*/
private final Set<StageId> stageResultPartitionBoundariesSent = new HashSet<>();
/**
* Return value of {@link MSQWorkerTaskLauncher#start()}. Set by {@link #startTaskLauncher()}.
*/
@ -2013,7 +2138,11 @@ public class ControllerImpl implements Controller
this.queryDef = queryDef;
this.inputSpecSlicerFactory = inputSpecSlicerFactory;
this.closer = closer;
this.queryKernel = new ControllerQueryKernel(queryDef);
this.queryKernel = new ControllerQueryKernel(
queryDef,
workerMemoryParameters.getPartitionStatisticsMaxRetainedBytes(),
isFaultToleranceEnabled
);
}
/**
@ -2025,9 +2154,12 @@ public class ControllerImpl implements Controller
while (!queryKernel.isDone()) {
startStages();
fetchStatsFromWorkers();
sendPartitionBoundaries();
updateLiveReportMaps();
cleanUpEffectivelyFinishedStages();
retryFailedTasks();
checkForErrorsInSketchFetcher();
runKernelCommands();
}
@ -2039,6 +2171,77 @@ public class ControllerImpl implements Controller
return Pair.of(queryKernel, workerTaskLauncherFuture);
}
private void checkForErrorsInSketchFetcher()
{
Throwable throwable = workerSketchFetcher.getError();
if (throwable != null) {
throw new ISE(throwable, "worker sketch fetch failed");
}
}
private void retryFailedTasks() throws InterruptedException
{
// if no work orders to rety skip
if (workOrdersToRetry.size() == 0) {
return;
}
Set<Integer> workersNeedToBeFullyStarted = new HashSet<>();
// transform work orders from map<Worker,Set<WorkOrders> to Map<StageId,Map<Worker,WorkOrder>>
// since we would want workOrders of processed per stage
Map<StageId, Map<Integer, WorkOrder>> stageWorkerOrders = new HashMap<>();
for (Map.Entry<Integer, Set<WorkOrder>> workerStages : workOrdersToRetry.entrySet()) {
workersNeedToBeFullyStarted.add(workerStages.getKey());
for (WorkOrder workOrder : workerStages.getValue()) {
stageWorkerOrders.compute(
new StageId(queryDef.getQueryId(), workOrder.getStageNumber()),
(stageId, workOrders) -> {
if (workOrders == null) {
workOrders = new HashMap<Integer, WorkOrder>();
}
workOrders.put(workerStages.getKey(), workOrder);
return workOrders;
}
);
}
}
// wait till the workers identified above are fully ready
workerTaskLauncher.waitUntilWorkersReady(workersNeedToBeFullyStarted);
for (Map.Entry<StageId, Map<Integer, WorkOrder>> stageWorkOrders : stageWorkerOrders.entrySet()) {
contactWorkersForStage(
queryKernel,
(netClient, taskId, workerNumber) -> netClient.postWorkOrder(
taskId,
stageWorkOrders.getValue().get(workerNumber)
),
new IntArraySet(stageWorkOrders.getValue().keySet()),
(taskId) -> {
int workerNumber = MSQTasks.workerFromTaskId(taskId);
queryKernel.workOrdersSentForWorker(stageWorkOrders.getKey(), workerNumber);
// remove successfully contacted workOrders from workOrdersToRetry
workOrdersToRetry.compute(workerNumber, (task, workOrderSet) -> {
if (workOrderSet == null || workOrderSet.size() == 0 || !workOrderSet.remove(stageWorkOrders.getValue()
.get(
workerNumber))) {
throw new ISE("Worker[%d] orders not found", workerNumber);
}
if (workOrderSet.size() == 0) {
return null;
}
return workOrderSet;
});
},
isFaultToleranceEnabled
);
}
}
/**
* Run at least one command from {@link #kernelManipulationQueue}, waiting for it if necessary.
*/
@ -2079,6 +2282,66 @@ public class ControllerImpl implements Controller
);
}
/**
* Enqueues the fetching {@link org.apache.druid.msq.statistics.ClusterByStatisticsCollector}
* from each worker via {@link WorkerSketchFetcher}
*/
private void fetchStatsFromWorkers()
{
for (Map.Entry<StageId, Set<Integer>> stageToWorker : queryKernel.getStagesAndWorkersToFetchClusterStats()
.entrySet()) {
List<String> allTasks = workerTaskLauncher.getActiveTasks();
Set<String> tasks = stageToWorker.getValue().stream().map(allTasks::get).collect(Collectors.toSet());
ClusterStatisticsMergeMode clusterStatisticsMergeMode = stageToStatsMergingMode.get(stageToWorker.getKey()
.getStageNumber());
switch (clusterStatisticsMergeMode) {
case SEQUENTIAL:
submitSequentialMergeFetchRequests(stageToWorker.getKey(), tasks);
break;
case PARALLEL:
submitParallelMergeRequests(stageToWorker.getKey(), tasks);
break;
default:
throw new IllegalStateException("No fetching strategy found for mode: " + clusterStatisticsMergeMode);
}
}
}
private void submitParallelMergeRequests(StageId stageId, Set<String> tasks)
{
// eagerly change state of workers whose state is being fetched so that we do not keep on queuing fetch requests.
queryKernel.startFetchingStatsFromWorker(
stageId,
tasks.stream().map(MSQTasks::workerFromTaskId).collect(Collectors.toSet())
);
workerSketchFetcher.inMemoryFullSketchMerging(ControllerImpl.this::addToKernelManipulationQueue,
stageId, tasks,
ControllerImpl.this::addToRetryQueue
);
}
private void submitSequentialMergeFetchRequests(StageId stageId, Set<String> tasks)
{
if (queryKernel.allPartialKeyInformationPresent(stageId)) {
// eagerly change state of workers whose state is being fetched so that we do not keep on queuing fetch requests.
queryKernel.startFetchingStatsFromWorker(
stageId,
tasks.stream()
.map(MSQTasks::workerFromTaskId)
.collect(Collectors.toSet())
);
workerSketchFetcher.sequentialTimeChunkMerging(
ControllerImpl.this::addToKernelManipulationQueue,
queryKernel.getCompleteKeyStatisticsInformation(stageId),
stageId, tasks,
ControllerImpl.this::addToRetryQueue
);
}
}
/**
* Start up any stages that are ready to start.
*/
@ -2091,7 +2354,6 @@ public class ControllerImpl implements Controller
);
for (final StageId stageId : newStageIds) {
queryKernel.startStage(stageId);
// Allocate segments, if this is the final stage of an ingestion.
if (MSQControllerTask.isIngestion(task.getQuerySpec())
@ -2160,31 +2422,39 @@ public class ControllerImpl implements Controller
for (final StageId stageId : queryKernel.getActiveStages()) {
if (queryKernel.getStageDefinition(stageId).mustGatherResultKeyStatistics()
&& queryKernel.doesStageHaveResultPartitions(stageId)
&& stageResultPartitionBoundariesSent.add(stageId)) {
&& queryKernel.doesStageHaveResultPartitions(stageId)) {
IntSet workersToSendPartitionBoundaries = queryKernel.getWorkersToSendPartitionBoundaries(stageId);
if (workersToSendPartitionBoundaries.isEmpty()) {
log.debug("No workers for stage[%s] ready to receive partition boundaries", stageId);
continue;
}
final ClusterByPartitions partitions = queryKernel.getResultPartitionBoundariesForStage(stageId);
if (log.isDebugEnabled()) {
final ClusterByPartitions partitions = queryKernel.getResultPartitionBoundariesForStage(stageId);
log.debug(
"Query [%s] sending out partition boundaries for stage %d: %s",
"Query [%s] sending out partition boundaries for stage %d: %s for workers %s",
stageId.getQueryId(),
stageId.getStageNumber(),
IntStream.range(0, partitions.size())
.mapToObj(i -> StringUtils.format("%s:%s", i, partitions.get(i)))
.collect(Collectors.joining(", "))
.collect(Collectors.joining(", ")),
workersToSendPartitionBoundaries.toString()
);
} else {
log.info(
"Query [%s] sending out partition boundaries for stage %d.",
"Query [%s] sending out partition boundaries for stage %d for workers %s",
stageId.getQueryId(),
stageId.getStageNumber()
stageId.getStageNumber(),
workersToSendPartitionBoundaries.toString()
);
}
postResultPartitionBoundariesForStage(
queryKernel,
queryDef,
stageId.getStageNumber(),
queryKernel.getResultPartitionBoundariesForStage(stageId),
queryKernel.getWorkerInputsForStage(stageId).workers()
partitions,
workersToSendPartitionBoundaries
);
}
}
@ -2240,8 +2510,11 @@ public class ControllerImpl implements Controller
for (final StageId stageId : queryKernel.getEffectivelyFinishedStageIds()) {
log.info("Query [%s] issuing cleanup order for stage %d.", queryDef.getQueryId(), stageId.getStageNumber());
contactWorkersForStage(
queryKernel,
(netClient, taskId, workerNumber) -> netClient.postCleanupStage(taskId, stageId),
queryKernel.getWorkerInputsForStage(stageId).workers()
queryKernel.getWorkerInputsForStage(stageId).workers(),
(ignore1) -> {},
false
);
queryKernel.finishStage(stageId, true);
}
@ -2267,6 +2540,31 @@ public class ControllerImpl implements Controller
}
}
static ClusterStatisticsMergeMode finalizeClusterStatisticsMergeMode(
StageDefinition stageDef,
ClusterStatisticsMergeMode initialMode
)
{
ClusterStatisticsMergeMode mergeMode = initialMode;
if (initialMode == ClusterStatisticsMergeMode.AUTO) {
ClusterBy clusterBy = stageDef.getClusterBy();
if (clusterBy.getBucketByCount() == 0) {
// If there is no time clustering, there is no scope for sequential merge
mergeMode = ClusterStatisticsMergeMode.PARALLEL;
} else if (stageDef.getMaxWorkerCount() > Limits.MAX_WORKERS_FOR_PARALLEL_MERGE) {
mergeMode = ClusterStatisticsMergeMode.SEQUENTIAL;
} else {
mergeMode = ClusterStatisticsMergeMode.PARALLEL;
}
log.info(
"Stage [%d] AUTO mode: chose %s mode to merge key statistics",
stageDef.getStageNumber(),
mergeMode
);
}
return mergeMode;
}
/**
* Interface used by {@link #contactWorkersForStage}.
*/
@ -2274,4 +2572,13 @@ public class ControllerImpl implements Controller
{
ListenableFuture<Void> contactTask(WorkerClient client, String taskId, int workerNumber);
}
/**
* Interface used when {@link TaskContactFn#contactTask(WorkerClient, String, int)} returns a successful future.
*/
private interface TaskContactSuccess
{
void onSuccess(String taskId);
}
}

View File

@ -23,7 +23,7 @@ public class Limits
{
/**
* Maximum number of columns that can appear in a frame signature.
*
* <p>
* Somewhat less than {@link WorkerMemoryParameters#STANDARD_FRAME_SIZE} divided by typical minimum column size:
* {@link org.apache.druid.frame.allocation.AppendableMemory#DEFAULT_INITIAL_ALLOCATION_SIZE}.
*/
@ -68,4 +68,20 @@ public class Limits
* Maximum size of the kernel manipulation queue in {@link org.apache.druid.msq.indexing.MSQControllerTask}.
*/
public static final int MAX_KERNEL_MANIPULATION_QUEUE_SIZE = 100_000;
/**
* Maximum relaunches across all workers.
*/
public static final int TOTAL_RELAUNCH_LIMIT = 100;
/**
* Maximum relaunches per worker. Initial run is not a relaunch. The worker will be spawned 1 + workerRelaunchLimit times before erroring out.
*/
public static final int PER_WORKER_RELAUNCH_LIMIT = 2;
/**
* Max number of workers for {@link ClusterStatisticsMergeMode#PARALLEL}. If the number of workers is more than this,
* {@link ClusterStatisticsMergeMode#SEQUENTIAL} mode is chosen.
*/
public static final long MAX_WORKERS_FOR_PARALLEL_MERGE = 100;
}

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.inject.Injector;
import com.google.inject.Key;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.msq.guice.MultiStageQuery;
import org.apache.druid.msq.indexing.error.CanceledFault;
@ -31,6 +32,7 @@ import org.apache.druid.msq.indexing.error.InsertTimeNullFault;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.indexing.error.MSQFaultUtils;
import org.apache.druid.msq.indexing.error.UnknownFault;
import org.apache.druid.msq.indexing.error.WorkerFailedFault;
import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault;
@ -44,6 +46,8 @@ import org.apache.druid.storage.StorageConnector;
import javax.annotation.Nullable;
import java.util.UUID;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class MSQTasks
{
@ -54,6 +58,10 @@ public class MSQTasks
private static final String TASK_ID_PREFIX = "query-";
private static final String WORKER_NUMBER = "workerNumber";
// taskids are in the form 12dsa1-worker9_0. see method workerTaskId() for more details.
private static final Pattern WORKER_PATTERN = Pattern.compile(".*-worker(?<" + WORKER_NUMBER + ">[0-9]+)_[0-9]+");
/**
* Returns a controller task ID given a SQL query id.
*/
@ -65,9 +73,32 @@ public class MSQTasks
/**
* Returns a worker task ID given a SQL query id.
*/
public static String workerTaskId(final String controllerTaskId, final int workerNumber)
public static String workerTaskId(final String controllerTaskId, final int workerNumber, final int retryCount)
{
return StringUtils.format("%s-worker%d", controllerTaskId, workerNumber);
return StringUtils.format("%s-worker%d_%d", controllerTaskId, workerNumber, retryCount);
}
/**
* Extract worker from taskId or throw exception if unable to parse out the worker.
*/
public static int workerFromTaskId(final String taskId)
{
final Matcher matcher = WORKER_PATTERN.matcher(taskId);
if (matcher.matches()) {
try {
String worker = matcher.group(WORKER_NUMBER);
return Integer.parseInt(worker);
}
catch (NumberFormatException e) {
throw new ISE(e, "Unable to parse worker out of task %s", taskId);
}
} else {
throw new ISE(
"Desired pattern %s to extract worker from task id %s did not match ",
WORKER_PATTERN.pattern(),
taskId
);
}
}
/**
@ -194,7 +225,7 @@ public class MSQTasks
logMessage.append("; host ").append(errorReport.getHost());
}
logMessage.append(": ").append(errorReport.getFault().getCodeWithMessage());
logMessage.append(": ").append(MSQFaultUtils.generateMessageWithErrorCode(errorReport.getFault()));
if (errorReport.getExceptionStackTrace() != null) {
if (errorReport.getFault() instanceof UnknownFault) {

View File

@ -71,6 +71,7 @@ import org.apache.druid.msq.indexing.error.CanceledFault;
import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.MSQFaultUtils;
import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher;
import org.apache.druid.msq.indexing.error.MSQWarningReportPublisher;
import org.apache.druid.msq.indexing.error.MSQWarningReportSimplePublisher;
@ -233,7 +234,7 @@ public class WorkerImpl implements Worker
}
});
return TaskStatus.failure(id(), errorReport.getFault().getCodeWithMessage());
return TaskStatus.failure(id(), MSQFaultUtils.generateMessageWithErrorCode(errorReport.getFault()));
} else {
return TaskStatus.success(id());
}
@ -261,7 +262,7 @@ public class WorkerImpl implements Worker
// Delete all the stage outputs
closer.register(() -> {
for (final StageId stageId : stageOutputs.keySet()) {
cleanStageOutput(stageId);
cleanStageOutput(stageId, false);
}
});
@ -516,11 +517,12 @@ public class WorkerImpl implements Worker
throw new ISE("Worker number mismatch: expected [%d]", task.getWorkerNumber());
}
// Do not add to queue if workerOrder already present.
kernelManipulationQueue.add(
kernelHolder ->
kernelHolder.getStageKernelMap().computeIfAbsent(
kernelHolder.getStageKernelMap().putIfAbsent(
workOrder.getStageDefinition().getId(),
ignored -> WorkerStageKernel.create(workOrder)
WorkerStageKernel.create(workOrder)
)
);
}
@ -538,10 +540,18 @@ public class WorkerImpl implements Worker
kernelHolder -> {
final WorkerStageKernel stageKernel = kernelHolder.getStageKernelMap().get(stageId);
// Ignore the update if we don't have a kernel for this stage.
if (stageKernel != null) {
stageKernel.setResultPartitionBoundaries(stagePartitionBoundaries);
if (!stageKernel.hasResultPartitionBoundaries()) {
stageKernel.setResultPartitionBoundaries(stagePartitionBoundaries);
} else {
// Ignore if partition boundaries are already set.
log.warn(
"Stage[%s] already has result partition boundaries set. Ignoring the latest partition boundaries recieved.",
stageId
);
}
} else {
// Ignore the update if we don't have a kernel for this stage.
log.warn("Ignored result partition boundaries call for unknown stage [%s]", stageId);
}
}
@ -555,7 +565,7 @@ public class WorkerImpl implements Worker
log.info("Cleanup order for stage: [%s] received", stageId);
kernelManipulationQueue.add(
holder -> {
cleanStageOutput(stageId);
cleanStageOutput(stageId, true);
// Mark the stage as FINISHED
holder.getStageKernelMap().get(stageId).setStageFinished();
}
@ -726,7 +736,7 @@ public class WorkerImpl implements Worker
final CounterSnapshotsTree snapshotsTree = getCounters();
if (controllerAlive && !snapshotsTree.isEmpty()) {
controllerClient.postCounters(snapshotsTree);
controllerClient.postCounters(id(), snapshotsTree);
}
}
@ -735,7 +745,7 @@ public class WorkerImpl implements Worker
* the readable channels corresponding to all the partitions for that stage, and removes it from the {@code stageOutputs}
* map
*/
private void cleanStageOutput(final StageId stageId)
private void cleanStageOutput(final StageId stageId, boolean removeDurableStorageFiles)
{
// This code is thread-safe because remove() on ConcurrentHashMap will remove and return the removed channel only for
// one thread. For the other threads it will return null, therefore we will call doneReading for a channel only once
@ -755,7 +765,7 @@ public class WorkerImpl implements Worker
// temp directories where intermediate results were stored, it won't be the case for the external storage.
// Therefore, the logic for cleaning the stage output in case of a worker/machine crash has to be external.
// We currently take care of this in the controller.
if (durableStageStorageEnabled) {
if (durableStageStorageEnabled && removeDurableStorageFiles) {
final String folderName = DurableStorageUtils.getTaskIdOutputsFolderName(
task.getControllerTaskId(),
stageId.getStageNumber(),

View File

@ -32,7 +32,7 @@ import java.util.Set;
*/
public interface WorkerManagerClient extends Closeable
{
String run(String controllerId, MSQWorkerTask task);
String run(String taskId, MSQWorkerTask task);
/**
* @param workerId the task ID

View File

@ -19,29 +19,33 @@
package org.apache.druid.msq.exec;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.Either;
import com.google.common.util.concurrent.SettableFuture;
import org.apache.druid.common.guava.FutureUtils;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.function.TriConsumer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher;
import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
/**
* Queues up fetching sketches from workers and progressively generates partitions boundaries.
@ -50,73 +54,26 @@ public class WorkerSketchFetcher implements AutoCloseable
{
private static final Logger log = new Logger(WorkerSketchFetcher.class);
private static final int DEFAULT_THREAD_COUNT = 4;
// If the combined size of worker sketches is more than this threshold, SEQUENTIAL merging mode is used.
static final long BYTES_THRESHOLD = 1_000_000_000L;
// If there are more workers than this threshold, SEQUENTIAL merging mode is used.
static final long WORKER_THRESHOLD = 100;
private final ClusterStatisticsMergeMode clusterStatisticsMergeMode;
private final int statisticsMaxRetainedBytes;
private final WorkerClient workerClient;
private final ExecutorService executorService;
private final MSQWorkerTaskLauncher workerTaskLauncher;
private final boolean retryEnabled;
private AtomicReference<Throwable> isError = new AtomicReference<>();
final ExecutorService executorService;
public WorkerSketchFetcher(
WorkerClient workerClient,
ClusterStatisticsMergeMode clusterStatisticsMergeMode,
int statisticsMaxRetainedBytes
MSQWorkerTaskLauncher workerTaskLauncher,
boolean retryEnabled
)
{
this.workerClient = workerClient;
this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
this.executorService = Execs.multiThreaded(DEFAULT_THREAD_COUNT, "SketchFetcherThreadPool-%d");
this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes;
}
/**
* Submits a request to fetch and generate partitions for the given worker statistics and returns a future for it. It
* decides based on the statistics if it should fetch sketches one by one or together.
*/
public CompletableFuture<Either<Long, ClusterByPartitions>> submitFetcherTask(
CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
List<String> workerTaskIds,
StageDefinition stageDefinition,
IntSet workersForStage
)
{
ClusterBy clusterBy = stageDefinition.getClusterBy();
switch (clusterStatisticsMergeMode) {
case SEQUENTIAL:
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
case PARALLEL:
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
case AUTO:
if (clusterBy.getBucketByCount() == 0) {
log.info(
"Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge key statistics",
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
// If there is no time clustering, there is no scope for sequential merge
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
} else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD
|| completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
log.info(
"Query[%s] stage[%d] for AUTO mode: chose SEQUENTIAL mode to merge key statistics",
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
}
log.info(
"Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge key statistics",
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
default:
throw new IllegalStateException("No fetching strategy found for mode: " + clusterStatisticsMergeMode);
}
this.workerTaskLauncher = workerTaskLauncher;
this.retryEnabled = retryEnabled;
}
/**
@ -124,248 +81,209 @@ public class WorkerSketchFetcher implements AutoCloseable
* This is faster than fetching them timechunk by timechunk but the collector will be downsampled till it can fit
* on the controller, resulting in less accurate partition boundries.
*/
CompletableFuture<Either<Long, ClusterByPartitions>> inMemoryFullSketchMerging(
StageDefinition stageDefinition,
List<String> workerTaskIds,
IntSet workersForStage
public void inMemoryFullSketchMerging(
Consumer<Consumer<ControllerQueryKernel>> kernelActions,
StageId stageId,
Set<String> taskIds,
TriConsumer<ControllerQueryKernel, Integer, MSQFault> retryOperation
)
{
CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture = new CompletableFuture<>();
// Create a new key statistics collector to merge worker sketches into
final ClusterByStatisticsCollector mergedStatisticsCollector =
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
final int workerCount = workersForStage.size();
// Guarded by synchronized mergedStatisticsCollector
final Set<Integer> finishedWorkers = new HashSet<>();
for (String taskId : taskIds) {
try {
int workerNumber = MSQTasks.workerFromTaskId(taskId);
executorService.submit(() -> {
fetchStatsFromWorker(
kernelActions,
() -> workerClient.fetchClusterByStatisticsSnapshot(
taskId,
stageId.getQueryId(),
stageId.getStageNumber()
),
taskId,
(kernel, snapshot) -> kernel.mergeClusterByStatisticsCollectorForAllTimeChunks(
stageId,
workerNumber,
snapshot
),
retryOperation
);
});
}
catch (RejectedExecutionException rejectedExecutionException) {
if (isError.get() == null) {
throw rejectedExecutionException;
} else {
// throw worker error exception
throw new ISE("Unable to fetch partitions %s", isError.get());
}
}
}
}
log.info(
"Fetching stats using %s for stage[%d] for workers[%s] ",
ClusterStatisticsMergeMode.PARALLEL,
stageDefinition.getStageNumber(),
workersForStage.stream().map(Object::toString).collect(Collectors.joining(","))
);
private void fetchStatsFromWorker(
Consumer<Consumer<ControllerQueryKernel>> kernelActions,
Supplier<ListenableFuture<ClusterByStatisticsSnapshot>> fetchStatsSupplier,
String taskId,
BiConsumer<ControllerQueryKernel, ClusterByStatisticsSnapshot> successKernelOperation,
TriConsumer<ControllerQueryKernel, Integer, MSQFault> retryOperation
)
{
if (isError.get() != null) {
executorService.shutdownNow();
return;
}
int worker = MSQTasks.workerFromTaskId(taskId);
try {
workerTaskLauncher.waitUntilWorkersReady(ImmutableSet.of(worker));
}
catch (InterruptedException interruptedException) {
isError.compareAndSet(null, interruptedException);
executorService.shutdownNow();
return;
}
// Submit a task for each worker to fetch statistics
workersForStage.forEach(workerNo -> {
executorService.submit(() -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
workerClient.fetchClusterByStatisticsSnapshot(
workerTaskIds.get(workerNo),
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
// if task is not the latest task. It must have retried.
if (!workerTaskLauncher.isTaskLatest(taskId)) {
log.info(
"Task[%s] is no longer the latest task for worker[%d], hence ignoring fetching stats from this worker",
taskId,
worker
);
return;
}
ListenableFuture<ClusterByStatisticsSnapshot> fetchFuture = fetchStatsSupplier.get();
SettableFuture<Boolean> kernelActionFuture = SettableFuture.create();
Futures.addCallback(fetchFuture, new FutureCallback<ClusterByStatisticsSnapshot>()
{
@Override
public void onSuccess(@Nullable ClusterByStatisticsSnapshot result)
{
try {
ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = snapshotFuture.get();
if (clusterByStatisticsSnapshot == null) {
throw new ISE("Worker %s returned null sketch, this should never happen", workerNo);
}
synchronized (mergedStatisticsCollector) {
mergedStatisticsCollector.addAll(clusterByStatisticsSnapshot);
finishedWorkers.add(workerNo);
if (finishedWorkers.size() == workerCount) {
log.debug("Query [%s] Received all statistics, generating partitions", stageDefinition.getId().getQueryId());
partitionFuture.complete(stageDefinition.generatePartitionsForShuffle(mergedStatisticsCollector));
kernelActions.accept((queryKernel) -> {
try {
successKernelOperation.accept(queryKernel, result);
// we do not want to have too many key collector sketches in the event queue as that cause memory issues
// blocking the executor service thread until the kernel operation is finished.
// so we would have utmost DEFAULT_THREAD_COUNT number of sketches in the queue.
kernelActionFuture.set(true);
}
}
catch (Exception e) {
failFutureAndShutDownExecutorService(e, taskId, kernelActionFuture);
}
});
}
catch (Exception e) {
synchronized (mergedStatisticsCollector) {
if (!partitionFuture.isDone()) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
}
}
failFutureAndShutDownExecutorService(e, taskId, kernelActionFuture);
}
});
}
@Override
public void onFailure(Throwable t)
{
if (retryEnabled) {
//add to retry queue
try {
kernelActions.accept((kernel) -> {
try {
retryOperation.accept(kernel, worker, new WorkerRpcFailedFault(taskId));
kernelActionFuture.set(false);
}
catch (Exception e) {
failFutureAndShutDownExecutorService(e, taskId, kernelActionFuture);
}
});
kernelActionFuture.set(false);
}
catch (Exception e) {
failFutureAndShutDownExecutorService(e, taskId, kernelActionFuture);
}
} else {
failFutureAndShutDownExecutorService(t, taskId, kernelActionFuture);
}
}
});
return partitionFuture;
FutureUtils.getUnchecked(kernelActionFuture, true);
}
private void failFutureAndShutDownExecutorService(
Throwable t,
String taskId,
SettableFuture<Boolean> kernelActionFuture
)
{
if (isError.compareAndSet(null, t)) {
log.error(t, "Failed while fetching stats from task[%s]", taskId);
}
executorService.shutdownNow();
kernelActionFuture.setException(t);
}
/**
* Fetches cluster statistics from all workers and generates partition boundaries from them one time chunk at a time.
* This takes longer due to the overhead of fetching sketches, however, this prevents any loss in accuracy from
* downsampling on the controller.
* down sampling on the controller.
*/
CompletableFuture<Either<Long, ClusterByPartitions>> sequentialTimeChunkMerging(
public void sequentialTimeChunkMerging(
Consumer<Consumer<ControllerQueryKernel>> kernelActions,
CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
StageDefinition stageDefinition,
List<String> workerTaskIds
StageId stageId,
Set<String> tasks,
TriConsumer<ControllerQueryKernel, Integer, MSQFault> retryOperation
)
{
SequentialFetchStage sequentialFetchStage = new SequentialFetchStage(
stageDefinition,
workerTaskIds,
completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().entrySet().iterator()
);
log.info(
"Fetching stats using %s for stage[%d] for tasks[%s]",
ClusterStatisticsMergeMode.SEQUENTIAL,
stageDefinition.getStageNumber(),
String.join("", workerTaskIds)
);
sequentialFetchStage.submitFetchingTasksForNextTimeChunk();
return sequentialFetchStage.getPartitionFuture();
}
private class SequentialFetchStage
{
private final StageDefinition stageDefinition;
private final List<String> workerTaskIds;
private final Iterator<Map.Entry<Long, Set<Integer>>> timeSegmentVsWorkerIdIterator;
private final CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture;
// Final sorted list of partition boundaries. This is appended to after statistics for each time chunk are gathered.
private final List<ClusterByPartition> finalPartitionBoundries;
public SequentialFetchStage(
StageDefinition stageDefinition,
List<String> workerTaskIds,
Iterator<Map.Entry<Long, Set<Integer>>> timeSegmentVsWorkerIdIterator
)
{
this.finalPartitionBoundries = new ArrayList<>();
this.stageDefinition = stageDefinition;
this.workerTaskIds = workerTaskIds;
this.timeSegmentVsWorkerIdIterator = timeSegmentVsWorkerIdIterator;
this.partitionFuture = new CompletableFuture<>();
if (!completeKeyStatisticsInformation.isComplete()) {
throw new ISE("All worker partial key information not received for stage[%d]", stageId.getStageNumber());
}
/**
* Submits the tasks to fetch key statistics for the time chunk pointed to by {@link #timeSegmentVsWorkerIdIterator}.
* Once the statistics have been gathered from all workers which have them, generates partitions and adds it to
* {@link #finalPartitionBoundries}, stiching the partitions between time chunks using
* {@link #abutAndAppendPartitionBoundries(List, List)} to make them continuous.
*
* The time chunks returned by {@link #timeSegmentVsWorkerIdIterator} should be in ascending order for the partitions
* to be generated correctly.
*
* If {@link #timeSegmentVsWorkerIdIterator} doesn't have any more values, assumes that partition boundaries have
* been successfully generated and completes {@link #partitionFuture} with the result.
*
* Completes the future with an error as soon as the number of partitions exceed max partition count for the stage
* definition.
*/
public void submitFetchingTasksForNextTimeChunk()
{
if (!timeSegmentVsWorkerIdIterator.hasNext()) {
partitionFuture.complete(Either.value(new ClusterByPartitions(finalPartitionBoundries)));
} else {
Map.Entry<Long, Set<Integer>> entry = timeSegmentVsWorkerIdIterator.next();
// Time chunk for which partition boundries are going to be generated for
Long timeChunk = entry.getKey();
Set<Integer> workerIdsWithTimeChunk = entry.getValue();
// Create a new key statistics collector to merge worker sketches into
ClusterByStatisticsCollector mergedStatisticsCollector =
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
// Guarded by synchronized mergedStatisticsCollector
Set<Integer> finishedWorkers = new HashSet<>();
completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().forEach((timeChunk, wks) -> {
log.debug("Query [%s]. Submitting request for statistics for time chunk %s to %s workers",
stageDefinition.getId().getQueryId(),
timeChunk,
workerIdsWithTimeChunk.size());
// Submits a task for every worker which has a certain time chunk
for (int workerNo : workerIdsWithTimeChunk) {
for (String taskId : tasks) {
int workerNumber = MSQTasks.workerFromTaskId(taskId);
if (wks.contains(workerNumber)) {
executorService.submit(() -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
workerClient.fetchClusterByStatisticsSnapshotForTimeChunk(
workerTaskIds.get(workerNo),
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber(),
fetchStatsFromWorker(
kernelActions,
() -> workerClient.fetchClusterByStatisticsSnapshotForTimeChunk(
taskId,
stageId.getQueryId(),
stageId.getStageNumber(),
timeChunk
);
),
taskId,
(kernel, snapshot) -> kernel.mergeClusterByStatisticsCollectorForTimeChunk(
stageId,
workerNumber,
timeChunk,
snapshot
),
retryOperation
);
try {
ClusterByStatisticsSnapshot snapshotForTimeChunk = snapshotFuture.get();
if (snapshotForTimeChunk == null) {
throw new ISE("Worker %s returned null sketch for %s, this should never happen", workerNo, timeChunk);
}
synchronized (mergedStatisticsCollector) {
mergedStatisticsCollector.addAll(snapshotForTimeChunk);
finishedWorkers.add(workerNo);
if (finishedWorkers.size() == workerIdsWithTimeChunk.size()) {
Either<Long, ClusterByPartitions> longClusterByPartitionsEither =
stageDefinition.generatePartitionsForShuffle(mergedStatisticsCollector);
log.debug("Query [%s]. Received all statistics for time chunk %s, generating partitions",
stageDefinition.getId().getQueryId(),
timeChunk);
long totalPartitionCount = finalPartitionBoundries.size() + getPartitionCountFromEither(longClusterByPartitionsEither);
if (totalPartitionCount > stageDefinition.getMaxPartitionCount()) {
// Fail fast if more partitions than the maximum have been reached.
partitionFuture.complete(Either.error(totalPartitionCount));
mergedStatisticsCollector.clear();
} else {
List<ClusterByPartition> timeSketchPartitions = longClusterByPartitionsEither.valueOrThrow().ranges();
abutAndAppendPartitionBoundries(finalPartitionBoundries, timeSketchPartitions);
log.debug("Query [%s]. Finished generating partitions for time chunk %s, total count so far %s",
stageDefinition.getId().getQueryId(),
timeChunk,
finalPartitionBoundries.size());
submitFetchingTasksForNextTimeChunk();
}
}
}
}
catch (Exception e) {
synchronized (mergedStatisticsCollector) {
if (!partitionFuture.isDone()) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
}
}
}
});
}
}
}
/**
* Takes a list of sorted {@link ClusterByPartitions} {@param timeSketchPartitions} and adds it to a sorted list
* {@param finalPartitionBoundries}. If {@param finalPartitionBoundries} is not empty, the end time of the last
* partition of {@param finalPartitionBoundries} is changed to abut with the starting time of the first partition
* of {@param timeSketchPartitions}.
*
* This is used to make the partitions generated continuous.
*/
private void abutAndAppendPartitionBoundries(
List<ClusterByPartition> finalPartitionBoundries,
List<ClusterByPartition> timeSketchPartitions
)
{
if (!finalPartitionBoundries.isEmpty()) {
// Stitch up the end time of the last partition with the start time of the first partition.
ClusterByPartition clusterByPartition = finalPartitionBoundries.remove(finalPartitionBoundries.size() - 1);
finalPartitionBoundries.add(new ClusterByPartition(clusterByPartition.getStart(), timeSketchPartitions.get(0).getStart()));
}
finalPartitionBoundries.addAll(timeSketchPartitions);
}
public CompletableFuture<Either<Long, ClusterByPartitions>> getPartitionFuture()
{
return partitionFuture;
}
});
}
/**
* Gets the partition size from an {@link Either}. If it is an error, the long denotes the number of partitions
* (in the case of creating too many partitions), otherwise checks the size of the list.
* Returns {@link Throwable} if error, else null
*/
private static long getPartitionCountFromEither(Either<Long, ClusterByPartitions> either)
public Throwable getError()
{
if (either.isError()) {
return either.error();
} else {
return either.valueOrThrow().size();
}
return isError.get();
}
@Override
public void close()
{

View File

@ -55,6 +55,8 @@ import org.apache.druid.msq.indexing.error.NotEnoughMemoryFault;
import org.apache.druid.msq.indexing.error.QueryNotSupportedFault;
import org.apache.druid.msq.indexing.error.RowTooLargeFault;
import org.apache.druid.msq.indexing.error.TaskStartTimeoutFault;
import org.apache.druid.msq.indexing.error.TooManyAttemptsForJob;
import org.apache.druid.msq.indexing.error.TooManyAttemptsForWorker;
import org.apache.druid.msq.indexing.error.TooManyBucketsFault;
import org.apache.druid.msq.indexing.error.TooManyClusteredByColumnsFault;
import org.apache.druid.msq.indexing.error.TooManyColumnsFault;
@ -120,8 +122,10 @@ public class MSQIndexingModule implements DruidModule
TooManyPartitionsFault.class,
TooManyWarningsFault.class,
TooManyWorkersFault.class,
TooManyAttemptsForJob.class,
UnknownFault.class,
WorkerFailedFault.class,
TooManyAttemptsForWorker.class,
WorkerRpcFailedFault.class
);

View File

@ -128,16 +128,17 @@ public class ControllerChatHandler implements ChatHandler
* See {@link ControllerClient#postCounters} for the client-side code that calls this API.
*/
@POST
@Path("/counters")
@Path("/counters/{taskId}")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response httpPostCounters(
@PathParam("taskId") final String taskId,
final CounterSnapshotsTree snapshotsTree,
@Context final HttpServletRequest req
)
{
ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper());
controller.updateCounters(snapshotsTree);
controller.updateCounters(taskId, snapshotsTree);
return Response.status(Response.Status.OK).build();
}

View File

@ -80,10 +80,11 @@ public class IndexerControllerClient implements ControllerClient
}
@Override
public void postCounters(CounterSnapshotsTree snapshotsTree) throws IOException
public void postCounters(String workerId, CounterSnapshotsTree snapshotsTree) throws IOException
{
final String path = StringUtils.format("/counters/%s", StringUtils.urlEncode(workerId));
doRequest(
new RequestBuilder(HttpMethod.POST, "/counters")
new RequestBuilder(HttpMethod.POST, path)
.jsonContent(jsonMapper, snapshotsTree),
IgnoreHttpResponseHandler.INSTANCE
);

View File

@ -43,10 +43,10 @@ public class IndexerWorkerManagerClient implements WorkerManagerClient
}
@Override
public String run(String controllerId, MSQWorkerTask task)
public String run(String taskId, MSQWorkerTask task)
{
FutureUtils.getUnchecked(overlordClient.runTask(controllerId, task), true);
return controllerId;
FutureUtils.getUnchecked(overlordClient.runTask(taskId, task), true);
return taskId;
}
@Override

View File

@ -37,6 +37,7 @@ import org.apache.druid.msq.exec.WorkerContext;
import org.apache.druid.msq.exec.WorkerImpl;
import java.util.Map;
import java.util.Objects;
@JsonTypeName(MSQWorkerTask.TYPE)
public class MSQWorkerTask extends AbstractTask
@ -45,8 +46,10 @@ public class MSQWorkerTask extends AbstractTask
private final String controllerTaskId;
private final int workerNumber;
private final int retryCount;
// Using an Injector directly because tasks do not have a way to provide their own Guice modules.
// Not part of equals and hashcode implementation
@JacksonInject
private Injector injector;
@ -58,11 +61,12 @@ public class MSQWorkerTask extends AbstractTask
@JsonProperty("controllerTaskId") final String controllerTaskId,
@JsonProperty("dataSource") final String dataSource,
@JsonProperty("workerNumber") final int workerNumber,
@JsonProperty("context") final Map<String, Object> context
@JsonProperty("context") final Map<String, Object> context,
@JsonProperty(value = "retry", defaultValue = "0") final int retryCount
)
{
super(
MSQTasks.workerTaskId(controllerTaskId, workerNumber),
MSQTasks.workerTaskId(controllerTaskId, workerNumber, retryCount),
controllerTaskId,
null,
dataSource,
@ -71,6 +75,7 @@ public class MSQWorkerTask extends AbstractTask
this.controllerTaskId = controllerTaskId;
this.workerNumber = workerNumber;
this.retryCount = retryCount;
}
@JsonProperty
@ -85,6 +90,21 @@ public class MSQWorkerTask extends AbstractTask
return workerNumber;
}
@JsonProperty("retry")
public int getRetryCount()
{
return retryCount;
}
/**
* Creates a new retry {@link MSQWorkerTask} with the same context as the current task, but with the retry count
* incremented by 1
*/
public MSQWorkerTask getRetryTask()
{
return new MSQWorkerTask(controllerTaskId, getDataSource(), workerNumber, getContext(), retryCount + 1);
}
@Override
public String getType()
{
@ -119,4 +139,29 @@ public class MSQWorkerTask extends AbstractTask
{
return getContextValue(Tasks.PRIORITY_KEY, Tasks.DEFAULT_BATCH_INDEX_TASK_PRIORITY);
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
MSQWorkerTask that = (MSQWorkerTask) o;
return workerNumber == that.workerNumber
&& retryCount == that.retryCount
&& Objects.equals(controllerTaskId, that.controllerTaskId)
&& Objects.equals(worker, that.worker);
}
@Override
public int hashCode()
{
return Objects.hash(super.hashCode(), controllerTaskId, workerNumber, retryCount, worker);
}
}

View File

@ -35,10 +35,14 @@ import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.exec.ControllerContext;
import org.apache.druid.msq.exec.ControllerImpl;
import org.apache.druid.msq.exec.Limits;
import org.apache.druid.msq.exec.MSQTasks;
import org.apache.druid.msq.exec.WorkerManagerClient;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.MSQWarnings;
import org.apache.druid.msq.indexing.error.TaskStartTimeoutFault;
import org.apache.druid.msq.indexing.error.TooManyAttemptsForJob;
import org.apache.druid.msq.indexing.error.TooManyAttemptsForWorker;
import org.apache.druid.msq.indexing.error.UnknownFault;
import org.apache.druid.msq.indexing.error.WorkerFailedFault;
import org.apache.druid.msq.util.MultiStageQueryContext;
@ -48,6 +52,7 @@ import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@ -69,6 +74,7 @@ public class MSQWorkerTaskLauncher
private static final long LOW_FREQUENCY_CHECK_MILLIS = 2000;
private static final long SWITCH_TO_LOW_FREQUENCY_CHECK_AFTER_MILLIS = 10000;
private static final long SHUTDOWN_TIMEOUT_MS = Duration.ofMinutes(1).toMillis();
private int currentRelaunchCount = 0;
// States for "state" variable.
private enum State
@ -111,15 +117,27 @@ public class MSQWorkerTaskLauncher
// Mutable state accessible only to the main loop. LinkedHashMap since order of key set matters. Tasks are added
// here once they are submitted for running, but before they are fully started up.
// taskId -> taskTracker
private final Map<String, TaskTracker> taskTrackers = new LinkedHashMap<>();
// Set of tasks which are issued a cancel request by the controller.
private final Set<String> canceledWorkerTasks = ConcurrentHashMap.newKeySet();
// tasks to clean up due to retries
private final Set<String> tasksToCleanup = ConcurrentHashMap.newKeySet();
// workers to relaunch
private final Set<Integer> workersToRelaunch = ConcurrentHashMap.newKeySet();
private final ConcurrentHashMap<Integer, List<String>> workerToTaskIds = new ConcurrentHashMap<>();
private final RetryTask retryTask;
public MSQWorkerTaskLauncher(
final String controllerTaskId,
final String dataSource,
final ControllerContext context,
final RetryTask retryTask,
final boolean durableStageStorageEnabled,
@Nullable final Long maxParseExceptions,
final long maxTaskStartDelayMillis
@ -131,6 +149,8 @@ public class MSQWorkerTaskLauncher
this.exec = Execs.singleThreaded(
"multi-stage-query-task-launcher[" + StringUtils.encodeForFormat(controllerTaskId) + "]-%s"
);
this.retryTask = retryTask;
this.durableStageStorageEnabled = durableStageStorageEnabled;
this.maxParseExceptions = maxParseExceptions;
this.maxTaskStartDelayMillis = maxTaskStartDelayMillis;
@ -197,7 +217,7 @@ public class MSQWorkerTaskLauncher
/**
* Get the list of currently-active tasks.
*/
public List<String> getTaskList()
public List<String> getActiveTasks()
{
synchronized (taskIds) {
return ImmutableList.copyOf(taskIds);
@ -227,6 +247,36 @@ public class MSQWorkerTaskLauncher
}
}
/**
* Queues worker for relaunch. A noop if the worker is already in the queue.
*
* @param workerNumber
*/
public void submitForRelaunch(int workerNumber)
{
workersToRelaunch.add(workerNumber);
}
/**
* Blocks the call untill the worker tasks are ready to be contacted for work.
*
* @param workerSet
* @throws InterruptedException
*/
public void waitUntilWorkersReady(Set<Integer> workerSet) throws InterruptedException
{
synchronized (taskIds) {
while (!fullyStartedTasks.containsAll(workerSet)) {
if (stopFuture.isDone() || stopFuture.isCancelled()) {
FutureUtils.getUnchecked(stopFuture, false);
throw new ISE("Stopped");
}
taskIds.wait();
}
}
}
/**
* Checks if the controller has canceled the input taskId. This method is used in {@link ControllerImpl}
* to figure out if the worker taskId is canceled by the controller. If yes, the errors from that worker taskId
@ -239,6 +289,15 @@ public class MSQWorkerTaskLauncher
return canceledWorkerTasks.contains(taskId);
}
public boolean isTaskLatest(String taskId)
{
int worker = MSQTasks.workerFromTaskId(taskId);
synchronized (taskIds) {
return taskId.equals(taskIds.get(worker));
}
}
private void mainLoop()
{
try {
@ -251,6 +310,8 @@ public class MSQWorkerTaskLauncher
runNewTasks();
updateTaskTrackersAndTaskIds();
checkForErroneousTasks();
relaunchTasks();
cleanFailedTasksWhichAreRelaunched();
}
catch (Throwable e) {
state.set(State.STOPPED);
@ -318,7 +379,7 @@ public class MSQWorkerTaskLauncher
final Map<String, Object> taskContext = new HashMap<>();
if (durableStageStorageEnabled) {
taskContext.put(MultiStageQueryContext.CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, true);
taskContext.put(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, true);
}
if (maxParseExceptions != null) {
@ -339,10 +400,19 @@ public class MSQWorkerTaskLauncher
controllerTaskId,
dataSource,
i,
taskContext
taskContext,
0
);
taskTrackers.put(task.getId(), new TaskTracker(i));
taskTrackers.put(task.getId(), new TaskTracker(i, task));
workerToTaskIds.compute(i, (workerId, taskIds) -> {
if (taskIds == null) {
taskIds = new ArrayList<>();
}
taskIds.add(task.getId());
return taskIds;
});
context.workerManager().run(task.getId(), task);
synchronized (taskIds) {
@ -405,33 +475,137 @@ public class MSQWorkerTaskLauncher
/**
* Used by the main loop to generate exceptions if any tasks have failed, have taken too long to start up, or
* have gone inexplicably missing.
*
* <p>
* Throws an exception if some task is erroneous.
*/
private void checkForErroneousTasks()
{
final int numTasks = taskTrackers.size();
for (final Map.Entry<String, TaskTracker> taskEntry : taskTrackers.entrySet()) {
Iterator<Map.Entry<String, TaskTracker>> taskTrackerIterator = taskTrackers.entrySet().iterator();
while (taskTrackerIterator.hasNext()) {
final Map.Entry<String, TaskTracker> taskEntry = taskTrackerIterator.next();
final String taskId = taskEntry.getKey();
final TaskTracker tracker = taskEntry.getValue();
if (tracker.isRetrying()) {
continue;
}
if (tracker.status == null) {
throw new MSQException(UnknownFault.forMessage(StringUtils.format("Task [%s] status missing", taskId)));
}
removeWorkerFromFullyStartedWorkers(tracker);
final String errorMessage = StringUtils.format("Task [%s] status missing", taskId);
log.info(errorMessage + ". Trying to relaunch the worker");
tracker.enableRetrying();
retryTask.retry(
tracker.msqWorkerTask,
UnknownFault.forMessage(errorMessage)
);
if (tracker.didRunTimeOut(maxTaskStartDelayMillis) && !canceledWorkerTasks.contains(taskId)) {
} else if (tracker.didRunTimeOut(maxTaskStartDelayMillis) && !canceledWorkerTasks.contains(taskId)) {
removeWorkerFromFullyStartedWorkers(tracker);
throw new MSQException(new TaskStartTimeoutFault(numTasks + 1));
} else if (tracker.didFail() && !canceledWorkerTasks.contains(taskId)) {
removeWorkerFromFullyStartedWorkers(tracker);
log.info("Task[%s] failed because %s. Trying to relaunch the worker", taskId, tracker.status.getErrorMsg());
tracker.enableRetrying();
retryTask.retry(tracker.msqWorkerTask, new WorkerFailedFault(taskId, tracker.status.getErrorMsg()));
}
}
}
if (tracker.didFail() && !canceledWorkerTasks.contains(taskId)) {
throw new MSQException(new WorkerFailedFault(taskId, tracker.status.getErrorMsg()));
}
private void removeWorkerFromFullyStartedWorkers(TaskTracker tracker)
{
synchronized (taskIds) {
fullyStartedTasks.remove(tracker.msqWorkerTask.getWorkerNumber());
}
}
private void relaunchTasks()
{
Iterator<Integer> iterator = workersToRelaunch.iterator();
while (iterator.hasNext()) {
int worker = iterator.next();
workerToTaskIds.compute(worker, (workerId, taskHistory) -> {
if (taskHistory == null || taskHistory.isEmpty()) {
throw new ISE("TaskHistory cannot by null for worker %d", workerId);
}
String latestTaskId = taskHistory.get(taskHistory.size() - 1);
TaskTracker tracker = taskTrackers.get(latestTaskId);
if (tracker == null) {
throw new ISE("Did not find taskTracker for latest taskId[%s]", latestTaskId);
}
// if task is not failed donot retry
if (!tracker.isComplete()) {
return taskHistory;
}
MSQWorkerTask toRelaunch = tracker.msqWorkerTask;
MSQWorkerTask relaunchedTask = toRelaunch.getRetryTask();
// check relaunch limits
checkRelaunchLimitsOrThrow(tracker, toRelaunch);
// clean up trackers and tasks
tasksToCleanup.add(latestTaskId);
taskTrackers.remove(latestTaskId);
log.info(
"Relaunching worker[%d] with new task id[%s] with worker relaunch count[%d] and job relaunch count[%d]",
relaunchedTask.getWorkerNumber(),
relaunchedTask.getId(),
toRelaunch.getRetryCount(),
currentRelaunchCount
);
currentRelaunchCount += 1;
taskTrackers.put(relaunchedTask.getId(), new TaskTracker(relaunchedTask.getWorkerNumber(), relaunchedTask));
synchronized (taskIds) {
fullyStartedTasks.remove(relaunchedTask.getWorkerNumber());
taskIds.notifyAll();
}
context.workerManager().run(relaunchedTask.getId(), relaunchedTask);
taskHistory.add(relaunchedTask.getId());
synchronized (taskIds) {
// replace taskId with the retry taskID for the same worker number
taskIds.set(toRelaunch.getWorkerNumber(), relaunchedTask.getId());
taskIds.notifyAll();
}
return taskHistory;
});
iterator.remove();
}
}
private void checkRelaunchLimitsOrThrow(TaskTracker tracker, MSQWorkerTask relaunchTask)
{
if (relaunchTask.getRetryCount() > Limits.PER_WORKER_RELAUNCH_LIMIT) {
throw new MSQException(new TooManyAttemptsForWorker(
Limits.PER_WORKER_RELAUNCH_LIMIT,
relaunchTask.getId(),
relaunchTask.getWorkerNumber(),
tracker.status.getErrorMsg()
));
}
if (currentRelaunchCount > Limits.TOTAL_RELAUNCH_LIMIT) {
throw new MSQException(new TooManyAttemptsForJob(
Limits.TOTAL_RELAUNCH_LIMIT,
currentRelaunchCount,
relaunchTask.getId(),
tracker.status.getErrorMsg()
));
}
}
private void shutDownTasks()
{
cleanFailedTasksWhichAreRelaunched();
for (final Map.Entry<String, TaskTracker> taskEntry : taskTrackers.entrySet()) {
final String taskId = taskEntry.getKey();
final TaskTracker tracker = taskEntry.getValue();
@ -441,6 +615,32 @@ public class MSQWorkerTaskLauncher
context.workerManager().cancel(taskId);
}
}
}
/**
* Cleans the task indentified in {@link MSQWorkerTaskLauncher#relaunchTasks()} for relaunch. Asks the overlord to cancel the task.
*/
private void cleanFailedTasksWhichAreRelaunched()
{
Iterator<String> tasksToCancel = tasksToCleanup.iterator();
while (tasksToCancel.hasNext()) {
String taskId = tasksToCancel.next();
try {
if (canceledWorkerTasks.add(taskId)) {
try {
context.workerManager().cancel(taskId);
}
catch (Exception ignore) {
//ignoring cancellation exception
}
}
}
finally {
tasksToCancel.remove();
}
}
}
/**
@ -489,12 +689,16 @@ public class MSQWorkerTaskLauncher
{
private final int workerNumber;
private final long startTimeMs = System.currentTimeMillis();
private final MSQWorkerTask msqWorkerTask;
private TaskStatus status;
private TaskLocation initialLocation;
public TaskTracker(int workerNumber)
private boolean isRetrying = false;
public TaskTracker(int workerNumber, MSQWorkerTask msqWorkerTask)
{
this.workerNumber = workerNumber;
this.msqWorkerTask = msqWorkerTask;
}
public boolean unknownLocation()
@ -518,5 +722,23 @@ public class MSQWorkerTaskLauncher
&& unknownLocation()
&& System.currentTimeMillis() - startTimeMs > maxTaskStartDelayMillis;
}
/**
* Enables retrying for the task
*/
public void enableRetrying()
{
isRetrying = true;
}
/**
* Checks is the task is retrying,
*
* @return
*/
public boolean isRetrying()
{
return isRetrying;
}
}
}

View File

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.indexing;
import org.apache.druid.msq.indexing.error.MSQFault;
public interface RetryTask
{
/**
* Retry task when {@link MSQFault} is encountered.
*
* @param workerTask
* @param msqFault
*/
void retry(MSQWorkerTask workerTask, MSQFault msqFault);
}

View File

@ -44,6 +44,12 @@ public abstract class BaseMSQFault implements MSQFault
BaseMSQFault(final String errorCode, @Nullable final String errorMessage)
{
this.errorCode = Preconditions.checkNotNull(errorCode, "errorCode");
Preconditions.checkArgument(
!errorCode.contains(MSQFaultUtils.ERROR_CODE_DELIMITER),
"Error code[%s] contains restricted characters[%s]",
errorCode,
MSQFaultUtils.ERROR_CODE_DELIMITER
);
this.errorMessage = errorMessage;
}
@ -99,7 +105,7 @@ public abstract class BaseMSQFault implements MSQFault
@Override
public String toString()
{
return getCodeWithMessage();
return MSQFaultUtils.generateMessageWithErrorCode(this);
}
private static String format(

View File

@ -26,7 +26,7 @@ import com.fasterxml.jackson.annotation.JsonTypeName;
public class CanceledFault extends BaseMSQFault
{
public static final CanceledFault INSTANCE = new CanceledFault();
static final String CODE = "Canceled";
public static final String CODE = "Canceled";
CanceledFault()
{

View File

@ -44,7 +44,7 @@ public class DurableStorageConfigurationFault extends BaseMSQFault
+ "Check the documentation on how to enable durable storage mode. "
+ "If you want to still query without durable storage mode, set %s to false in the query context. Got error %s",
MSQDurableStorageModule.MSQ_INTERMEDIATE_STORAGE_ENABLED,
MultiStageQueryContext.CTX_ENABLE_DURABLE_SHUFFLE_STORAGE,
MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE,
errorMessage
);
this.errorMessage = errorMessage;

View File

@ -19,8 +19,6 @@
package org.apache.druid.msq.indexing.error;
import com.google.common.base.Preconditions;
import javax.annotation.Nullable;
/**
@ -35,8 +33,8 @@ public class MSQException extends RuntimeException
final MSQFault fault
)
{
super(fault.getCodeWithMessage(), cause);
this.fault = Preconditions.checkNotNull(fault, "fault");
super(MSQFaultUtils.generateMessageWithErrorCode(fault), cause);
this.fault = fault;
}
public MSQException(final MSQFault fault)

View File

@ -36,14 +36,4 @@ public interface MSQFault
@Nullable
String getErrorMessage();
default String getCodeWithMessage()
{
final String message = getErrorMessage();
if (message != null && !message.isEmpty()) {
return getErrorCode() + ": " + message;
} else {
return getErrorCode();
}
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.indexing.error;
public class MSQFaultUtils
{
public static final String ERROR_CODE_DELIMITER = ": ";
/**
* Generate string message with error code delimited by {@link MSQFaultUtils#ERROR_CODE_DELIMITER}
*/
public static String generateMessageWithErrorCode(MSQFault msqFault)
{
final String message = msqFault.getErrorMessage();
if (message != null && !message.isEmpty()) {
return msqFault.getErrorCode() + ERROR_CODE_DELIMITER + message;
} else {
return msqFault.getErrorCode();
}
}
/**
* Gets the error code from the message. If the message is empty or null, {@link UnknownFault#CODE} is returned. This method
* does not gurantee that the error code we get out of the message is a valid error code.
*/
public static String getErrorCodeFromMessage(String message)
{
if (message == null || message.isEmpty() || !message.contains(ERROR_CODE_DELIMITER)) {
return UnknownFault.CODE;
}
return message.split(ERROR_CODE_DELIMITER, 2)[0];
}
}

View File

@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.indexing.error;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import java.util.Objects;
@JsonTypeName(TooManyAttemptsForJob.CODE)
public class TooManyAttemptsForJob extends BaseMSQFault
{
static final String CODE = "TooManyAttemptsForJob";
private final int maxRelaunchCount;
private final String taskId;
private final int currentRelaunchCount;
private final String rootErrorMessage;
@JsonCreator
public TooManyAttemptsForJob(
@JsonProperty("maxRelaunchCount") int maxRelaunchCount,
@JsonProperty("currentRelaunchCount") int currentRelaunchCount,
@JsonProperty("taskId") String taskId,
@JsonProperty("rootErrorMessage") String rootErrorMessage
)
{
super(
CODE,
"Total relaunch count across all workers %d exceeded max relaunch limit %d . Latest task[%s] failure reason: %s",
currentRelaunchCount,
maxRelaunchCount,
taskId,
rootErrorMessage
);
this.maxRelaunchCount = maxRelaunchCount;
this.currentRelaunchCount = currentRelaunchCount;
this.taskId = taskId;
this.rootErrorMessage = rootErrorMessage;
}
@JsonProperty
public int getMaxRelaunchCount()
{
return maxRelaunchCount;
}
@JsonProperty
public String getTaskId()
{
return taskId;
}
@JsonProperty
public int getCurrentRelaunchCount()
{
return currentRelaunchCount;
}
@JsonProperty
public String getRootErrorMessage()
{
return rootErrorMessage;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
TooManyAttemptsForJob that = (TooManyAttemptsForJob) o;
return maxRelaunchCount == that.maxRelaunchCount
&& currentRelaunchCount == that.currentRelaunchCount
&& Objects.equals(
taskId,
that.taskId
)
&& Objects.equals(rootErrorMessage, that.rootErrorMessage);
}
@Override
public int hashCode()
{
return Objects.hash(super.hashCode(), maxRelaunchCount, taskId, currentRelaunchCount, rootErrorMessage);
}
}

View File

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.indexing.error;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import java.util.Objects;
@JsonTypeName(TooManyAttemptsForWorker.CODE)
public class TooManyAttemptsForWorker extends BaseMSQFault
{
static final String CODE = "TooManyAttemptsForWorker";
private final int maxPerWorkerRelaunchCount;
private final String taskId;
private final int workerNumber;
private final String rootErrorMessage;
@JsonCreator
public TooManyAttemptsForWorker(
@JsonProperty("maxPerWorkerRelaunchCount") int maxPerWorkerRelaunchCount,
@JsonProperty("taskId") String taskId,
@JsonProperty("workerNumber") int workerNumber,
@JsonProperty("rootErrorMessage") String rootErrorMessage
)
{
super(
CODE,
"Worker[%d] exceeded max relaunch count of %d for task[%s]. Latest failure reason: %s.",
workerNumber,
maxPerWorkerRelaunchCount,
taskId,
rootErrorMessage
);
this.maxPerWorkerRelaunchCount = maxPerWorkerRelaunchCount;
this.taskId = taskId;
this.workerNumber = workerNumber;
this.rootErrorMessage = rootErrorMessage;
}
@JsonProperty
public int getMaxPerWorkerRelaunchCount()
{
return maxPerWorkerRelaunchCount;
}
@JsonProperty
public int getWorkerNumber()
{
return workerNumber;
}
@JsonProperty
public String getTaskId()
{
return taskId;
}
@JsonProperty
public String getRootErrorMessage()
{
return rootErrorMessage;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
TooManyAttemptsForWorker that = (TooManyAttemptsForWorker) o;
return maxPerWorkerRelaunchCount == that.maxPerWorkerRelaunchCount
&& workerNumber == that.workerNumber
&& Objects.equals(taskId, that.taskId)
&& Objects.equals(rootErrorMessage, that.rootErrorMessage);
}
@Override
public int hashCode()
{
return Objects.hash(super.hashCode(), maxPerWorkerRelaunchCount, taskId, workerNumber, rootErrorMessage);
}
}

View File

@ -28,7 +28,7 @@ import java.util.Objects;
@JsonTypeName(WorkerRpcFailedFault.CODE)
public class WorkerRpcFailedFault extends BaseMSQFault
{
static final String CODE = "WorkerRpcFailed";
public static final String CODE = "WorkerRpcFailed";
private final String workerTaskId;

View File

@ -48,24 +48,24 @@ import java.util.function.Supplier;
/**
* Definition of a stage in a multi-stage {@link QueryDefinition}.
*
* <p>
* Each stage has a list of {@link InputSpec} describing its inputs. The position of each spec within the list is
* its "input number". Some inputs are broadcast to all workers (see {@link #getBroadcastInputNumbers()}). Other,
* non-broadcast inputs are split up across workers.
*
* <p>
* The number of workers in a stage is at most {@link #getMaxWorkerCount()}. It may be less, depending on the
* {@link WorkerAssignmentStrategy} in play and depending on the number of distinct inputs available. (For example:
* if there is only one input file, then there can be only one worker.)
*
* <p>
* Each stage has a {@link FrameProcessorFactory} describing the work it does. Output frames written by these
* processors have the signature given by {@link #getSignature()}.
*
* <p>
* Each stage has a {@link ShuffleSpec} describing the shuffle that occurs as part of the stage. The shuffle spec is
* optional: if none is provided, then the {@link FrameProcessorFactory} directly writes to output partitions. If a
* shuffle spec is provided, then the {@link FrameProcessorFactory} is expected to sort each output frame individually
* according to {@link ShuffleSpec#getClusterBy()}. The execution system handles the rest, including sorting data across
* frames and producing the appropriate output partitions.
*
* <p>
* The rarely-used parameter {@link #getShuffleCheckHasMultipleValues()} controls whether the execution system
* checks, while shuffling, if the key used for shuffling has any multi-value fields. When this is true, the method
* {@link ClusterByStatisticsCollector#hasMultipleValues} is enabled on collectors
@ -259,6 +259,19 @@ public class StageDefinition
return id.getStageNumber();
}
/**
* Returns true, if the shuffling stage requires key statistics from the workers.
* <br></br>
* Returns false, if the stage does not shuffle.
* <br></br>
* <br></br>
* It's possible we're shuffling using partition boundaries that are known ahead of time
* For eg: we know there's exactly one partition in query shapes like `select with limit`.
* <br></br>
* In such cases, we return a false.
*
* @return
*/
public boolean mustGatherResultKeyStatistics()
{
return shuffleSpec != null && shuffleSpec.needsStatistics();
@ -269,11 +282,11 @@ public class StageDefinition
)
{
if (shuffleSpec == null) {
throw new ISE("No shuffle");
throw new ISE("No shuffle for stage[%d]", getStageNumber());
} else if (mustGatherResultKeyStatistics() && collector == null) {
throw new ISE("Statistics required, but not gathered");
throw new ISE("Statistics required, but not gathered for stage[%d]", getStageNumber());
} else if (!mustGatherResultKeyStatistics() && collector != null) {
throw new ISE("Statistics gathered, but not required");
throw new ISE("Statistics gathered, but not required for stage[%d]", getStageNumber());
} else {
return shuffleSpec.generatePartitions(collector, MAX_PARTITIONS);
}

View File

@ -23,15 +23,24 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.exec.QueryValidator;
import org.apache.druid.msq.indexing.error.CanceledFault;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.indexing.error.MSQFaultUtils;
import org.apache.druid.msq.indexing.error.UnknownFault;
import org.apache.druid.msq.indexing.error.WorkerFailedFault;
import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault;
import org.apache.druid.msq.input.InputSpecSlicer;
import org.apache.druid.msq.input.InputSpecSlicerFactory;
import org.apache.druid.msq.input.stage.ReadablePartitions;
@ -41,10 +50,12 @@ import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@ -56,7 +67,7 @@ import java.util.stream.Collectors;
/**
* Kernel for the controller of a multi-stage query.
*
* <p>
* Instances of this class are state machines for query execution. Kernels do not do any RPC or deal with any data.
* This separation of decision-making from the "real world" allows the decision-making to live in one,
* easy-to-follow place.
@ -65,6 +76,7 @@ import java.util.stream.Collectors;
*/
public class ControllerQueryKernel
{
private static final Logger log = new Logger(ControllerQueryKernel.class);
private final QueryDefinition queryDef;
/**
@ -107,9 +119,32 @@ public class ControllerQueryKernel
*/
private final Set<StageId> effectivelyFinishedStages = new HashSet<>();
public ControllerQueryKernel(final QueryDefinition queryDef)
/**
* Map<StageId, Map <WorkerNumber, WorkOrder>>
* Stores the work order per worker per stage so that we can retrieve that in case of worker retry
*/
private final Map<StageId, Int2ObjectMap<WorkOrder>> stageWorkOrders;
/**
* {@link MSQFault#getErrorCode()} which are retried.
*/
private static final Set<String> RETRIABLE_ERROR_CODES = ImmutableSet.of(
CanceledFault.CODE,
UnknownFault.CODE,
WorkerRpcFailedFault.CODE
);
private final int maxRetainedPartitionSketchBytes;
private final boolean faultToleranceEnabled;
public ControllerQueryKernel(
final QueryDefinition queryDef,
int maxRetainedPartitionSketchBytes,
boolean faultToleranceEnabled
)
{
this.queryDef = queryDef;
this.maxRetainedPartitionSketchBytes = maxRetainedPartitionSketchBytes;
this.faultToleranceEnabled = faultToleranceEnabled;
this.inflowMap = ImmutableMap.copyOf(computeStageInflowMap(queryDef));
this.outflowMap = ImmutableMap.copyOf(computeStageOutflowMap(queryDef));
@ -117,6 +152,8 @@ public class ControllerQueryKernel
this.pendingInflowMap = computeStageInflowMap(queryDef);
this.pendingOutflowMap = computeStageOutflowMap(queryDef);
stageWorkOrders = new HashMap<>();
initializeReadyToRunStages();
}
@ -208,7 +245,7 @@ public class ControllerQueryKernel
}
/**
* Returns true if all the stages comprising the query definition have been sucessful in producing their results
* Returns true if all the stages comprising the query definition have been successful in producing their results
*/
public boolean isSuccess()
{
@ -226,7 +263,7 @@ public class ControllerQueryKernel
@Nullable final Int2ObjectMap<Object> extraInfos
)
{
final Int2ObjectMap<WorkOrder> retVal = new Int2ObjectAVLTreeMap<>();
final Int2ObjectMap<WorkOrder> workerToWorkOrder = new Int2ObjectAVLTreeMap<>();
final ControllerStageTracker stageKernel = getStageKernelOrThrow(getStageId(stageNumber));
final WorkerInputs workerInputs = stageKernel.getWorkerInputs();
@ -246,10 +283,10 @@ public class ControllerQueryKernel
);
QueryValidator.validateWorkOrder(workOrder);
retVal.put(workerNumber, workOrder);
workerToWorkOrder.put(workerNumber, workOrder);
}
return retVal;
stageWorkOrders.put(new StageId(queryDef.getQueryId(), stageNumber), workerToWorkOrder);
return workerToWorkOrder;
}
private void createNewKernels(
@ -265,7 +302,8 @@ public class ControllerQueryKernel
stageDef,
stageWorkerCountMap,
slicer,
assignmentStrategy
assignmentStrategy,
maxRetainedPartitionSketchBytes
);
stageTracker.put(nextStage, stageKernel);
}
@ -324,6 +362,30 @@ public class ControllerQueryKernel
return getStageKernelOrThrow(stageId).getResultPartitions();
}
/**
* Delegates call to {@link ControllerStageTracker#getWorkersToSendPartitionBoundaries()}
*/
public IntSet getWorkersToSendPartitionBoundaries(final StageId stageId)
{
return getStageKernelOrThrow(stageId).getWorkersToSendPartitionBoundaries();
}
/**
* Delegates call to {@link ControllerQueryKernel#workOrdersSentForWorker(StageId, int)}
*/
public void workOrdersSentForWorker(final StageId stageId, int worker)
{
getStageKernelOrThrow(stageId).workOrderSentForWorker(worker);
}
/**
* Delegates call to {@link ControllerStageTracker#partitionBoundariesSentForWorker(int)} ()}
*/
public void partitionBoundariesSentForWorker(final StageId stageId, int worker)
{
getStageKernelOrThrow(stageId).partitionBoundariesSentForWorker(worker);
}
/**
* Delegates call to {@link ControllerStageTracker#getResultPartitionBoundaries()}
*/
@ -340,14 +402,6 @@ public class ControllerQueryKernel
return getStageKernelOrThrow(stageId).getCompleteKeyStatisticsInformation();
}
/**
* Delegates call to {@link ControllerStageTracker#setClusterByPartitionBoundaries(ClusterByPartitions)} ()}
*/
public void setClusterByPartitionBoundaries(final StageId stageId, ClusterByPartitions clusterByPartitions)
{
getStageKernelOrThrow(stageId).setClusterByPartitionBoundaries(clusterByPartitions);
}
/**
* Delegates call to {@link ControllerStageTracker#collectorEncounteredAnyMultiValueField()}
*/
@ -366,7 +420,7 @@ public class ControllerQueryKernel
/**
* Checks if the stage can be started, delegates call to {@link ControllerStageTracker#start()} for internal phase
* transition and registers the transition in this queryKernel
* transition and registers the transition in this queryKernel. Work orders need to be created via {@link ControllerQueryKernel#createWorkOrders(int, Int2ObjectMap)} before calling this method.
*/
public void startStage(final StageId stageId)
{
@ -374,6 +428,9 @@ public class ControllerQueryKernel
if (stageKernel.getPhase() != ControllerStagePhase.NEW) {
throw new ISE("Cannot start the stage: [%s]", stageId);
}
if (stageWorkOrders.get(stageId) == null) {
throw new ISE("Work orders not present for stage %s", stageId);
}
stageKernel.start();
transitionStageKernel(stageId, ControllerStagePhase.READING_INPUT);
}
@ -381,7 +438,7 @@ public class ControllerQueryKernel
/**
* Checks if the stage can be finished, delegates call to {@link ControllerStageTracker#finish()} for internal phase
* transition and registers the transition in this query kernel
*
* <p>
* If the method is called with strict = true, we confirm if the stage can be marked as finished or else
* throw illegal argument exception
*/
@ -393,6 +450,7 @@ public class ControllerQueryKernel
getStageKernelOrThrow(stageId).finish();
effectivelyFinishedStages.remove(stageId);
transitionStageKernel(stageId, ControllerStagePhase.FINISHED);
stageWorkOrders.remove(stageId);
}
/**
@ -404,7 +462,7 @@ public class ControllerQueryKernel
}
/**
* Delegates call to {@link ControllerStageTracker#addPartialKeyStatisticsForWorker(int, PartialKeyStatisticsInformation)}.
* Delegates call to {@link ControllerStageTracker#addPartialKeyInformationForWorker(int, PartialKeyStatisticsInformation)}.
* If calling this causes transition for the stage kernel, then this gets registered in this query kernel
*/
public void addPartialKeyStatisticsForStageAndWorker(
@ -414,7 +472,7 @@ public class ControllerQueryKernel
)
{
ControllerStageTracker stageKernel = getStageKernelOrThrow(stageId);
ControllerStagePhase newPhase = stageKernel.addPartialKeyStatisticsForWorker(
ControllerStagePhase newPhase = stageKernel.addPartialKeyInformationForWorker(
workerNumber,
partialKeyStatisticsInformation
);
@ -479,8 +537,23 @@ public class ControllerQueryKernel
return stageKernel;
}
private WorkOrder getWorkOrder(int workerNumber, StageId stageId)
{
Int2ObjectMap<WorkOrder> stageWorkOrder = stageWorkOrders.get(stageId);
if (stageWorkOrder == null) {
throw new ISE("Stage[%d] work orders not found", stageId.getStageNumber());
}
WorkOrder workOrder = stageWorkOrder.get(workerNumber);
if (workOrder == null) {
throw new ISE("Work order for worker[%d] not found for stage[%d]", workerNumber, stageId.getStageNumber());
}
return workOrder;
}
/**
* Whenever a stage kernel changes it phase, the change must be "registered" by calling this method with the stageId
* Whenever a stage kernel changes its phase, the change must be "registered" by calling this method with the stageId
* and the new phase
*/
public void transitionStageKernel(StageId stageId, ControllerStagePhase newPhase)
@ -507,6 +580,12 @@ public class ControllerQueryKernel
}
if (ControllerStagePhase.isPostReadingPhase(newPhase)) {
// when fault tolerance is enabled, we cannot delete the input data eagerly as we need the input stage for retry until
// results for the current stage are ready.
if (faultToleranceEnabled && newPhase == ControllerStagePhase.POST_READING) {
return;
}
// Once the stage has consumed all the data/input from its dependent stages, we remove it from all the stages
// whose input it was dependent on
for (StageId inputStage : inflowMap.get(stageId)) {
@ -568,4 +647,136 @@ public class ControllerQueryKernel
return retVal;
}
/**
* Checks the {@link MSQFault#getErrorCode()} is eligible for retry.
* <br/>
* If yes, transitions the stage to{@link ControllerStagePhase#RETRYING} and returns all the {@link WorkOrder}
* <br/>
* else throw {@link MSQException}
*
* @param workerNumber
* @param msqFault
* @return List of {@link WorkOrder} that needs to be retried.
*/
public List<WorkOrder> getWorkInCaseWorkerEligibleForRetryElseThrow(int workerNumber, MSQFault msqFault)
{
final String errorCode;
if (msqFault instanceof WorkerFailedFault) {
errorCode = MSQFaultUtils.getErrorCodeFromMessage((((WorkerFailedFault) msqFault).getErrorMsg()));
} else {
errorCode = msqFault.getErrorCode();
}
log.info("Parsed out errorCode[%s] to check eligibility for retry", errorCode);
if (RETRIABLE_ERROR_CODES.contains(errorCode)) {
return getWorkInCaseWorkerEligibleForRetry(workerNumber);
} else {
throw new MSQException(msqFault);
}
}
/**
* Gets all the stages currently being tracked and filters out all effectively finished stages.
* <br/>
* From the remaining stages, checks if (stage,worker) needs to be retried.
* <br/>
* If yes adds the workOrder for that stage to the return list and transitions the stage kernel to {@link ControllerStagePhase#RETRYING}
*
* @param worker
* @return List of {@link WorkOrder} that needs to be retried.
*/
private List<WorkOrder> getWorkInCaseWorkerEligibleForRetry(int worker)
{
List<StageId> trackedSet = new ArrayList<>(getActiveStages());
trackedSet.removeAll(getEffectivelyFinishedStageIds());
List<WorkOrder> workOrders = new ArrayList<>();
for (StageId stageId : trackedSet) {
ControllerStageTracker controllerStageTracker = getStageKernelOrThrow(stageId);
if (ControllerStagePhase.RETRYING.canTransitionFrom(controllerStageTracker.getPhase())
&& controllerStageTracker.retryIfNeeded(worker)) {
workOrders.add(getWorkOrder(worker, stageId));
// should be a no-op.
transitionStageKernel(stageId, ControllerStagePhase.RETRYING);
}
}
return workOrders;
}
/**
* For each stage, fetches the workers who are ready with their {@link ClusterByStatisticsSnapshot}
*/
public Map<StageId, Set<Integer>> getStagesAndWorkersToFetchClusterStats()
{
List<StageId> trackedSet = new ArrayList<>(getActiveStages());
trackedSet.removeAll(getEffectivelyFinishedStageIds());
Map<StageId, Set<Integer>> stageToWorkers = new HashMap<>();
for (StageId stageId : trackedSet) {
ControllerStageTracker controllerStageTracker = getStageKernelOrThrow(stageId);
if (controllerStageTracker.getStageDefinition().mustGatherResultKeyStatistics()) {
stageToWorkers.put(stageId, controllerStageTracker.getWorkersToFetchClusterStatisticsFrom());
}
}
return stageToWorkers;
}
/**
* Delegates call to {@link ControllerStageTracker#startFetchingStatsFromWorker(int)} for each worker
*/
public void startFetchingStatsFromWorker(StageId stageId, Set<Integer> workers)
{
ControllerStageTracker controllerStageTracker = getStageKernelOrThrow(stageId);
for (int worker : workers) {
controllerStageTracker.startFetchingStatsFromWorker(worker);
}
}
/**
* Delegates call to {@link ControllerStageTracker#mergeClusterByStatisticsCollectorForAllTimeChunks(int, ClusterByStatisticsSnapshot)}
*/
public void mergeClusterByStatisticsCollectorForAllTimeChunks(
StageId stageId,
int workerNumber,
ClusterByStatisticsSnapshot clusterByStatsSnapshot
)
{
getStageKernelOrThrow(stageId).mergeClusterByStatisticsCollectorForAllTimeChunks(
workerNumber,
clusterByStatsSnapshot
);
}
/**
* Delegates call to {@link ControllerStageTracker#mergeClusterByStatisticsCollectorForTimeChunk(int, Long, ClusterByStatisticsSnapshot)}
*/
public void mergeClusterByStatisticsCollectorForTimeChunk(
StageId stageId,
int workerNumber,
Long timeChunk,
ClusterByStatisticsSnapshot clusterByStatsSnapshot
)
{
getStageKernelOrThrow(stageId).mergeClusterByStatisticsCollectorForTimeChunk(
workerNumber,
timeChunk,
clusterByStatsSnapshot
);
}
/**
* Delegates call to {@link ControllerStageTracker#allPartialKeyInformationFetched()}
*/
public boolean allPartialKeyInformationPresent(StageId stageId)
{
return getStageKernelOrThrow(stageId).allPartialKeyInformationFetched();
}
}

View File

@ -25,7 +25,7 @@ import java.util.Set;
/**
* Phases that a stage can be in, as far as the controller is concerned.
*
* <p>
* Used by {@link ControllerStageTracker}.
*/
public enum ControllerStagePhase
@ -44,13 +44,16 @@ public enum ControllerStagePhase
@Override
public boolean canTransitionFrom(final ControllerStagePhase priorPhase)
{
return priorPhase == NEW;
return priorPhase == RETRYING || priorPhase == NEW;
}
},
// Waiting to fetch key statistics in the background from the workers and incrementally generate partitions.
// This phase is only transitioned to once all partialKeyInformation are recieved from workers.
// Transitioning to this phase should also enqueue the task to fetch key statistics to WorkerSketchFetcher.
// This phase is only transitioned to once all partialKeyInformation are received from workers.
// Transitioning to this phase should also enqueue the task to fetch key statistics if `SEQUENTIAL` strategy is used.
// In `PARALLEL` strategy, we start fetching the key statistics as soon as they are available on the worker.
// This stage is not required in non-pre shuffle contexts
MERGING_STATISTICS {
@Override
public boolean canTransitionFrom(final ControllerStagePhase priorPhase)
@ -59,7 +62,7 @@ public enum ControllerStagePhase
}
},
// Post the inputs have been read and mapped to frames, in the `POST_READING` stage, we pre-shuffle and determing the partition boundaries.
// Post the inputs have been read and mapped to frames, in the `POST_READING` stage, we pre-shuffle and determining the partition boundaries.
// This step for a stage spits out the statistics of the data as a whole (and not just the individual records). This
// phase is not required in non-pre shuffle contexts.
POST_READING {
@ -96,6 +99,20 @@ public enum ControllerStagePhase
{
return true;
}
},
// Stages whose workers are currently under relaunch. We can transition out of Retrying state only when all the work orders
// of this stage have been sent.
// We can transition into Retrying phase when the prior phase did not publish its final results yet.
RETRYING {
@Override
public boolean canTransitionFrom(final ControllerStagePhase priorPhase)
{
return priorPhase == READING_INPUT
|| priorPhase == POST_READING
|| priorPhase == MERGING_STATISTICS
|| priorPhase == RETRYING;
}
};
public abstract boolean canTransitionFrom(ControllerStagePhase priorPhase);

View File

@ -22,12 +22,18 @@ package org.apache.druid.msq.kernel.controller;
import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntSortedMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.exec.ClusterStatisticsMergeMode;
import org.apache.druid.msq.indexing.error.InsertTimeNullFault;
import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.indexing.error.TooManyPartitionsFault;
@ -39,29 +45,51 @@ import org.apache.druid.msq.input.stage.ReadablePartitions;
import org.apache.druid.msq.input.stage.StageInputSlice;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
import javax.annotation.Nullable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.IntStream;
/**
* Controller-side state machine for each stage. Used by {@link ControllerQueryKernel} to form the overall state
* machine for an entire query.
*
* <p>
* Package-private: stage trackers are an internal implementation detail of {@link ControllerQueryKernel}, not meant
* for separate use.
*/
class ControllerStageTracker
{
private static final Logger log = new Logger(ControllerStageTracker.class);
private static final long STATIC_TIME_CHUNK_FOR_PARALLEL_MERGE = Granularities.ALL.bucketStart(-1);
private final StageDefinition stageDef;
private final int workerCount;
private final WorkerInputs workerInputs;
private final IntSet workersWithReportedKeyStatistics = new IntAVLTreeSet();
private final IntSet workersWithResultsComplete = new IntAVLTreeSet();
// worker-> workerStagePhase
// Controller keeps track of the stage with this map.
// Currently, we rely on the serial nature of the state machine to keep things in sync between the controller and the worker.
// So the worker state in the controller can go out of sync with the actual worker state.
private final Int2ObjectMap<ControllerWorkerStagePhase> workerToPhase = new Int2ObjectOpenHashMap<>();
// workers which have reported partial key information.
private final IntSet workerReportedPartialKeyInformation = new IntAVLTreeSet();
// workers from which key collector is fetched.
private final IntSet workersFromWhichKeyCollectorFetched = new IntAVLTreeSet();
private final int maxRetainedPartitionSketchBytes;
private ControllerStagePhase phase = ControllerStagePhase.NEW;
@Nullable
@ -75,6 +103,17 @@ class ControllerStageTracker
@Nullable
private ClusterByPartitions resultPartitionBoundaries;
// created when mergingStatsForTimeChunk is called. Should be cleared once timeChunkToBoundaries is set for the timechunk
private final Map<Long, ClusterByStatisticsCollector> timeChunkToCollector = new HashMap<>();
private final Map<Long, ClusterByPartitions> timeChunkToBoundaries = new TreeMap<>();
long totalPartitionCount;
// states used for tracking worker to timechunks and vice versa so that we know when to generate partition boundaries for (timeChunk,worker)
private Map<Integer, Set<Long>> workerToRemainingTimeChunks = null;
private Map<Long, Set<Integer>> timeChunkToRemainingWorkers = null;
@Nullable
private Object resultObject;
@ -83,12 +122,16 @@ class ControllerStageTracker
private ControllerStageTracker(
final StageDefinition stageDef,
final WorkerInputs workerInputs
final WorkerInputs workerInputs,
final int maxRetainedPartitionSketchBytes
)
{
this.stageDef = stageDef;
this.workerCount = workerInputs.workerCount();
this.workerInputs = workerInputs;
this.maxRetainedPartitionSketchBytes = maxRetainedPartitionSketchBytes;
initializeWorkerState(workerCount);
if (stageDef.mustGatherResultKeyStatistics()) {
this.completeKeyStatisticsInformation =
@ -99,6 +142,17 @@ class ControllerStageTracker
}
}
/**
* Initialize stage for each worker to {@link ControllerWorkerStagePhase#NEW}
*
* @param workerCount
*/
private void initializeWorkerState(int workerCount)
{
IntStream.range(0, workerCount)
.forEach(wokerNumber -> workerToPhase.put(wokerNumber, ControllerWorkerStagePhase.NEW));
}
/**
* Given a stage definition and number of workers to available per stage, this method creates a stage tracker.
* This method determines the actual number of workers to use (which in turn depends on the input slices and
@ -108,11 +162,16 @@ class ControllerStageTracker
final StageDefinition stageDef,
final Int2IntMap stageWorkerCountMap,
final InputSpecSlicer slicer,
final WorkerAssignmentStrategy assignmentStrategy
final WorkerAssignmentStrategy assignmentStrategy,
final int maxRetainedPartitionSketchBytes
)
{
final WorkerInputs workerInputs = WorkerInputs.create(stageDef, stageWorkerCountMap, slicer, assignmentStrategy);
return new ControllerStageTracker(stageDef, workerInputs);
return new ControllerStageTracker(
stageDef,
workerInputs,
maxRetainedPartitionSketchBytes
);
}
/**
@ -165,10 +224,88 @@ class ControllerStageTracker
}
}
/**
* Get workers which need to be sent partition boundaries
*
* @return
*/
IntSet getWorkersToSendPartitionBoundaries()
{
if (!getStageDefinition().doesShuffle()) {
throw new ISE("Result partition information is not relevant to this stage because it does not shuffle");
}
IntAVLTreeSet workers = new IntAVLTreeSet();
for (int worker : workerToPhase.keySet()) {
if (ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES.equals(workerToPhase.get(worker))) {
workers.add(worker);
}
}
return workers;
}
/**
* Indicates that the work order for worker has been sent. Transitions the state to {@link ControllerWorkerStagePhase#READING_INPUT}
* if no more work orders need to be sent.
*
* @param worker
*/
void workOrderSentForWorker(int worker)
{
workerToPhase.compute(worker, (wk, state) -> {
if (state == null) {
throw new ISE("Worker[%d] not found for stage[%s]", wk, stageDef.getStageNumber());
}
if (!ControllerWorkerStagePhase.READING_INPUT.canTransitionFrom(state)) {
throw new ISE(
"Worker[%d] cannot transistion from state[%s] to state[%s] while sending work order",
worker,
state,
ControllerWorkerStagePhase.READING_INPUT
);
}
return ControllerWorkerStagePhase.READING_INPUT;
});
if (phase != ControllerStagePhase.READING_INPUT) {
if (allWorkOrdersSent()) {
// if no more work orders need to be sent, change state to reading input from retrying.
transitionTo(ControllerStagePhase.READING_INPUT);
}
}
}
/**
* Indicates that the partition boundaries for worker has been sent.
*
* @param worker
*/
void partitionBoundariesSentForWorker(int worker)
{
workerToPhase.compute(worker, (wk, state) -> {
if (state == null) {
throw new ISE("Worker[%d] not found for stage[%s]", wk, stageDef.getStageNumber());
}
if (!ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT.canTransitionFrom(state)) {
throw new ISE(
"Worker[%d] cannot transistion from state[%s] to state[%s] while sending partition boundaries",
worker,
state,
ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT
);
}
return ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT;
});
}
/**
* Whether the result key statistics collector for this stage has encountered any multi-valued input at
* any key position.
*
* <p>
* This method exists because {@link org.apache.druid.timeline.partition.DimensionRangeShardSpec} does not
* support partitioning on multi-valued strings, so we need to know if any multi-valued strings exist in order
* to decide whether we can use this kind of shard spec.
@ -177,7 +314,7 @@ class ControllerStageTracker
{
if (completeKeyStatisticsInformation == null) {
throw new ISE("Stage does not gather result key statistics");
} else if (workersWithReportedKeyStatistics.size() != workerCount) {
} else if (workerReportedPartialKeyInformation.size() != workerCount) {
throw new ISE("Result key statistics are not ready");
} else {
return completeKeyStatisticsInformation.hasMultipleValues();
@ -234,21 +371,20 @@ class ControllerStageTracker
}
/**
* Adds result key statistics for a particular worker number. If statistics have already been added for this worker,
* then this call ignores the new ones and does nothing.
* Adds partial key statistics information for a particular worker number. If information is already added for this worker,
* then this call ignores the new information.
*
* @param workerNumber the worker
* @param workerNumber the worker
* @param partialKeyStatisticsInformation partial key statistics
*/
ControllerStagePhase addPartialKeyStatisticsForWorker(
ControllerStagePhase addPartialKeyInformationForWorker(
final int workerNumber,
final PartialKeyStatisticsInformation partialKeyStatisticsInformation
)
{
if (phase != ControllerStagePhase.READING_INPUT) {
throw new ISE("Cannot add result key statistics from stage [%s]", phase);
}
if (!stageDef.mustGatherResultKeyStatistics() || !stageDef.doesShuffle() || completeKeyStatisticsInformation == null) {
if (!stageDef.mustGatherResultKeyStatistics()
|| !stageDef.doesShuffle()
|| completeKeyStatisticsInformation == null) {
throw new ISE("Stage does not gather result key statistics");
}
@ -256,23 +392,72 @@ class ControllerStageTracker
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
try {
if (workersWithReportedKeyStatistics.add(workerNumber)) {
ControllerWorkerStagePhase currentPhase = workerToPhase.get(workerNumber);
if (partialKeyStatisticsInformation.getTimeSegments().contains(null)) {
// Time should not contain null value
failForReason(InsertTimeNullFault.instance());
return getPhase();
if (currentPhase == null) {
throw new ISE("Worker[%d] not found for stage[%s]", workerNumber, stageDef.getStageNumber());
}
try {
if (ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_ALL_KEY_STATS_TO_BE_FETCHED.canTransitionFrom(currentPhase)) {
workerToPhase.put(workerNumber, ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_ALL_KEY_STATS_TO_BE_FETCHED);
// if partial key stats already received for worker, do not update the sketch.
if (workerReportedPartialKeyInformation.add(workerNumber)) {
if (partialKeyStatisticsInformation.getTimeSegments().contains(null)) {
// Time should not contain null value
failForReason(InsertTimeNullFault.instance());
return getPhase();
}
completeKeyStatisticsInformation.mergePartialInformation(workerNumber, partialKeyStatisticsInformation);
}
completeKeyStatisticsInformation.mergePartialInformation(workerNumber, partialKeyStatisticsInformation);
if (resultPartitions != null) {
// we already have result partitions. No need to fetch the stats from worker
// can happen in case of worker retry
workerToPhase.put(
workerNumber,
ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES
);
}
if (workersWithReportedKeyStatistics.size() == workerCount) {
if (workersFromWhichKeyCollectorFetched.contains(workerNumber)) {
// we already have fetched the key collector from this worker. No need to fetch it again.
// can happen in case of worker retry
workerToPhase.put(
workerNumber,
ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES
);
}
if (allPartialKeyInformationFetched()) {
completeKeyStatisticsInformation.complete();
if (workerToRemainingTimeChunks == null && timeChunkToRemainingWorkers == null) {
initializeTimeChunkWorkerTrackers();
}
// All workers have sent the partial key statistics information.
// Transition to MERGING_STATISTICS state to queue fetch clustering statistics from workers.
transitionTo(ControllerStagePhase.MERGING_STATISTICS);
if (phase != ControllerStagePhase.FAILED) {
transitionTo(ControllerStagePhase.MERGING_STATISTICS);
}
// if all the results have been fetched, we can straight way transition to post reading.
if (allResultsStatsFetched()) {
if (phase != ControllerStagePhase.FAILED) {
transitionTo(ControllerStagePhase.POST_READING);
}
}
}
} else {
throw new ISE(
"Worker[%d] for stage[%d] expected to be in state[%s]. Found state[%s]",
workerNumber,
(stageDef.getStageNumber()),
ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_ALL_KEY_STATS_TO_BE_FETCHED,
currentPhase
);
}
}
catch (Exception e) {
@ -283,6 +468,262 @@ class ControllerStageTracker
return getPhase();
}
private void initializeTimeChunkWorkerTrackers()
{
workerToRemainingTimeChunks = new HashMap<>();
timeChunkToRemainingWorkers = new HashMap<>();
completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().forEach((timeChunk, workers) -> {
for (int worker : workers) {
this.workerToRemainingTimeChunks.compute(worker, (wk, timeChunks) -> {
if (timeChunks == null) {
timeChunks = new HashSet<>();
}
timeChunks.add(timeChunk);
return timeChunks;
});
}
timeChunkToRemainingWorkers.put(timeChunk, workers);
});
}
/**
* Merges the {@link ClusterByStatisticsSnapshot} for the worker, time chunk with the stage {@link ClusterByStatisticsCollector} being
* tracked at {@link #timeChunkToCollector} for the same time chunk. This method is called when
* {@link ClusterStatisticsMergeMode#SEQUENTIAL} is chosen eventually.
* <br></br>
* <br></br>
* If all the stats from the worker are merged, we transition the worker to {@link ControllerWorkerStagePhase#PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES};
* <br></br>
* If all the stats from all the workers are merged, we transition the stage to {@link ControllerStagePhase#POST_READING}
*/
void mergeClusterByStatisticsCollectorForTimeChunk(
int workerNumber,
Long timeChunk,
ClusterByStatisticsSnapshot clusterByStatisticsSnapshot
)
{
if (!stageDef.mustGatherResultKeyStatistics()
|| !stageDef.doesShuffle()) {
throw new ISE("Stage does not gather result key statistics");
}
if (workerNumber < 0 || workerNumber >= workerCount) {
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
if (completeKeyStatisticsInformation == null || !completeKeyStatisticsInformation.isComplete()) {
throw new ISE(
"Cannot merge worker[%d] time chunk until all the key information is received for stage[%d]",
workerNumber,
stageDef.getStageNumber()
);
}
ControllerWorkerStagePhase workerStagePhase = workerToPhase.get(workerNumber);
if (workerStagePhase == null) {
throw new ISE("Worker[%d] not found for stage[%s]", workerNumber, stageDef.getStageNumber());
}
// only merge in case this worker has remaining time chunks
workerToRemainingTimeChunks.computeIfPresent(workerNumber, (wk, timeChunks) -> {
if (timeChunks.remove(timeChunk)) {
// merge the key collector
timeChunkToCollector.compute(
timeChunk,
(ignored, collector) -> {
if (collector == null) {
collector = stageDef.createResultKeyStatisticsCollector(maxRetainedPartitionSketchBytes);
}
collector.addAll(clusterByStatisticsSnapshot);
return collector;
}
);
// if work for one time chunk is finished, generate the ClusterByPartitions for that timeChunk and clear the collector so that we free up controller memory.
timeChunkToRemainingWorkers.compute(timeChunk, (tc, workers) -> {
if (workers == null || workers.isEmpty()) {
throw new ISE(
"Remaining workers should not be empty until all the work is finished for time chunk[%d] for stage[%d]",
timeChunk,
stageDef.getStageNumber()
);
}
workers.remove(workerNumber);
if (workers.isEmpty()) {
// generate partition boundaries since all work is finished for the time chunk
ClusterByStatisticsCollector collector = timeChunkToCollector.get(tc);
Either<Long, ClusterByPartitions> countOrPartitions = stageDef.generatePartitionsForShuffle(collector);
totalPartitionCount += getPartitionCountFromEither(countOrPartitions);
if (totalPartitionCount > stageDef.getMaxPartitionCount()) {
failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
return null;
}
timeChunkToBoundaries.put(tc, countOrPartitions.valueOrThrow());
// clear the collector to give back memory
collector.clear();
timeChunkToCollector.remove(tc);
return null;
}
return workers;
});
}
return timeChunks.isEmpty() ? null : timeChunks;
});
// if all time chunks for worker are taken care off transition worker.
if (workerToRemainingTimeChunks.get(workerNumber) == null) {
// adding worker to a set so that we do not fetch the worker collectors again.
workersFromWhichKeyCollectorFetched.add(workerNumber);
if (ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES.canTransitionFrom(
workerStagePhase)) {
workerToPhase.put(workerNumber, ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES);
} else {
throw new ISE(
"Worker[%d] for stage[%d] expected to be in state[%s]. Found state[%s]",
workerNumber,
(stageDef.getStageNumber()),
ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES,
workerStagePhase
);
}
}
// if all time chunks have the partition boundaries, merge them to set resultPartitionBoundaries
if (workerToRemainingTimeChunks.isEmpty()) {
if (resultPartitionBoundaries == null) {
timeChunkToBoundaries.forEach((ignored, partitions) -> {
if (resultPartitionBoundaries == null) {
resultPartitionBoundaries = partitions;
} else {
abutAndAppendPartitionBoundaries(resultPartitionBoundaries.ranges(), partitions.ranges());
}
});
timeChunkToBoundaries.clear();
setClusterByPartitionBoundaries(resultPartitionBoundaries);
} else {
// we already have result partitions. We can safely transition to POST READING and submit the result boundaries to the workers.
transitionTo(ControllerStagePhase.POST_READING);
}
}
}
/**
* Merges the entire {@link ClusterByStatisticsSnapshot} for the worker with the stage {@link ClusterByStatisticsCollector} being
* tracked at {@link #timeChunkToCollector} with key {@link ControllerStageTracker#STATIC_TIME_CHUNK_FOR_PARALLEL_MERGE}. This method is called when
* {@link ClusterStatisticsMergeMode#PARALLEL} is chosen eventually.
* <br></br>
* <br></br>
* If all the stats from the worker are merged, we transition the worker to {@link ControllerWorkerStagePhase#PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES}.
* <br></br>
* If all the stats from all the workers are merged, we transition the stage to {@link ControllerStagePhase#POST_READING}.
*/
void mergeClusterByStatisticsCollectorForAllTimeChunks(
int workerNumber,
ClusterByStatisticsSnapshot clusterByStatsSnapshot
)
{
if (!stageDef.mustGatherResultKeyStatistics()
|| !stageDef.doesShuffle()) {
throw new ISE("Stage does not gather result key statistics");
}
if (workerNumber < 0 || workerNumber >= workerCount) {
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
ControllerWorkerStagePhase workerStagePhase = workerToPhase.get(workerNumber);
if (workerStagePhase == null) {
throw new ISE("Worker[%d] not found for stage[%s]", workerNumber, stageDef.getStageNumber());
}
// To prevent the case where we do not fetch the collector twice, like when worker is retried, we should be okay with the
// older collector from the previous run of the worker.
if (workersFromWhichKeyCollectorFetched.add(workerNumber)) {
// in case of parallel merge we use the "ALL" granularity start time to put the sketches
timeChunkToCollector.compute(
STATIC_TIME_CHUNK_FOR_PARALLEL_MERGE,
(timeChunk, stats) -> {
if (stats == null) {
stats = stageDef.createResultKeyStatisticsCollector(maxRetainedPartitionSketchBytes);
}
stats.addAll(clusterByStatsSnapshot);
return stats;
}
);
} else {
log.debug("Already have key collector for worker[%d] stage[%d]", workerNumber, stageDef.getStageNumber());
}
if (ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES.canTransitionFrom(workerStagePhase)) {
workerToPhase.put(workerNumber, ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES);
} else {
throw new ISE(
"Worker[%d] for stage[%d] expected to be in state[%s]. Found state[%s]",
workerNumber,
(stageDef.getStageNumber()),
ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES,
workerStagePhase
);
}
if (allResultsStatsFetched()) {
if (completeKeyStatisticsInformation == null || !completeKeyStatisticsInformation.isComplete()) {
throw new ISE(
"Cannot generate partition boundaries until all the key information is received for worker[%d] stage[%d]",
workerNumber,
stageDef.getStageNumber()
);
}
if (resultPartitions == null) {
Either<Long, ClusterByPartitions> countOrPartitions = stageDef.generatePartitionsForShuffle(timeChunkToCollector.get(
STATIC_TIME_CHUNK_FOR_PARALLEL_MERGE));
totalPartitionCount += getPartitionCountFromEither(countOrPartitions);
if (totalPartitionCount > stageDef.getMaxPartitionCount()) {
failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
return;
}
resultPartitionBoundaries = countOrPartitions.valueOrThrow();
setClusterByPartitionBoundaries(resultPartitionBoundaries);
} else {
log.debug("Already have result partitions for stage[%d]", stageDef.getStageNumber());
}
timeChunkToCollector.computeIfPresent(
STATIC_TIME_CHUNK_FOR_PARALLEL_MERGE,
(key, collector) -> collector.clear()
);
timeChunkToCollector.clear();
}
}
/**
* Returns true if all {@link ClusterByStatisticsSnapshot} are fetched from each worker else false.
*/
private boolean allResultsStatsFetched()
{
return workerToPhase.values().stream()
.filter(stagePhase -> stagePhase.equals(ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES)
|| stagePhase.equals(ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT)
|| stagePhase.equals(ControllerWorkerStagePhase.RESULTS_READY))
.count()
== workerCount;
}
/**
* Sets the {@link #resultPartitions} and {@link #resultPartitionBoundaries} and transitions the phase to POST_READING.
*/
@ -297,7 +738,7 @@ class ControllerStageTracker
}
if (!ControllerStagePhase.MERGING_STATISTICS.equals(getPhase())) {
throw new ISE("Cannot set partition boundires from key statistics from stage [%s]", getPhase());
throw new ISE("Cannot set partition boundaries from key statistics from stage [%s]", getPhase());
}
this.resultPartitionBoundaries = clusterByPartitions;
@ -326,12 +767,23 @@ class ControllerStageTracker
throw new NullPointerException("resultObject must not be null");
}
// This is unidirectional flow of data. While this works in the current state of MSQ where partial fault tolerance
// is implemented and a query flows in one direction only, rolling back of workers' state and query kernel's
// phase should be allowed to fully support fault tolerance in cases such as:
// 1. Rolling back worker's state in case it fails (and then retries)
// 2. Rolling back query kernel's phase in case the results are lost (and needs workers to retry the computation)
if (workersWithResultsComplete.add(workerNumber)) {
ControllerWorkerStagePhase currentPhase = workerToPhase.get(workerNumber);
if (currentPhase == null) {
throw new ISE("Worker[%d] not found for stage[%s]", workerNumber, stageDef.getStageNumber());
}
if (ControllerWorkerStagePhase.RESULTS_READY.canTransitionFrom(currentPhase)) {
if (stageDef.mustGatherResultKeyStatistics() && currentPhase == ControllerWorkerStagePhase.READING_INPUT) {
throw new ISE(
"Worker[%d] for stage[%d] expected to be in state[%s]. Found state[%s]",
workerNumber,
(stageDef.getStageNumber()),
ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT,
currentPhase
);
}
workerToPhase.put(workerNumber, ControllerWorkerStagePhase.RESULTS_READY);
if (this.resultObject == null) {
this.resultObject = resultObject;
} else {
@ -339,15 +791,34 @@ class ControllerStageTracker
this.resultObject = getStageDefinition().getProcessorFactory()
.mergeAccumulatedResult(this.resultObject, resultObject);
}
} else {
throw new ISE(
"Worker[%d] for stage[%d] expected to be in state[%s]. Found state[%s]",
workerNumber,
(stageDef.getStageNumber()),
stageDef.mustGatherResultKeyStatistics()
? ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT
: ControllerWorkerStagePhase.READING_INPUT,
currentPhase
);
}
if (workersWithResultsComplete.size() == workerCount) {
if (allResultsPresent()) {
transitionTo(ControllerStagePhase.RESULTS_READY);
return true;
}
return false;
}
private boolean allResultsPresent()
{
return workerToPhase.values()
.stream()
.filter(stagePhase -> stagePhase.equals(ControllerWorkerStagePhase.RESULTS_READY))
.count() == workerCount;
}
/**
* Reason for failure of this stage.
*/
@ -370,20 +841,23 @@ class ControllerStageTracker
/**
* Sets {@link #resultPartitions} (always) and {@link #resultPartitionBoundaries} without using key statistics.
*
* <p>
* If {@link StageDefinition#mustGatherResultKeyStatistics()} is true, this method should not be called.
*/
private void generateResultPartitionsAndBoundariesWithoutKeyStatistics()
{
if (resultPartitions != null) {
throw new ISE("Result partitions have already been generated");
// In case of retrying workers, we are perfectly fine using the partition boundaries generated before the retry
// took place. Hence, ignoring the request to generate result partitions.
log.debug("Partition boundaries already generated for stage %d", stageDef.getStageNumber());
return;
}
final int stageNumber = stageDef.getStageNumber();
if (stageDef.doesShuffle()) {
if (stageDef.mustGatherResultKeyStatistics()) {
throw new ISE("Cannot generate result partitions without key statistics");
if (stageDef.mustGatherResultKeyStatistics() && !allPartialKeyInformationFetched()) {
throw new ISE("Cannot generate result partitions without all worker key statistics");
}
final Either<Long, ClusterByPartitions> maybeResultPartitionBoundaries =
@ -421,6 +895,45 @@ class ControllerStageTracker
}
}
/**
* True if all {@link PartialKeyStatisticsInformation} are present for a shuffling stage which require statistics, else false.
* If the stage does not gather result statistics, we return a true.
*/
public boolean allPartialKeyInformationFetched()
{
if (!stageDef.mustGatherResultKeyStatistics()) {
return true;
}
return workerToPhase.values()
.stream()
.filter(stagePhase -> stagePhase.equals(ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_ALL_KEY_STATS_TO_BE_FETCHED)
|| stagePhase.equals(ControllerWorkerStagePhase.PRESHUFFLE_FETCHING_ALL_KEY_STATS)
|| stagePhase.equals(ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES)
|| stagePhase.equals(ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT)
|| stagePhase.equals(ControllerWorkerStagePhase.RESULTS_READY))
.count()
== workerCount;
}
/**
* True if all {@link org.apache.druid.msq.kernel.WorkOrder} are sent else false.
*/
private boolean allWorkOrdersSent()
{
return workerToPhase.values()
.stream()
.filter(stagePhase ->
stagePhase.equals(ControllerWorkerStagePhase.READING_INPUT)
|| stagePhase.equals(ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_ALL_KEY_STATS_TO_BE_FETCHED)
|| stagePhase.equals(ControllerWorkerStagePhase.PRESHUFFLE_FETCHING_ALL_KEY_STATS)
|| stagePhase.equals(ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES)
|| stagePhase.equals(ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT)
|| stagePhase.equals(ControllerWorkerStagePhase.RESULTS_READY)
)
.count()
== workerCount;
}
/**
* Marks the stage as failed and sets the reason for the same.
*
@ -433,7 +946,7 @@ class ControllerStageTracker
this.failureReason = fault;
}
void transitionTo(final ControllerStagePhase newPhase)
private void transitionTo(final ControllerStagePhase newPhase)
{
if (newPhase.canTransitionFrom(phase)) {
phase = newPhase;
@ -441,4 +954,107 @@ class ControllerStageTracker
throw new IAE("Cannot transition from [%s] to [%s]", phase, newPhase);
}
}
/**
* Retry true if the worker needs to be retried based on state else returns false.
*
* @param workerNumber
*/
public boolean retryIfNeeded(int workerNumber)
{
if (phase.equals(ControllerStagePhase.FINISHED) || phase.equals(ControllerStagePhase.RESULTS_READY)) {
// do nothing
return false;
}
if (!isTrackingWorker(workerNumber)) {
// not tracking this worker
return false;
}
if (workerToPhase.get(workerNumber).equals(ControllerWorkerStagePhase.RESULTS_READY)
|| workerToPhase.get(workerNumber).equals(ControllerWorkerStagePhase.FINISHED)) {
// do nothing
return false;
}
workerToPhase.put(workerNumber, ControllerWorkerStagePhase.NEW);
transitionTo(ControllerStagePhase.RETRYING);
return true;
}
private boolean isTrackingWorker(int workerNumber)
{
return workerToPhase.get(workerNumber) != null;
}
/**
* Returns the workers who are ready with {@link ClusterByStatisticsSnapshot}
*/
public Set<Integer> getWorkersToFetchClusterStatisticsFrom()
{
Set<Integer> workersToFetchStats = new HashSet<>();
workerToPhase.forEach((worker, phase) -> {
if (phase.equals(ControllerWorkerStagePhase.PRESHUFFLE_WAITING_FOR_ALL_KEY_STATS_TO_BE_FETCHED)) {
workersToFetchStats.add(worker);
}
});
return workersToFetchStats;
}
/**
* Transitions the worker to {@link ControllerWorkerStagePhase#PRESHUFFLE_FETCHING_ALL_KEY_STATS) indicating fetching has begun.
*/
public void startFetchingStatsFromWorker(int worker)
{
ControllerWorkerStagePhase workerStagePhase = workerToPhase.get(worker);
if (ControllerWorkerStagePhase.PRESHUFFLE_FETCHING_ALL_KEY_STATS.canTransitionFrom(workerStagePhase)) {
workerToPhase.put(worker, ControllerWorkerStagePhase.PRESHUFFLE_FETCHING_ALL_KEY_STATS);
} else {
throw new ISE(
"Worker[%d] for stage[%d] expected to be in state[%s]. Found state[%s]",
worker,
(stageDef.getStageNumber()),
ControllerWorkerStagePhase.PRESHUFFLE_FETCHING_ALL_KEY_STATS,
workerStagePhase
);
}
}
/**
* Takes a list of sorted {@link ClusterByPartitions} {@param timeSketchPartitions} and adds it to a sorted list
* {@param finalPartitionBoundaries}. If {@param finalPartitionBoundaries} is not empty, the end time of the last
* partition of {@param finalPartitionBoundaries} is changed to abut with the starting time of the first partition
* of {@param timeSketchPartitions}.
* <p>
* This is used to make the partitions generated continuous.
*/
private void abutAndAppendPartitionBoundaries(
List<ClusterByPartition> finalPartitionBoundaries,
List<ClusterByPartition> timeSketchPartitions
)
{
if (!finalPartitionBoundaries.isEmpty()) {
// Stitch up the end time of the last partition with the start time of the first partition.
ClusterByPartition clusterByPartition = finalPartitionBoundaries.remove(finalPartitionBoundaries.size() - 1);
finalPartitionBoundaries.add(new ClusterByPartition(
clusterByPartition.getStart(),
timeSketchPartitions.get(0).getStart()
));
}
finalPartitionBoundaries.addAll(timeSketchPartitions);
}
/**
* Gets the partition size from an {@link Either}. If it is an error, the long denotes the number of partitions
* (in the case of creating too many partitions), otherwise checks the size of the list.
*/
private static long getPartitionCountFromEither(Either<Long, ClusterByPartitions> either)
{
if (either.isError()) {
return either.error();
} else {
return either.valueOrThrow().size();
}
}
}

View File

@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.kernel.controller;
/**
* Worker phases that a stage can be in being tracked by the controller.
* <p>
* Used by {@link ControllerStageTracker}.
*/
public enum ControllerWorkerStagePhase
{
NEW {
@Override
public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase)
{
return false;
}
},
READING_INPUT {
@Override
public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase)
{
return priorPhase == NEW;
}
},
PRESHUFFLE_WAITING_FOR_ALL_KEY_STATS_TO_BE_FETCHED {
@Override
public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase)
{
return priorPhase == READING_INPUT;
}
},
PRESHUFFLE_FETCHING_ALL_KEY_STATS {
@Override
public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase)
{
return priorPhase == PRESHUFFLE_WAITING_FOR_ALL_KEY_STATS_TO_BE_FETCHED;
}
},
PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES {
@Override
public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase)
{
return priorPhase == PRESHUFFLE_FETCHING_ALL_KEY_STATS;
}
},
PRESHUFFLE_WRITING_OUTPUT {
@Override
public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase)
{
return priorPhase == PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES;
}
},
RESULTS_READY {
@Override
public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase)
{
return priorPhase == READING_INPUT || priorPhase == PRESHUFFLE_WRITING_OUTPUT;
}
},
FINISHED {
@Override
public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase)
{
return priorPhase == RESULTS_READY;
}
},
// Something went wrong.
FAILED {
@Override
public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase)
{
return true;
}
};
public abstract boolean canTransitionFrom(ControllerWorkerStagePhase priorPhase);
}

View File

@ -19,7 +19,7 @@
package org.apache.druid.msq.statistics;
import com.google.common.collect.ImmutableSortedMap;
import org.apache.druid.java.util.common.ISE;
import java.util.HashSet;
import java.util.Set;
@ -36,6 +36,8 @@ public class CompleteKeyStatisticsInformation
private double bytesRetained;
private boolean complete;
public CompleteKeyStatisticsInformation(
final SortedMap<Long, Set<Integer>> timeChunks,
boolean multipleValues,
@ -53,9 +55,13 @@ public class CompleteKeyStatisticsInformation
* {@param partialKeyStatisticsInformation}, {@link #multipleValues} is set to true if
* {@param partialKeyStatisticsInformation} contains multipleValues and the bytes retained by the partial sketch
* is added to {@link #bytesRetained}.
* This method should not be called after {@link CompleteKeyStatisticsInformation#complete()} is called.
*/
public void mergePartialInformation(int workerNumber, PartialKeyStatisticsInformation partialKeyStatisticsInformation)
{
if (complete) {
throw new ISE("Key stats for all workers have been received. This method should not be called.");
}
for (Long timeSegment : partialKeyStatisticsInformation.getTimeSegments()) {
this.timeSegmentVsWorkerMap
.computeIfAbsent(timeSegment, key -> new HashSet<>())
@ -67,7 +73,10 @@ public class CompleteKeyStatisticsInformation
public SortedMap<Long, Set<Integer>> getTimeSegmentVsWorkerMap()
{
return ImmutableSortedMap.copyOfSorted(timeSegmentVsWorkerMap);
if (!complete) {
throw new ISE("Key stats for all the workers have not been received. This method cant be called yet.");
}
return timeSegmentVsWorkerMap;
}
public boolean hasMultipleValues()
@ -79,4 +88,17 @@ public class CompleteKeyStatisticsInformation
{
return bytesRetained;
}
/**
* Does not allow update via {@link CompleteKeyStatisticsInformation#mergePartialInformation(int, PartialKeyStatisticsInformation)} once this method is called.
*/
public void complete()
{
complete = true;
}
public boolean isComplete()
{
return complete;
}
}

View File

@ -58,10 +58,14 @@ public class MultiStageQueryContext
public static final String CTX_FINALIZE_AGGREGATIONS = "finalizeAggregations";
private static final boolean DEFAULT_FINALIZE_AGGREGATIONS = true;
public static final String CTX_ENABLE_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage";
public static final String CTX_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage";
private static final boolean DEFAULT_DURABLE_SHUFFLE_STORAGE = false;
public static final String CTX_FAULT_TOLERANCE = "faultTolerance";
public static final boolean DEFAULT_FAULT_TOLERANCE = false;
public static final String CTX_CLUSTER_STATISTICS_MERGE_MODE = "clusterStatisticsMergeMode";
public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.PARALLEL.toString();
private static final boolean DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = false;
public static final String CTX_DESTINATION = "destination";
private static final String DEFAULT_DESTINATION = null;
@ -91,8 +95,16 @@ public class MultiStageQueryContext
public static boolean isDurableStorageEnabled(final QueryContext queryContext)
{
return queryContext.getBoolean(
CTX_ENABLE_DURABLE_SHUFFLE_STORAGE,
DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE
CTX_DURABLE_SHUFFLE_STORAGE,
DEFAULT_DURABLE_SHUFFLE_STORAGE
);
}
public static boolean isFaultToleranceEnabled(final QueryContext queryContext)
{
return queryContext.getBoolean(
CTX_FAULT_TOLERANCE,
DEFAULT_FAULT_TOLERANCE
);
}

View File

@ -19,21 +19,47 @@
package org.apache.druid.msq.exec;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.indexing.common.actions.SegmentTransactionalInsertAction;
import org.apache.druid.indexing.common.actions.TaskActionClient;
import org.apache.druid.indexing.overlord.SegmentPublishResult;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.indexing.error.InsertLockPreemptedFault;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.easymock.EasyMock;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.io.IOException;
import java.util.Collections;
import static org.mockito.Mockito.doReturn;
public class ControllerImplTest
{
@Mock
private StageDefinition stageDefinition;
@Mock
private ClusterBy clusterBy;
private AutoCloseable mocks;
@Before
public void setUp()
{
mocks = MockitoAnnotations.openMocks(this);
doReturn(StageId.fromString("1_1")).when(stageDefinition).getId();
doReturn(clusterBy).when(stageDefinition).getClusterBy();
}
@Test
public void test_performSegmentPublish_ok() throws IOException
{
@ -101,4 +127,84 @@ public class ControllerImplTest
Assert.assertEquals(InsertLockPreemptedFault.instance(), e.getFault());
}
@Test
public void test_belowThresholds_ShouldBeParallel()
{
// Cluster by bucket count not 0
doReturn(1).when(clusterBy).getBucketByCount();
// Worker count below threshold
doReturn(1).when(stageDefinition).getMaxWorkerCount();
Assert.assertEquals(
ClusterStatisticsMergeMode.PARALLEL,
ControllerImpl.finalizeClusterStatisticsMergeMode(
stageDefinition,
ClusterStatisticsMergeMode.AUTO
)
);
}
@Test
public void test_noClusterByColumns_shouldBeParallel()
{
// Cluster by bucket count 0
doReturn(ClusterBy.none()).when(stageDefinition).getClusterBy();
// Worker count above threshold
doReturn((int) Limits.MAX_WORKERS_FOR_PARALLEL_MERGE + 1).when(stageDefinition).getMaxWorkerCount();
Assert.assertEquals(
ClusterStatisticsMergeMode.PARALLEL,
ControllerImpl.finalizeClusterStatisticsMergeMode(
stageDefinition,
ClusterStatisticsMergeMode.AUTO
)
);
}
@Test
public void test_numWorkersAboveThreshold_shouldBeSequential()
{
// Cluster by bucket count not 0
doReturn(1).when(clusterBy).getBucketByCount();
// Worker count above threshold
doReturn((int) Limits.MAX_WORKERS_FOR_PARALLEL_MERGE + 1).when(stageDefinition).getMaxWorkerCount();
Assert.assertEquals(
ClusterStatisticsMergeMode.SEQUENTIAL,
ControllerImpl.finalizeClusterStatisticsMergeMode(
stageDefinition,
ClusterStatisticsMergeMode.AUTO
)
);
}
@Test
public void test_mode_should_not_change()
{
Assert.assertEquals(
ClusterStatisticsMergeMode.SEQUENTIAL,
ControllerImpl.finalizeClusterStatisticsMergeMode(null, ClusterStatisticsMergeMode.SEQUENTIAL)
);
Assert.assertEquals(
ClusterStatisticsMergeMode.PARALLEL,
ControllerImpl.finalizeClusterStatisticsMergeMode(null, ClusterStatisticsMergeMode.PARALLEL)
);
}
@After
public void tearDown() throws Exception
{
mocks.close();
}
}

View File

@ -44,21 +44,46 @@ import org.apache.druid.timeline.SegmentId;
import org.hamcrest.CoreMatchers;
import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.mockito.Mockito;
import javax.annotation.Nonnull;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
@RunWith(Parameterized.class)
public class MSQInsertTest extends MSQTestBase
{
private final HashFunction fn = Hashing.murmur3_128();
@Parameterized.Parameters(name = "{index}:with context {0}")
public static Collection<Object[]> data()
{
Object[][] data = new Object[][]{
{DEFAULT, DEFAULT_MSQ_CONTEXT},
{DURABLE_STORAGE, DURABLE_STORAGE_MSQ_CONTEXT},
{FAULT_TOLERANCE, FAULT_TOLERANCE_MSQ_CONTEXT},
{SEQUENTIAL_MERGE, SEQUENTIAL_MERGE_MSQ_CONTEXT}
};
return Arrays.asList(data);
}
@Parameterized.Parameter(0)
public String contextName;
@Parameterized.Parameter(1)
public Map<String, Object> context;
@Test
public void testInsertOnFoo1()
{
@ -70,6 +95,7 @@ public class MSQInsertTest extends MSQTestBase
testIngestQuery().setSql(
"insert into foo1 select __time, dim1 , count(*) as cnt from foo where dim1 is not null group by 1, 2 PARTITIONED by day clustered by dim1")
.setExpectedDataSource("foo1")
.setQueryContext(context)
.setExpectedRowSignature(rowSignature)
.setExpectedSegment(expectedFooSegments())
.setExpectedResultRows(expectedFooRows())
@ -100,6 +126,7 @@ public class MSQInsertTest extends MSQTestBase
+ ") group by 1 PARTITIONED by day ")
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo1",
Intervals.of("2016-06-27/P1D"),
@ -129,6 +156,7 @@ public class MSQInsertTest extends MSQTestBase
"insert into foo1 select floor(__time to day) as __time , dim1 , count(*) as cnt from foo where dim1 is not null group by 1, 2 PARTITIONED by day clustered by dim1")
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedSegment(expectedFooSegments())
.setExpectedResultRows(expectedFooRows())
.verifyResults();
@ -155,6 +183,7 @@ public class MSQInsertTest extends MSQTestBase
.setQueryContext(context)
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setQueryContext(MSQInsertTest.this.context)
.setExpectedSegment(expectedFooSegments())
.setExpectedResultRows(expectedFooRows())
.verifyResults();
@ -172,6 +201,7 @@ public class MSQInsertTest extends MSQTestBase
"INSERT INTO foo1 SELECT dim3 FROM foo WHERE dim3 IS NOT NULL PARTITIONED BY ALL TIME")
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", Intervals.ETERNITY, "test", 0)))
.setExpectedResultRows(expectedMultiValueFooRows())
.verifyResults();
@ -188,6 +218,7 @@ public class MSQInsertTest extends MSQTestBase
"INSERT INTO foo1 SELECT dim3 FROM foo WHERE dim3 IS NOT NULL GROUP BY 1 PARTITIONED BY ALL TIME")
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", Intervals.ETERNITY, "test", 0)))
.setExpectedResultRows(expectedMultiValueFooRowsGroupBy())
.verifyResults();
@ -199,6 +230,7 @@ public class MSQInsertTest extends MSQTestBase
testIngestQuery().setSql(
"INSERT INTO foo1 SELECT count(dim3) FROM foo WHERE dim3 IS NOT NULL GROUP BY 1 PARTITIONED BY ALL TIME")
.setExpectedDataSource("foo1")
.setQueryContext(context)
.setExpectedValidationErrorMatcher(CoreMatchers.allOf(
CoreMatchers.instanceOf(SqlPlanningException.class),
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
@ -219,6 +251,7 @@ public class MSQInsertTest extends MSQTestBase
"INSERT INTO foo1 SELECT MV_TO_ARRAY(dim3) AS dim3 FROM foo GROUP BY 1 PARTITIONED BY ALL TIME")
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", Intervals.ETERNITY, "test", 0)))
.setExpectedResultRows(expectedMultiValueFooRowsToArray())
.verifyResults();
@ -227,24 +260,29 @@ public class MSQInsertTest extends MSQTestBase
@Test
public void testInsertOnFoo1WithMultiValueDimGroupByWithoutGroupByEnable()
{
Map<String, Object> context = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put("groupByEnableMultiValueUnnesting", false)
.build();
Map<String, Object> localContext = ImmutableMap.<String, Object>builder()
.putAll(context)
.put("groupByEnableMultiValueUnnesting", false)
.build();
testIngestQuery().setSql(
"INSERT INTO foo1 SELECT dim3, count(*) AS cnt1 FROM foo GROUP BY dim3 PARTITIONED BY ALL TIME")
.setQueryContext(context)
.setQueryContext(localContext)
.setExpectedExecutionErrorMatcher(CoreMatchers.allOf(
CoreMatchers.instanceOf(ISE.class),
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
"Encountered multi-value dimension [dim3] that cannot be processed with 'groupByEnableMultiValueUnnesting' set to false."))
ThrowableMessageMatcher.hasMessage(!FAULT_TOLERANCE.equals(contextName)
? CoreMatchers.containsString(
"Encountered multi-value dimension [dim3] that cannot be processed with 'groupByEnableMultiValueUnnesting' set to false.")
:
CoreMatchers.containsString("exceeded max relaunch count")
)
))
.verifyExecutionError();
}
@Test
public void testRolltestRollUpOnFoo1UpOnFoo1()
public void testRollUpOnFoo1UpOnFoo1()
{
RowSignature rowSignature = RowSignature.builder()
.add("__time", ColumnType.LONG)
@ -254,7 +292,9 @@ public class MSQInsertTest extends MSQTestBase
testIngestQuery().setSql(
"insert into foo1 select __time, dim1 , count(*) as cnt from foo where dim1 is not null group by 1, 2 PARTITIONED by day clustered by dim1")
.setExpectedDataSource("foo1")
.setQueryContext(ROLLUP_CONTEXT)
.setQueryContext(new ImmutableMap.Builder<String, Object>().putAll(context)
.putAll(ROLLUP_CONTEXT_PARAMS)
.build())
.setExpectedRollUp(true)
.addExpectedAggregatorFactory(new LongSumAggregatorFactory("cnt", "cnt"))
.setExpectedRowSignature(rowSignature)
@ -272,11 +312,11 @@ public class MSQInsertTest extends MSQTestBase
.add("dim1", ColumnType.STRING)
.add("cnt", ColumnType.LONG).build();
testIngestQuery().setSql(
"insert into foo1 select floor(__time to day) as __time , dim1 , count(*) as cnt from foo where dim1 is not null group by 1, 2 PARTITIONED by day clustered by dim1")
.setExpectedDataSource("foo1")
.setQueryContext(ROLLUP_CONTEXT)
.setQueryContext(new ImmutableMap.Builder<String, Object>().putAll(context).putAll(
ROLLUP_CONTEXT_PARAMS).build())
.setExpectedRollUp(true)
.setExpectedQueryGranularity(Granularities.DAY)
.addExpectedAggregatorFactory(new LongSumAggregatorFactory("cnt", "cnt"))
@ -300,7 +340,8 @@ public class MSQInsertTest extends MSQTestBase
testIngestQuery().setSql(
"insert into foo1 select floor(__time to day) as __time , dim1 , count(distinct m1) as cnt from foo where dim1 is not null group by 1, 2 PARTITIONED by day clustered by dim1")
.setExpectedDataSource("foo1")
.setQueryContext(ROLLUP_CONTEXT)
.setQueryContext(new ImmutableMap.Builder<String, Object>().putAll(context).putAll(
ROLLUP_CONTEXT_PARAMS).build())
.setExpectedRollUp(true)
.setExpectedQueryGranularity(Granularities.DAY)
.addExpectedAggregatorFactory(new HyperUniquesAggregatorFactory("cnt", "cnt", false, true))
@ -324,7 +365,8 @@ public class MSQInsertTest extends MSQTestBase
testIngestQuery().setSql(
"insert into foo1 select __time , dim1 , count(distinct m1) as cnt from foo where dim1 is not null group by 1, 2 PARTITIONED by day clustered by dim1")
.setExpectedDataSource("foo1")
.setQueryContext(ROLLUP_CONTEXT)
.setQueryContext(new ImmutableMap.Builder<String, Object>().putAll(context).putAll(
ROLLUP_CONTEXT_PARAMS).build())
.setExpectedRollUp(true)
.addExpectedAggregatorFactory(new HyperUniquesAggregatorFactory("cnt", "cnt", false, true))
.setExpectedRowSignature(rowSignature)
@ -355,7 +397,8 @@ public class MSQInsertTest extends MSQTestBase
+ " '[{\"name\": \"timestamp\", \"type\": \"string\"}, {\"name\": \"page\", \"type\": \"string\"}, {\"name\": \"user\", \"type\": \"string\"}]'\n"
+ " )\n"
+ ") group by 1 PARTITIONED by day ")
.setQueryContext(ROLLUP_CONTEXT)
.setQueryContext(new ImmutableMap.Builder<String, Object>().putAll(context).putAll(
ROLLUP_CONTEXT_PARAMS).build())
.setExpectedRollUp(true)
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
@ -392,7 +435,8 @@ public class MSQInsertTest extends MSQTestBase
+ " '[{\"name\": \"timestamp\", \"type\": \"string\"}, {\"name\": \"namespace\", \"type\": \"string\"}, {\"name\": \"user\", \"type\": \"string\"}]'\n"
+ " )\n"
+ ") group by 1,2 PARTITIONED by day ")
.setQueryContext(ROLLUP_CONTEXT)
.setQueryContext(new ImmutableMap.Builder<String, Object>().putAll(context).putAll(
ROLLUP_CONTEXT_PARAMS).build())
.setExpectedRollUp(true)
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
@ -433,6 +477,7 @@ public class MSQInsertTest extends MSQTestBase
+ "CLUSTERED BY dim1")
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedValidationErrorMatcher(CoreMatchers.allOf(
CoreMatchers.instanceOf(SqlPlanningException.class),
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
@ -478,6 +523,7 @@ public class MSQInsertTest extends MSQTestBase
+ ") PARTITIONED by day")
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedMSQFault(new ColumnNameRestrictedFault("__bucket"))
.verifyResults();
}
@ -485,13 +531,13 @@ public class MSQInsertTest extends MSQTestBase
@Test
public void testInsertQueryWithInvalidSubtaskCount()
{
Map<String, Object> context = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put(MultiStageQueryContext.CTX_MAX_NUM_TASKS, 1)
.build();
Map<String, Object> localContext = ImmutableMap.<String, Object>builder()
.putAll(context)
.put(MultiStageQueryContext.CTX_MAX_NUM_TASKS, 1)
.build();
testIngestQuery().setSql(
"insert into foo1 select __time, dim1 , count(*) as cnt from foo where dim1 is not null group by 1, 2 PARTITIONED by day clustered by dim1")
.setQueryContext(context)
.setQueryContext(localContext)
.setExpectedExecutionErrorMatcher(
ThrowableMessageMatcher.hasMessage(
CoreMatchers.startsWith(
@ -522,6 +568,7 @@ public class MSQInsertTest extends MSQTestBase
+ " )\n"
+ ") group by 1 PARTITIONED by day ")
.setExpectedDataSource("foo")
.setQueryContext(context)
.setExpectedMSQFault(new RowTooLargeFault(500))
.setExpectedExecutionErrorMatcher(CoreMatchers.allOf(
CoreMatchers.instanceOf(ISE.class),
@ -544,6 +591,7 @@ public class MSQInsertTest extends MSQTestBase
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
"INSERT and REPLACE queries cannot have a LIMIT unless PARTITIONED BY is \"ALL\""))
))
.setQueryContext(context)
.verifyPlanningErrors();
}
@ -561,6 +609,7 @@ public class MSQInsertTest extends MSQTestBase
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
"INSERT and REPLACE queries cannot have an OFFSET"))
))
.setQueryContext(context)
.verifyPlanningErrors();
}

View File

@ -33,18 +33,42 @@ import org.apache.druid.timeline.partition.DimensionRangeShardSpec;
import org.hamcrest.CoreMatchers;
import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import javax.annotation.Nonnull;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
@RunWith(Parameterized.class)
public class MSQReplaceTest extends MSQTestBase
{
@Parameterized.Parameters(name = "{index}:with context {0}")
public static Collection<Object[]> data()
{
Object[][] data = new Object[][]{
{DEFAULT, DEFAULT_MSQ_CONTEXT},
{DURABLE_STORAGE, DURABLE_STORAGE_MSQ_CONTEXT},
{FAULT_TOLERANCE, FAULT_TOLERANCE_MSQ_CONTEXT},
{SEQUENTIAL_MERGE, SEQUENTIAL_MERGE_MSQ_CONTEXT}
};
return Arrays.asList(data);
}
@Parameterized.Parameter(0)
public String contextName;
@Parameterized.Parameter(1)
public Map<String, Object> context;
@Test
public void testReplaceOnFooWithAll()
{
@ -59,6 +83,7 @@ public class MSQReplaceTest extends MSQTestBase
+ "PARTITIONED BY DAY ")
.setExpectedDataSource("foo")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedDestinationIntervals(Intervals.ONLY_ETERNITY)
.setExpectedSegment(
ImmutableSet.of(
@ -91,15 +116,23 @@ public class MSQReplaceTest extends MSQTestBase
.add("m1", ColumnType.FLOAT)
.build();
testIngestQuery().setSql(" REPLACE INTO foo OVERWRITE WHERE __time >= TIMESTAMP '2000-01-02' AND __time < TIMESTAMP '2000-01-03' "
+ "SELECT __time, m1 "
+ "FROM foo "
+ "WHERE __time >= TIMESTAMP '2000-01-02' AND __time < TIMESTAMP '2000-01-03' "
+ "PARTITIONED by DAY ")
testIngestQuery().setSql(
" REPLACE INTO foo OVERWRITE WHERE __time >= TIMESTAMP '2000-01-02' AND __time < TIMESTAMP '2000-01-03' "
+ "SELECT __time, m1 "
+ "FROM foo "
+ "WHERE __time >= TIMESTAMP '2000-01-02' AND __time < TIMESTAMP '2000-01-03' "
+ "PARTITIONED by DAY ")
.setExpectedDataSource("foo")
.setExpectedDestinationIntervals(ImmutableList.of(Intervals.of("2000-01-02T00:00:00.000Z/2000-01-03T00:00:00.000Z")))
.setExpectedDestinationIntervals(ImmutableList.of(Intervals.of(
"2000-01-02T00:00:00.000Z/2000-01-03T00:00:00.000Z")))
.setExpectedRowSignature(rowSignature)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo", Intervals.of("2000-01-02T/P1D"), "test", 0)))
.setQueryContext(context)
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo",
Intervals.of("2000-01-02T/P1D"),
"test",
0
)))
.setExpectedResultRows(ImmutableList.of(new Object[]{946771200000L, 2.0f}))
.verifyResults();
}
@ -127,10 +160,27 @@ public class MSQReplaceTest extends MSQTestBase
.setExpectedDataSource("foo1")
.setExpectedDestinationIntervals(Intervals.ONLY_ETERNITY)
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedSegment(ImmutableSet.of(
SegmentId.of("foo1", Intervals.of("2016-06-27T00:00:00.000Z/2016-06-27T01:00:00.000Z"), "test", 0),
SegmentId.of("foo1", Intervals.of("2016-06-27T01:00:00.000Z/2016-06-27T02:00:00.000Z"), "test", 0),
SegmentId.of("foo1", Intervals.of("2016-06-27T02:00:00.000Z/2016-06-27T03:00:00.000Z"), "test", 0))
SegmentId.of(
"foo1",
Intervals.of("2016-06-27T00:00:00.000Z/2016-06-27T01:00:00.000Z"),
"test",
0
),
SegmentId.of(
"foo1",
Intervals.of("2016-06-27T01:00:00.000Z/2016-06-27T02:00:00.000Z"),
"test",
0
),
SegmentId.of(
"foo1",
Intervals.of("2016-06-27T02:00:00.000Z/2016-06-27T03:00:00.000Z"),
"test",
0
)
)
)
.setExpectedResultRows(
ImmutableList.of(
@ -158,23 +208,31 @@ public class MSQReplaceTest extends MSQTestBase
final File toRead = MSQTestFileUtils.getResourceAsTemporaryFile(this, "/wikipedia-sampled.json");
final String toReadFileNameAsJson = queryFramework().queryJsonMapper().writeValueAsString(toRead.getAbsolutePath());
testIngestQuery().setSql(" REPLACE INTO foo1 OVERWRITE WHERE __time >= TIMESTAMP '2016-06-27 01:00:00.00' AND __time < TIMESTAMP '2016-06-27 02:00:00.00' "
+ " SELECT "
+ " floor(TIME_PARSE(\"timestamp\") to hour) AS __time, "
+ " user "
+ "FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{ \"files\": [" + toReadFileNameAsJson + "],\"type\":\"local\"}',\n"
+ " '{\"type\": \"json\"}',\n"
+ " '[{\"name\": \"timestamp\", \"type\": \"string\"}, {\"name\": \"page\", \"type\": \"string\"}, {\"name\": \"user\", \"type\": \"string\"}]'\n"
+ " )\n"
+ ") "
+ "where \"timestamp\" >= TIMESTAMP '2016-06-27 01:00:00.00' AND \"timestamp\" < TIMESTAMP '2016-06-27 02:00:00.00' "
+ "PARTITIONED BY HOUR ")
testIngestQuery().setSql(
" REPLACE INTO foo1 OVERWRITE WHERE __time >= TIMESTAMP '2016-06-27 01:00:00.00' AND __time < TIMESTAMP '2016-06-27 02:00:00.00' "
+ " SELECT "
+ " floor(TIME_PARSE(\"timestamp\") to hour) AS __time, "
+ " user "
+ "FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{ \"files\": [" + toReadFileNameAsJson + "],\"type\":\"local\"}',\n"
+ " '{\"type\": \"json\"}',\n"
+ " '[{\"name\": \"timestamp\", \"type\": \"string\"}, {\"name\": \"page\", \"type\": \"string\"}, {\"name\": \"user\", \"type\": \"string\"}]'\n"
+ " )\n"
+ ") "
+ "where \"timestamp\" >= TIMESTAMP '2016-06-27 01:00:00.00' AND \"timestamp\" < TIMESTAMP '2016-06-27 02:00:00.00' "
+ "PARTITIONED BY HOUR ")
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setExpectedDestinationIntervals(ImmutableList.of(Intervals.of("2016-06-27T01:00:00.000Z/2016-06-27T02:00:00.000Z")))
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", Intervals.of("2016-06-27T01:00:00.000Z/2016-06-27T02:00:00.000Z"), "test", 0)))
.setQueryContext(context)
.setExpectedDestinationIntervals(ImmutableList.of(Intervals.of(
"2016-06-27T01:00:00.000Z/2016-06-27T02:00:00.000Z")))
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo1",
Intervals.of("2016-06-27T01:00:00.000Z/2016-06-27T02:00:00.000Z"),
"test",
0
)))
.setExpectedResultRows(
ImmutableList.of(
new Object[]{1466989200000L, "2001:DA8:207:E132:94DC:BA03:DFDF:8F9F"},
@ -197,6 +255,7 @@ public class MSQReplaceTest extends MSQTestBase
{
testIngestQuery().setSql("REPLACE INTO foo1 OVERWRITE SELECT * FROM foo PARTITIONED BY ALL TIME")
.setExpectedDataSource("foo1")
.setQueryContext(context)
.setExpectedValidationErrorMatcher(
CoreMatchers.allOf(
CoreMatchers.instanceOf(SqlPlanningException.class),
@ -222,8 +281,14 @@ public class MSQReplaceTest extends MSQTestBase
+ "PARTITIONED BY ALL TIME ")
.setExpectedDataSource("foo")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedDestinationIntervals(Intervals.ONLY_ETERNITY)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo", Intervals.of("2000-01-01T/P1M"), "test", 0)))
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo",
Intervals.of("2000-01-01T/P1M"),
"test",
0
)))
.setExpectedResultRows(
ImmutableList.of(
new Object[]{946684800000L, 1.0f},
@ -253,8 +318,14 @@ public class MSQReplaceTest extends MSQTestBase
+ "PARTITIONED BY MONTH")
.setExpectedDataSource("foo")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedDestinationIntervals(Intervals.ONLY_ETERNITY)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo", Intervals.of("2000-01-01T/P1M"), "test", 0)))
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo",
Intervals.of("2000-01-01T/P1M"),
"test",
0
)))
.setExpectedResultRows(
ImmutableList.of(
new Object[]{946684800000L, 1.0f},
@ -266,8 +337,9 @@ public class MSQReplaceTest extends MSQTestBase
)
)
.setExpectedSegment(ImmutableSet.of(
SegmentId.of("foo", Intervals.of("2000-01-01T/P1M"), "test", 0),
SegmentId.of("foo", Intervals.of("2001-01-01T/P1M"), "test", 0))
SegmentId.of("foo", Intervals.of("2000-01-01T/P1M"), "test", 0),
SegmentId.of("foo", Intervals.of("2001-01-01T/P1M"), "test", 0)
)
)
.verifyResults();
}
@ -288,15 +360,26 @@ public class MSQReplaceTest extends MSQTestBase
+ "PARTITIONED BY MONTH")
.setExpectedDataSource("foo")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedDestinationIntervals(Collections.singletonList(Intervals.of("2000-01-01T/2000-03-01T")))
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo", Intervals.of("2000-01-01T/P1M"), "test", 0)))
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo",
Intervals.of("2000-01-01T/P1M"),
"test",
0
)))
.setExpectedResultRows(
ImmutableList.of(
new Object[]{946684800000L, 1.0f},
new Object[]{946771200000L, 2.0f}
)
)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo", Intervals.of("2000-01-01T/P1M"), "test", 0)))
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo",
Intervals.of("2000-01-01T/P1M"),
"test",
0
)))
.verifyResults();
}
@ -316,15 +399,26 @@ public class MSQReplaceTest extends MSQTestBase
+ "PARTITIONED BY MONTH")
.setExpectedDataSource("foo")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedDestinationIntervals(Collections.singletonList(Intervals.of("2000-01-01T/2002-01-01T")))
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo", Intervals.of("2000-01-01T/P1M"), "test", 0)))
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo",
Intervals.of("2000-01-01T/P1M"),
"test",
0
)))
.setExpectedResultRows(
ImmutableList.of(
new Object[]{946684800000L, 1.0f},
new Object[]{946771200000L, 2.0f}
)
)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo", Intervals.of("2000-01-01T/P1M"), "test", 0)))
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo",
Intervals.of("2000-01-01T/P1M"),
"test",
0
)))
.verifyResults();
}
@ -337,6 +431,7 @@ public class MSQReplaceTest extends MSQTestBase
+ "FROM foo "
+ "LIMIT 50"
+ "PARTITIONED BY MONTH")
.setQueryContext(context)
.setExpectedValidationErrorMatcher(CoreMatchers.allOf(
CoreMatchers.instanceOf(SqlPlanningException.class),
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
@ -360,6 +455,7 @@ public class MSQReplaceTest extends MSQTestBase
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
"INSERT and REPLACE queries cannot have an OFFSET"))
))
.setQueryContext(context)
.verifyPlanningErrors();
}
@ -380,6 +476,7 @@ public class MSQReplaceTest extends MSQTestBase
.setExpectedDataSource("foo")
.setQueryContext(DEFAULT_MSQ_CONTEXT)
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedDestinationIntervals(Collections.singletonList(Intervals.of("2000-01-01T/2000-03-01T")))
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo",
@ -413,6 +510,7 @@ public class MSQReplaceTest extends MSQTestBase
.setExpectedDataSource("foo")
.setQueryContext(DEFAULT_MSQ_CONTEXT)
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedDestinationIntervals(Collections.singletonList(Intervals.of("2000-01-01T/2002-01-01T")))
.setExpectedSegment(ImmutableSet.of(SegmentId.of(
"foo",
@ -445,6 +543,7 @@ public class MSQReplaceTest extends MSQTestBase
.setQueryContext(DEFAULT_MSQ_CONTEXT)
.setExpectedShardSpec(DimensionRangeShardSpec.class)
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedSegment(expectedFooSegments())
.setExpectedResultRows(expectedFooRows())
.verifyResults();
@ -466,6 +565,7 @@ public class MSQReplaceTest extends MSQTestBase
+ "PARTITIONED BY ALL TIME ")
.setExpectedDataSource("foobar")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedDestinationIntervals(Intervals.ONLY_ETERNITY)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foobar", Intervals.ETERNITY, "test", 0)))
.setExpectedResultRows(

View File

@ -66,6 +66,8 @@ import org.apache.druid.sql.calcite.util.CalciteTests;
import org.hamcrest.CoreMatchers;
import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
@ -74,12 +76,33 @@ import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@RunWith(Parameterized.class)
public class MSQSelectTest extends MSQTestBase
{
@Parameterized.Parameters(name = "{index}:with context {0}")
public static Collection<Object[]> data()
{
Object[][] data = new Object[][]{
{DEFAULT, DEFAULT_MSQ_CONTEXT},
{DURABLE_STORAGE, DURABLE_STORAGE_MSQ_CONTEXT},
{FAULT_TOLERANCE, FAULT_TOLERANCE_MSQ_CONTEXT},
{SEQUENTIAL_MERGE, SEQUENTIAL_MERGE_MSQ_CONTEXT}
};
return Arrays.asList(data);
}
@Parameterized.Parameter(0)
public String contextName;
@Parameterized.Parameter(1)
public Map<String, Object> context;
@Test
public void testCalculator()
{
@ -101,7 +124,7 @@ public class MSQSelectTest extends MSQTestBase
)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("EXPR$0")
.context(defaultScanQueryContext(resultSignature))
.context(defaultScanQueryContext(context, resultSignature))
.build()
)
.columnMappings(ColumnMappings.identity(resultSignature))
@ -109,6 +132,7 @@ public class MSQSelectTest extends MSQTestBase
.build()
)
.setExpectedRowSignature(resultSignature)
.setQueryContext(context)
.setExpectedResultRows(ImmutableList.of(new Object[]{2L})).verifyResults();
}
@ -129,13 +153,14 @@ public class MSQSelectTest extends MSQTestBase
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("cnt", "dim1")
.context(defaultScanQueryContext(resultSignature))
.context(defaultScanQueryContext(context, resultSignature))
.build()
)
.columnMappings(ColumnMappings.identity(resultSignature))
.tuningConfig(MSQTuningConfig.defaultConfig())
.build()
)
.setQueryContext(context)
.setExpectedRowSignature(resultSignature)
.setExpectedResultRows(ImmutableList.of(
new Object[]{1L, !useDefault ? "" : null},
@ -164,6 +189,7 @@ public class MSQSelectTest extends MSQTestBase
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("dim2", "m1")
.context(defaultScanQueryContext(
context,
RowSignature.builder()
.add("dim2", ColumnType.STRING)
.add("m1", ColumnType.LONG)
@ -175,6 +201,7 @@ public class MSQSelectTest extends MSQTestBase
.build()
)
.setExpectedRowSignature(resultSignature)
.setQueryContext(context)
.setExpectedResultRows(ImmutableList.of(
new Object[]{1L, "en"},
new Object[]{1L, "ru"},
@ -207,7 +234,7 @@ public class MSQSelectTest extends MSQTestBase
))
.setAggregatorSpecs(aggregators(new CountAggregatorFactory(
"a0")))
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build())
.columnMappings(
new ColumnMappings(ImmutableList.of(
@ -219,6 +246,7 @@ public class MSQSelectTest extends MSQTestBase
.build())
.setExpectedRowSignature(rowSignature)
.setExpectedResultRows(ImmutableList.of(new Object[]{1L, 6L}))
.setQueryContext(context)
.verifyResults();
}
@ -249,7 +277,7 @@ public class MSQSelectTest extends MSQTestBase
null
)
)
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build();
testSelectQuery()
@ -266,6 +294,7 @@ public class MSQSelectTest extends MSQTestBase
.tuningConfig(MSQTuningConfig.defaultConfig())
.build())
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedResultRows(
ImmutableList.of(
new Object[]{6f, 1L},
@ -294,13 +323,13 @@ public class MSQSelectTest extends MSQTestBase
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT)))
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build()
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0")))
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build();
testSelectQuery()
@ -314,6 +343,7 @@ public class MSQSelectTest extends MSQTestBase
)
.setExpectedRowSignature(resultSignature)
.setExpectedResultRows(ImmutableList.of(new Object[]{6L}))
.setQueryContext(context)
.verifyResults();
}
@ -353,6 +383,7 @@ public class MSQSelectTest extends MSQTestBase
.columns("dim2", "m1", "m2")
.context(
defaultScanQueryContext(
context,
RowSignature.builder()
.add("dim2", ColumnType.STRING)
.add("m1", ColumnType.FLOAT)
@ -371,6 +402,7 @@ public class MSQSelectTest extends MSQTestBase
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(
defaultScanQueryContext(
context,
RowSignature.builder().add("m1", ColumnType.FLOAT).build()
)
)
@ -419,7 +451,7 @@ public class MSQSelectTest extends MSQTestBase
)
)
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build();
testSelectQuery()
@ -442,6 +474,7 @@ public class MSQSelectTest extends MSQTestBase
)
.setExpectedRowSignature(resultSignature)
.setExpectedResultRows(expectedResults)
.setQueryContext(context)
.verifyResults();
}
@ -482,6 +515,7 @@ public class MSQSelectTest extends MSQTestBase
.columns("dim2", "m1", "m2")
.context(
defaultScanQueryContext(
context,
RowSignature.builder()
.add("dim2", ColumnType.STRING)
.add("m1", ColumnType.FLOAT)
@ -535,7 +569,7 @@ public class MSQSelectTest extends MSQTestBase
)
)
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build();
testSelectQuery()
@ -558,6 +592,7 @@ public class MSQSelectTest extends MSQTestBase
)
.setExpectedRowSignature(resultSignature)
.setExpectedResultRows(expectedResults)
.setQueryContext(context)
.verifyResults();
}
@ -588,7 +623,7 @@ public class MSQSelectTest extends MSQTestBase
null
)
)
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build();
testSelectQuery()
@ -608,6 +643,7 @@ public class MSQSelectTest extends MSQTestBase
.build()
)
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedResultRows(
ImmutableList.of(
new Object[]{6f, 6d},
@ -647,7 +683,7 @@ public class MSQSelectTest extends MSQTestBase
3
)
)
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build();
testSelectQuery()
@ -667,6 +703,7 @@ public class MSQSelectTest extends MSQTestBase
.build()
)
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedResultRows(
ImmutableList.of(
new Object[]{6f, 6d},
@ -704,7 +741,7 @@ public class MSQSelectTest extends MSQTestBase
2
)
)
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build();
testSelectQuery()
@ -724,6 +761,7 @@ public class MSQSelectTest extends MSQTestBase
.build()
)
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedResultRows(
ImmutableList.of(
new Object[]{5f, 5d},
@ -768,7 +806,7 @@ public class MSQSelectTest extends MSQTestBase
)
.setDimensions(dimensions(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
.setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0")))
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build();
testSelectQuery()
@ -783,6 +821,7 @@ public class MSQSelectTest extends MSQTestBase
+ " )\n"
+ ") group by 1")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedResultRows(ImmutableList.of(new Object[]{1466985600000L, 20L}))
.setExpectedMSQSpec(
MSQSpec
@ -815,6 +854,7 @@ public class MSQSelectTest extends MSQTestBase
CoreMatchers.instanceOf(SqlPlanningException.class),
ThrowableMessageMatcher.hasMessage(CoreMatchers.startsWith("Encountered \"from <EOF>\""))
))
.setQueryContext(context)
.verifyPlanningErrors();
}
@ -823,6 +863,7 @@ public class MSQSelectTest extends MSQTestBase
{
testSelectQuery()
.setSql("SELECT * FROM INFORMATION_SCHEMA.SCHEMATA")
.setQueryContext(context)
.setExpectedValidationErrorMatcher(
CoreMatchers.allOf(
CoreMatchers.instanceOf(SqlPlanningException.class),
@ -838,6 +879,7 @@ public class MSQSelectTest extends MSQTestBase
{
testSelectQuery()
.setSql("SELECT * FROM sys.segments")
.setQueryContext(context)
.setExpectedValidationErrorMatcher(
CoreMatchers.allOf(
CoreMatchers.instanceOf(SqlPlanningException.class),
@ -853,6 +895,7 @@ public class MSQSelectTest extends MSQTestBase
{
testSelectQuery()
.setSql("select s.segment_id, s.num_rows, f.dim1 from sys.segments as s, foo as f")
.setQueryContext(context)
.setExpectedValidationErrorMatcher(
CoreMatchers.allOf(
CoreMatchers.instanceOf(SqlPlanningException.class),
@ -869,6 +912,7 @@ public class MSQSelectTest extends MSQTestBase
testSelectQuery()
.setSql("with segment_source as (SELECT * FROM sys.segments) "
+ "select segment_source.segment_id, segment_source.num_rows from segment_source")
.setQueryContext(context)
.setExpectedValidationErrorMatcher(
CoreMatchers.allOf(
CoreMatchers.instanceOf(SqlPlanningException.class),
@ -899,6 +943,7 @@ public class MSQSelectTest extends MSQTestBase
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("dim2", "m1")
.context(defaultScanQueryContext(
context,
RowSignature.builder()
.add("dim2", ColumnType.STRING)
.add("m1", ColumnType.LONG)
@ -911,6 +956,7 @@ public class MSQSelectTest extends MSQTestBase
.build()
)
.setExpectedRowSignature(resultSignature)
.setQueryContext(context)
.setExpectedResultRows(ImmutableList.of(
new Object[]{1L, "en"},
new Object[]{1L, "ru"},
@ -933,12 +979,13 @@ public class MSQSelectTest extends MSQTestBase
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("dim3")
.context(defaultScanQueryContext(resultSignature))
.context(defaultScanQueryContext(context, resultSignature))
.build())
.columnMappings(ColumnMappings.identity(resultSignature))
.tuningConfig(MSQTuningConfig.defaultConfig())
.build())
.setExpectedRowSignature(resultSignature)
.setQueryContext(context)
.setExpectedResultRows(ImmutableList.of(
new Object[]{ImmutableList.of("a", "b")},
new Object[]{ImmutableList.of("b", "c")},
@ -952,10 +999,7 @@ public class MSQSelectTest extends MSQTestBase
@Test
public void testGroupByWithMultiValue()
{
Map<String, Object> context = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put("groupByEnableMultiValueUnnesting", true)
.build();
Map<String, Object> localContext = enableMultiValueUnnesting(context, true);
RowSignature rowSignature = RowSignature.builder()
.add("dim3", ColumnType.STRING)
.add("cnt1", ColumnType.LONG)
@ -963,7 +1007,7 @@ public class MSQSelectTest extends MSQTestBase
testSelectQuery()
.setSql("select dim3, count(*) as cnt1 from foo group by dim3")
.setQueryContext(context)
.setQueryContext(localContext)
.setExpectedMSQSpec(
MSQSpec.builder()
.query(
@ -981,7 +1025,7 @@ public class MSQSelectTest extends MSQTestBase
)
.setAggregatorSpecs(aggregators(new CountAggregatorFactory(
"a0")))
.setContext(context)
.setContext(localContext)
.build()
)
.columnMappings(
@ -1003,18 +1047,20 @@ public class MSQSelectTest extends MSQTestBase
@Test
public void testGroupByWithMultiValueWithoutGroupByEnable()
{
Map<String, Object> context = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put("groupByEnableMultiValueUnnesting", false)
.build();
Map<String, Object> localContext = enableMultiValueUnnesting(context, false);
testSelectQuery()
.setSql("select dim3, count(*) as cnt1 from foo group by dim3")
.setQueryContext(context)
.setQueryContext(localContext)
.setExpectedExecutionErrorMatcher(CoreMatchers.allOf(
CoreMatchers.instanceOf(ISE.class),
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
"Encountered multi-value dimension [dim3] that cannot be processed with 'groupByEnableMultiValueUnnesting' set to false."))
ThrowableMessageMatcher.hasMessage(
!FAULT_TOLERANCE.equals(contextName)
? CoreMatchers.containsString(
"Encountered multi-value dimension [dim3] that cannot be processed with 'groupByEnableMultiValueUnnesting' set to false.")
:
CoreMatchers.containsString("exceeded max relaunch count")
)
))
.verifyExecutionError();
}
@ -1022,10 +1068,7 @@ public class MSQSelectTest extends MSQTestBase
@Test
public void testGroupByWithMultiValueMvToArray()
{
Map<String, Object> context = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put("groupByEnableMultiValueUnnesting", true)
.build();
Map<String, Object> localContext = enableMultiValueUnnesting(context, true);
RowSignature rowSignature = RowSignature.builder()
.add("EXPR$0", ColumnType.STRING_ARRAY)
@ -1034,7 +1077,7 @@ public class MSQSelectTest extends MSQTestBase
testSelectQuery()
.setSql("select MV_TO_ARRAY(dim3), count(*) as cnt1 from foo group by dim3")
.setQueryContext(context)
.setQueryContext(localContext)
.setExpectedMSQSpec(MSQSpec.builder()
.query(GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
@ -1058,7 +1101,7 @@ public class MSQSelectTest extends MSQTestBase
)
)
)
.setContext(context)
.setContext(localContext)
.build()
)
.columnMappings(
@ -1079,10 +1122,7 @@ public class MSQSelectTest extends MSQTestBase
@Test
public void testGroupByArrayWithMultiValueMvToArray()
{
Map<String, Object> context = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put("groupByEnableMultiValueUnnesting", true)
.build();
Map<String, Object> localContext = enableMultiValueUnnesting(context, true);
RowSignature rowSignature = RowSignature.builder()
.add("EXPR$0", ColumnType.STRING_ARRAY)
@ -1102,7 +1142,7 @@ public class MSQSelectTest extends MSQTestBase
testSelectQuery()
.setSql("select MV_TO_ARRAY(dim3), count(*) as cnt1 from foo group by MV_TO_ARRAY(dim3)")
.setQueryContext(context)
.setQueryContext(localContext)
.setExpectedMSQSpec(MSQSpec.builder()
.query(GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
@ -1126,7 +1166,7 @@ public class MSQSelectTest extends MSQTestBase
)
)
.setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0")))
.setContext(context)
.setContext(localContext)
.build()
)
.columnMappings(
@ -1144,21 +1184,24 @@ public class MSQSelectTest extends MSQTestBase
.verifyResults();
}
@Test
public void testGroupByWithMultiValueMvToArrayWithoutGroupByEnable()
{
Map<String, Object> context = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put("groupByEnableMultiValueUnnesting", false)
.build();
Map<String, Object> localContext = enableMultiValueUnnesting(context, false);
testSelectQuery()
.setSql("select MV_TO_ARRAY(dim3), count(*) as cnt1 from foo group by dim3")
.setQueryContext(context)
.setQueryContext(localContext)
.setExpectedExecutionErrorMatcher(CoreMatchers.allOf(
CoreMatchers.instanceOf(ISE.class),
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
"Encountered multi-value dimension [dim3] that cannot be processed with 'groupByEnableMultiValueUnnesting' set to false."))
ThrowableMessageMatcher.hasMessage(
!FAULT_TOLERANCE.equals(contextName)
? CoreMatchers.containsString(
"Encountered multi-value dimension [dim3] that cannot be processed with 'groupByEnableMultiValueUnnesting' set to false.")
:
CoreMatchers.containsString("exceeded max relaunch count")
)
))
.verifyExecutionError();
}
@ -1186,12 +1229,12 @@ public class MSQSelectTest extends MSQTestBase
)
)
)
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build();
testSelectQuery()
.setSql("select __time, count(dim3) as cnt1 from foo group by __time")
.setQueryContext(DEFAULT_MSQ_CONTEXT)
.setQueryContext(context)
.setExpectedMSQSpec(MSQSpec.builder()
.query(expectedQuery)
.columnMappings(
@ -1224,8 +1267,10 @@ public class MSQSelectTest extends MSQTestBase
.add("cnt1", ColumnType.LONG)
.build();
testSelectQuery()
.setSql("select cnt,count(*) as cnt1 from foo group by cnt")
.setQueryContext(context)
.setExpectedMSQSpec(MSQSpec.builder()
.query(GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
@ -1241,7 +1286,7 @@ public class MSQSelectTest extends MSQTestBase
))
.setAggregatorSpecs(aggregators(new CountAggregatorFactory(
"a0")))
.setContext(DEFAULT_MSQ_CONTEXT)
.setContext(context)
.build())
.columnMappings(
new ColumnMappings(ImmutableList.of(
@ -1254,13 +1299,15 @@ public class MSQSelectTest extends MSQTestBase
.setExpectedRowSignature(rowSignature)
.setExpectedResultRows(ImmutableList.of(new Object[]{1L, 6L}))
.verifyResults();
File successFile = new File(
localFileStorageDir,
DurableStorageUtils.getSuccessFilePath("query-test-query", 0, 0)
);
if (DURABLE_STORAGE.equals(contextName) || FAULT_TOLERANCE.equals(contextName)) {
new File(
localFileStorageDir,
DurableStorageUtils.getSuccessFilePath("query-test-query", 0, 0)
);
Mockito.verify(localFileStorageConnector, Mockito.times(2))
.write(ArgumentMatchers.endsWith("__success"));
Mockito.verify(localFileStorageConnector, Mockito.times(2))
.write(ArgumentMatchers.endsWith("__success"));
}
}
@Test
@ -1318,6 +1365,7 @@ public class MSQSelectTest extends MSQTestBase
.build())
.setExpectedMSQFault(new CannotParseExternalDataFault(
"Unable to add the row to the frame. Type conversion might be required."))
.setQueryContext(context)
.verifyResults();
}
@ -1354,4 +1402,13 @@ public class MSQSelectTest extends MSQTestBase
));
return expected;
}
private static Map<String, Object> enableMultiValueUnnesting(Map<String, Object> context, boolean value)
{
Map<String, Object> localContext = ImmutableMap.<String, Object>builder()
.putAll(context)
.put("groupByEnableMultiValueUnnesting", value)
.build();
return localContext;
}
}

View File

@ -23,10 +23,12 @@ import com.google.errorprone.annotations.concurrent.GuardedBy;
import org.apache.druid.indexer.TaskLocation;
import org.apache.druid.indexer.TaskState;
import org.apache.druid.indexer.TaskStatus;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.indexing.MSQWorkerTask;
import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.MSQFaultUtils;
import org.apache.druid.msq.indexing.error.TaskStartTimeoutFault;
import org.apache.druid.msq.indexing.error.TooManyColumnsFault;
import org.apache.druid.msq.indexing.error.TooManyWorkersFault;
@ -138,6 +140,24 @@ public class MSQTasksTest
);
}
@Test
public void test_getWorkerFromTaskId()
{
Assert.assertEquals(1, MSQTasks.workerFromTaskId("xxxx-worker1_0"));
Assert.assertEquals(10, MSQTasks.workerFromTaskId("xxxx-worker10_0"));
Assert.assertEquals(0, MSQTasks.workerFromTaskId("xxdsadxx-worker0_0"));
Assert.assertEquals(90, MSQTasks.workerFromTaskId("dx-worker90_0"));
Assert.assertEquals(9, MSQTasks.workerFromTaskId("12dsa1-worker9_0"));
Assert.assertThrows(ISE.class, () -> MSQTasks.workerFromTaskId("xxxx-worker-0"));
Assert.assertThrows(ISE.class, () -> MSQTasks.workerFromTaskId("worker-0"));
Assert.assertThrows(ISE.class, () -> MSQTasks.workerFromTaskId("xxxx-worker1-0"));
Assert.assertThrows(ISE.class, () -> MSQTasks.workerFromTaskId("xxxx-worker0-"));
Assert.assertThrows(ISE.class, () -> MSQTasks.workerFromTaskId("xxxx-worr1_0"));
Assert.assertThrows(ISE.class, () -> MSQTasks.workerFromTaskId("xxxx-worker-1-0"));
Assert.assertThrows(ISE.class, () -> MSQTasks.workerFromTaskId("xx"));
}
@Test
public void test_queryWithoutEnoughSlots_shouldThrowException()
{
@ -150,6 +170,7 @@ public class MSQTasksTest
CONTROLLER_ID,
"foo",
controllerContext,
(task, fault) -> {},
false,
-1L,
TimeUnit.SECONDS.toMillis(5)
@ -162,8 +183,8 @@ public class MSQTasksTest
}
catch (Exception e) {
Assert.assertEquals(
new TaskStartTimeoutFault(numTasks + 1).getCodeWithMessage(),
((MSQException) e.getCause()).getFault().getCodeWithMessage()
MSQFaultUtils.generateMessageWithErrorCode(new TaskStartTimeoutFault(numTasks + 1)),
MSQFaultUtils.generateMessageWithErrorCode(((MSQException) e.getCause()).getFault())
);
}
}
@ -221,7 +242,7 @@ public class MSQTasksTest
}
@Override
public synchronized String run(String controllerId, MSQWorkerTask task)
public synchronized String run(String taskId, MSQWorkerTask task)
{
allTasks.add(task.getId());

View File

@ -40,14 +40,14 @@ public class WorkerImplTest
@Test
public void testFetchStatsThrows()
{
WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>()), workerContext);
WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>(), 0), workerContext);
Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshot(new StageId("xx", 1)));
}
@Test
public void testFetchStatsWithTimeChunkThrows()
{
WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>()), workerContext);
WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>(), 0), workerContext);
Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshotForTimeChunk(new StageId("xx", 1), 1L));
}

View File

@ -1,160 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.exec;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.util.Collections;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class WorkerSketchFetcherAutoModeTest
{
@Mock
private CompleteKeyStatisticsInformation completeKeyStatisticsInformation;
@Mock
private StageDefinition stageDefinition;
@Mock
private ClusterBy clusterBy;
private AutoCloseable mocks;
private WorkerSketchFetcher target;
@Before
public void setUp()
{
mocks = MockitoAnnotations.openMocks(this);
target = spy(new WorkerSketchFetcher(mock(WorkerClient.class), ClusterStatisticsMergeMode.AUTO, 300_000_000));
// Don't actually try to fetch sketches
doReturn(null).when(target).inMemoryFullSketchMerging(any(), any(), any());
doReturn(null).when(target).sequentialTimeChunkMerging(any(), any(), any());
doReturn(StageId.fromString("1_1")).when(stageDefinition).getId();
doReturn(clusterBy).when(stageDefinition).getClusterBy();
}
@After
public void tearDown() throws Exception
{
mocks.close();
}
@Test
public void test_submitFetcherTask_belowThresholds_ShouldBeParallel()
{
// Bytes below threshold
doReturn(10.0).when(completeKeyStatisticsInformation).getBytesRetained();
// Cluster by bucket count not 0
doReturn(1).when(clusterBy).getBucketByCount();
// Worker count below threshold
doReturn(1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(
completeKeyStatisticsInformation,
Collections.emptyList(),
stageDefinition,
IntSet.of()
);
verify(target, times(1)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(0)).sequentialTimeChunkMerging(any(), any(), any());
}
@Test
public void test_submitFetcherTask_workerCountAboveThreshold_shouldBeSequential()
{
// Bytes below threshold
doReturn(10.0).when(completeKeyStatisticsInformation).getBytesRetained();
// Cluster by bucket count not 0
doReturn(1).when(clusterBy).getBucketByCount();
// Worker count below threshold
doReturn((int) WorkerSketchFetcher.WORKER_THRESHOLD + 1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(
completeKeyStatisticsInformation,
Collections.emptyList(),
stageDefinition,
IntSet.of()
);
verify(target, times(0)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(1)).sequentialTimeChunkMerging(any(), any(), any());
}
@Test
public void test_submitFetcherTask_noClusterByColumns_shouldBeParallel()
{
// Bytes above threshold
doReturn(WorkerSketchFetcher.BYTES_THRESHOLD + 10.0).when(completeKeyStatisticsInformation).getBytesRetained();
// Cluster by bucket count 0
doReturn(ClusterBy.none()).when(stageDefinition).getClusterBy();
// Worker count above threshold
doReturn((int) WorkerSketchFetcher.WORKER_THRESHOLD + 1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(
completeKeyStatisticsInformation,
Collections.emptyList(),
stageDefinition,
IntSet.of()
);
verify(target, times(1)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(0)).sequentialTimeChunkMerging(any(), any(), any());
}
@Test
public void test_submitFetcherTask_bytesRetainedAboveThreshold_shouldBeSequential()
{
// Bytes above threshold
doReturn(WorkerSketchFetcher.BYTES_THRESHOLD + 10.0).when(completeKeyStatisticsInformation).getBytesRetained();
// Cluster by bucket count not 0
doReturn(1).when(clusterBy).getBucketByCount();
// Worker count below threshold
doReturn(1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(
completeKeyStatisticsInformation,
Collections.emptyList(),
stageDefinition,
IntSet.of()
);
verify(target, times(0)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(1)).sequentialTimeChunkMerging(any(), any(), any());
}
}

View File

@ -23,16 +23,11 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.util.concurrent.Futures;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.junit.After;
@ -43,71 +38,51 @@ import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.SortedMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import static org.easymock.EasyMock.mock;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class WorkerSketchFetcherTest
{
@Mock
private CompleteKeyStatisticsInformation completeKeyStatisticsInformation;
@Mock
private MSQWorkerTaskLauncher workerTaskLauncher;
@Mock
private ControllerQueryKernel kernel;
@Mock
private StageDefinition stageDefinition;
@Mock
private ClusterBy clusterBy;
@Mock
private ClusterByStatisticsCollector mergedClusterByStatisticsCollector1;
@Mock
private ClusterByStatisticsCollector mergedClusterByStatisticsCollector2;
@Mock
private WorkerClient workerClient;
private ClusterByPartitions expectedPartitions1;
private ClusterByPartitions expectedPartitions2;
private AutoCloseable mocks;
private WorkerSketchFetcher target;
private static final String TASK_0 = "task-worker0_0";
private static final String TASK_1 = "task-worker1_0";
private static final String TASK_2 = "task-worker2_1";
private static final List<String> TASK_IDS = ImmutableList.of(TASK_0, TASK_1, TASK_2);
@Before
public void setUp()
{
mocks = MockitoAnnotations.openMocks(this);
doReturn(StageId.fromString("1_1")).when(stageDefinition).getId();
doReturn(clusterBy).when(stageDefinition).getClusterBy();
doReturn(25_000).when(stageDefinition).getMaxPartitionCount();
expectedPartitions1 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(
mock(RowKey.class),
mock(RowKey.class)
)));
expectedPartitions2 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(
mock(RowKey.class),
mock(RowKey.class)
)));
doReturn(ImmutableSortedMap.of(123L, ImmutableSet.of(1, 2))).when(completeKeyStatisticsInformation)
.getTimeSegmentVsWorkerMap();
doReturn(Either.value(expectedPartitions1)).when(stageDefinition)
.generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector1));
doReturn(Either.value(expectedPartitions2)).when(stageDefinition)
.generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector2));
doReturn(
mergedClusterByStatisticsCollector1,
mergedClusterByStatisticsCollector2
).when(stageDefinition).createResultKeyStatisticsCollector(anyInt());
doReturn(true).when(workerTaskLauncher).isTaskLatest(any());
}
@After
@ -120,104 +95,279 @@ public class WorkerSketchFetcherTest
}
@Test
public void test_submitFetcherTask_parallelFetch_mergePerformedCorrectly()
throws ExecutionException, InterruptedException
public void test_submitFetcherTask_parallelFetch() throws InterruptedException
{
// Store snapshots in a queue
final Queue<ClusterByStatisticsSnapshot> snapshotQueue = new ConcurrentLinkedQueue<>();
final List<String> workerIds = ImmutableList.of("0", "1", "2", "3", "4");
final CountDownLatch latch = new CountDownLatch(workerIds.size());
target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.PARALLEL, 300_000_000));
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size());
target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true));
// When fetching snapshots, return a mock and add it to queue
doAnswer(invocation -> {
ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class);
snapshotQueue.add(snapshot);
latch.countDown();
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt());
IntSet workersForStage = new IntAVLTreeSet();
workersForStage.addAll(ImmutableSet.of(0, 1, 2, 3, 4));
target.inMemoryFullSketchMerging((kernelConsumer) -> {
kernelConsumer.accept(kernel);
latch.countDown();
}, stageDefinition.getId(), ImmutableSet.copyOf(TASK_IDS), ((queryKernel, integer, msqFault) -> {}));
CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
workerIds,
stageDefinition,
workersForStage
);
latch.await(5, TimeUnit.SECONDS);
Assert.assertEquals(0, latch.getCount());
// Assert that the final result is complete and all other sketches returned have been merged.
eitherCompletableFuture.join();
Thread.sleep(1000);
Assert.assertTrue(eitherCompletableFuture.isDone() && !eitherCompletableFuture.isCompletedExceptionally());
Assert.assertFalse(snapshotQueue.isEmpty());
// Verify that all statistics were added to controller.
for (ClusterByStatisticsSnapshot snapshot : snapshotQueue) {
verify(mergedClusterByStatisticsCollector1, times(1)).addAll(eq(snapshot));
}
// Check that the partitions returned by the merged collector is returned by the final future.
Assert.assertEquals(expectedPartitions1, eitherCompletableFuture.get().valueOrThrow());
}
@Test
public void test_submitFetcherTask_sequentialFetch_mergePerformedCorrectly()
throws ExecutionException, InterruptedException
public void test_submitFetcherTask_sequentialFetch() throws InterruptedException
{
// Store snapshots in a queue
final Queue<ClusterByStatisticsSnapshot> snapshotQueue = new ConcurrentLinkedQueue<>();
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size() - 1);
SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap = ImmutableSortedMap.of(
1L,
ImmutableSet.of(0, 1, 2),
2L,
ImmutableSet.of(0, 1, 4)
);
doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap();
final CyclicBarrier barrier = new CyclicBarrier(3);
target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.SEQUENTIAL, 300_000_000));
target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true));
// When fetching snapshots, return a mock and add it to queue
doAnswer(invocation -> {
ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class);
snapshotQueue.add(snapshot);
barrier.await();
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyInt(), anyLong());
IntSet workersForStage = new IntAVLTreeSet();
workersForStage.addAll(ImmutableSet.of(0, 1, 2, 3, 4));
CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
target.sequentialTimeChunkMerging(
(kernelConsumer) -> {
kernelConsumer.accept(kernel);
latch.countDown();
},
completeKeyStatisticsInformation,
ImmutableList.of("0", "1", "2", "3", "4"),
stageDefinition,
workersForStage
stageDefinition.getId(),
ImmutableSet.copyOf(TASK_IDS),
((queryKernel, integer, msqFault) -> {})
);
// Assert that the final result is complete and all other sketches returned have been merged.
eitherCompletableFuture.join();
Thread.sleep(1000);
Assert.assertTrue(eitherCompletableFuture.isDone() && !eitherCompletableFuture.isCompletedExceptionally());
Assert.assertFalse(snapshotQueue.isEmpty());
// Verify that all statistics were added to controller.
snapshotQueue.stream().limit(3).forEach(snapshot -> {
verify(mergedClusterByStatisticsCollector1, times(1)).addAll(eq(snapshot));
});
snapshotQueue.stream().skip(3).limit(3).forEach(snapshot -> {
verify(mergedClusterByStatisticsCollector2, times(1)).addAll(eq(snapshot));
});
ClusterByPartitions expectedResult =
new ClusterByPartitions(
ImmutableList.of(
new ClusterByPartition(expectedPartitions1.get(0).getStart(), expectedPartitions2.get(0).getStart()),
new ClusterByPartition(expectedPartitions2.get(0).getStart(), expectedPartitions2.get(0).getEnd())
)
);
// Check that the partitions returned by the merged collector is returned by the final future.
Assert.assertEquals(expectedResult, eitherCompletableFuture.get().valueOrThrow());
latch.await(5, TimeUnit.SECONDS);
Assert.assertEquals(0, latch.getCount());
}
@Test
public void test_sequentialMerge_nonCompleteInformation()
{
doReturn(false).when(completeKeyStatisticsInformation).isComplete();
target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true));
Assert.assertThrows(ISE.class, () -> target.sequentialTimeChunkMerging(
(ignore) -> {},
completeKeyStatisticsInformation,
stageDefinition.getId(),
ImmutableSet.of(""),
((queryKernel, integer, msqFault) -> {})
));
}
@Test
public void test_inMemoryRetryEnabled_retryInvoked() throws InterruptedException
{
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size());
target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true));
workersWithFailedFetchParallel(ImmutableSet.of(TASK_1));
CountDownLatch retryLatch = new CountDownLatch(1);
target.inMemoryFullSketchMerging(
(kernelConsumer) -> {
kernelConsumer.accept(kernel);
latch.countDown();
},
stageDefinition.getId(),
ImmutableSet.copyOf(TASK_IDS),
((queryKernel, integer, msqFault) -> {
if (integer.equals(1) && msqFault.getErrorMessage().contains(TASK_1)) {
retryLatch.countDown();
}
})
);
latch.await(5, TimeUnit.SECONDS);
retryLatch.await(5, TimeUnit.SECONDS);
Assert.assertEquals(0, latch.getCount());
Assert.assertEquals(0, retryLatch.getCount());
}
@Test
public void test_SequentialRetryEnabled_retryInvoked() throws InterruptedException
{
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size() - 1);
target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true));
workersWithFailedFetchSequential(ImmutableSet.of(TASK_1));
CountDownLatch retryLatch = new CountDownLatch(1);
target.sequentialTimeChunkMerging(
(kernelConsumer) -> {
kernelConsumer.accept(kernel);
latch.countDown();
},
completeKeyStatisticsInformation,
stageDefinition.getId(),
ImmutableSet.copyOf(TASK_IDS),
((queryKernel, integer, msqFault) -> {
if (integer.equals(1) && msqFault.getErrorMessage().contains(TASK_1)) {
retryLatch.countDown();
}
})
);
latch.await(5, TimeUnit.SECONDS);
retryLatch.await(5, TimeUnit.SECONDS);
Assert.assertEquals(0, latch.getCount());
Assert.assertEquals(0, retryLatch.getCount());
}
@Test
public void test_InMemoryRetryDisabled_multipleFailures() throws InterruptedException
{
target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false));
workersWithFailedFetchParallel(ImmutableSet.of(TASK_1, TASK_0));
try {
target.inMemoryFullSketchMerging(
(kernelConsumer) -> kernelConsumer.accept(kernel),
stageDefinition.getId(),
ImmutableSet.copyOf(TASK_IDS),
((queryKernel, integer, msqFault) -> {
throw new ISE("Should not be here");
})
);
}
catch (Exception e) {
Assert.assertTrue(e.getMessage().contains("Task fetch failed"));
}
while (!target.executorService.isShutdown()) {
Thread.sleep(100);
}
Assert.assertTrue((target.getError().getMessage().contains("Task fetch failed")));
}
@Test
public void test_InMemoryRetryDisabled_singleFailure() throws InterruptedException
{
target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false));
workersWithFailedFetchParallel(ImmutableSet.of(TASK_1));
try {
target.inMemoryFullSketchMerging(
(kernelConsumer) -> kernelConsumer.accept(kernel),
stageDefinition.getId(),
ImmutableSet.copyOf(TASK_IDS),
((queryKernel, integer, msqFault) -> {
throw new ISE("Should not be here");
})
);
}
catch (Exception e) {
Assert.assertTrue(e.getMessage().contains("Task fetch failed"));
}
while (!target.executorService.isShutdown()) {
Thread.sleep(100);
}
Assert.assertTrue((target.getError().getMessage().contains("Task fetch failed")));
}
@Test
public void test_SequentialRetryDisabled_multipleFailures() throws InterruptedException
{
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false));
workersWithFailedFetchSequential(ImmutableSet.of(TASK_1, TASK_0));
try {
target.sequentialTimeChunkMerging(
(kernelConsumer) -> {
kernelConsumer.accept(kernel);
},
completeKeyStatisticsInformation,
stageDefinition.getId(),
ImmutableSet.copyOf(TASK_IDS),
((queryKernel, integer, msqFault) -> {
throw new ISE("Should not be here");
})
);
}
catch (Exception e) {
Assert.assertTrue(e.getMessage().contains("Task fetch failed"));
}
while (!target.executorService.isShutdown()) {
Thread.sleep(100);
}
Assert.assertTrue(target.getError().getMessage().contains("Task fetch failed"));
}
@Test
public void test_SequentialRetryDisabled_singleFailure() throws InterruptedException
{
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false));
workersWithFailedFetchSequential(ImmutableSet.of(TASK_1));
try {
target.sequentialTimeChunkMerging(
(kernelConsumer) -> {
kernelConsumer.accept(kernel);
},
completeKeyStatisticsInformation,
stageDefinition.getId(),
ImmutableSet.copyOf(TASK_IDS),
((queryKernel, integer, msqFault) -> {
throw new ISE("Should not be here");
})
);
}
catch (Exception e) {
Assert.assertTrue(e.getMessage().contains("Task fetch failed"));
}
while (!target.executorService.isShutdown()) {
Thread.sleep(100);
}
Assert.assertTrue(target.getError().getMessage().contains(TASK_1));
}
private void workersWithFailedFetchSequential(Set<String> failedTasks)
{
doAnswer(invocation -> {
ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class);
if (failedTasks.contains((String) invocation.getArgument(0))) {
return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0)));
}
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyInt(), anyLong());
}
private void workersWithFailedFetchParallel(Set<String> failedTasks)
{
doAnswer(invocation -> {
ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class);
if (failedTasks.contains((String) invocation.getArgument(0))) {
return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0)));
}
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt());
}
}

View File

@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.indexing;
import com.google.common.collect.ImmutableMap;
import org.junit.Assert;
import org.junit.Test;
import java.util.HashSet;
import java.util.Set;
public class MSQWorkerTaskTest
{
private final String controllerTaskId = "ctr";
private final String dataSource = "ds";
private final int workerNumber = 1;
private final ImmutableMap<String, Object> context = ImmutableMap.of("key", "val");
private final int retryCount = 0;
private final MSQWorkerTask msqWorkerTask = new MSQWorkerTask(
controllerTaskId,
dataSource,
workerNumber,
context,
retryCount
);
@Test
public void testEquals()
{
Assert.assertNotEquals(msqWorkerTask, 0);
Assert.assertEquals(msqWorkerTask, msqWorkerTask);
Assert.assertEquals(
msqWorkerTask,
new MSQWorkerTask(controllerTaskId, dataSource, workerNumber, context, retryCount)
);
Assert.assertEquals(
msqWorkerTask.getRetryTask(),
new MSQWorkerTask(controllerTaskId, dataSource, workerNumber, context, retryCount + 1)
);
Assert.assertNotEquals(msqWorkerTask, msqWorkerTask.getRetryTask());
}
@Test
public void testHashCode()
{
Set<MSQWorkerTask> msqWorkerTaskSet = new HashSet<>();
msqWorkerTaskSet.add(msqWorkerTask);
msqWorkerTaskSet.add(new MSQWorkerTask(controllerTaskId, dataSource, workerNumber, context, retryCount));
Assert.assertTrue(msqWorkerTaskSet.size() == 1);
msqWorkerTaskSet.add(msqWorkerTask.getRetryTask());
Assert.assertTrue(msqWorkerTaskSet.size() == 2);
msqWorkerTaskSet.add(new MSQWorkerTask(controllerTaskId + 1, dataSource, workerNumber, context, retryCount));
Assert.assertTrue(msqWorkerTaskSet.size() == 3);
msqWorkerTaskSet.add(new MSQWorkerTask(controllerTaskId, dataSource + 1, workerNumber, context, retryCount));
Assert.assertTrue(msqWorkerTaskSet.size() == 4);
msqWorkerTaskSet.add(new MSQWorkerTask(controllerTaskId, dataSource, workerNumber + 1, context, retryCount));
Assert.assertTrue(msqWorkerTaskSet.size() == 5);
msqWorkerTaskSet.add(new MSQWorkerTask(
controllerTaskId,
dataSource,
workerNumber,
ImmutableMap.of("key1", "v1"),
retryCount
));
Assert.assertTrue(msqWorkerTaskSet.size() == 6);
msqWorkerTaskSet.add(new MSQWorkerTask(controllerTaskId, dataSource, workerNumber, context, retryCount + 1));
Assert.assertTrue(msqWorkerTaskSet.size() == 6);
}
@Test
public void testGetter()
{
Assert.assertEquals(controllerTaskId, msqWorkerTask.getControllerTaskId());
Assert.assertEquals(dataSource, msqWorkerTask.getDataSource());
Assert.assertEquals(workerNumber, msqWorkerTask.getWorkerNumber());
Assert.assertEquals(retryCount, msqWorkerTask.getRetryCount());
}
}

View File

@ -158,7 +158,7 @@ public class WorkerChatHandlerTest
@Override
public MSQWorkerTask task()
{
return new MSQWorkerTask("controller", "ds", 1, new HashMap<>());
return new MSQWorkerTask("controller", "ds", 1, new HashMap<>(), 0);
}
@Override

View File

@ -75,6 +75,8 @@ public class MSQFaultSerdeTest
assertFaultSerde(new TooManyPartitionsFault(10));
assertFaultSerde(new TooManyWarningsFault(10, "the error"));
assertFaultSerde(new TooManyWorkersFault(10, 5));
assertFaultSerde(new TooManyAttemptsForWorker(2, "taskId", 1, "rootError"));
assertFaultSerde(new TooManyAttemptsForJob(2, 2, "taskId", "rootError"));
assertFaultSerde(UnknownFault.forMessage(null));
assertFaultSerde(UnknownFault.forMessage("the message"));
assertFaultSerde(new WorkerFailedFault("the worker task", "the error msg"));

View File

@ -339,7 +339,9 @@ public class MSQWarningsTest extends MSQTestBase
+ " '[{\"name\": \"timestamp\", \"type\": \"string\"}, {\"name\": \"page\", \"type\": \"string\"}, {\"name\": \"user\", \"type\": \"string\"}]'\n"
+ " )\n"
+ ") group by 1 PARTITIONED by day ")
.setQueryContext(ROLLUP_CONTEXT)
.setQueryContext(new ImmutableMap.Builder<String, Object>().putAll(DEFAULT_MSQ_CONTEXT)
.putAll(ROLLUP_CONTEXT_PARAMS)
.build())
.setExpectedRollUp(true)
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)

View File

@ -19,7 +19,18 @@
package org.apache.druid.msq.kernel;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.SortColumn;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.input.stage.StageInputSpec;
import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.junit.Assert;
import org.junit.Test;
public class StageDefinitionTest
@ -32,4 +43,62 @@ public class StageDefinitionTest
.usingGetClass()
.verify();
}
@Test
public void testGeneratePartitionsForNullShuffle()
{
StageDefinition stageDefinition = new StageDefinition(
new StageId("query", 1),
ImmutableList.of(new StageInputSpec(0)),
ImmutableSet.of(),
new OffsetLimitFrameProcessorFactory(0, 1L),
RowSignature.empty(),
null,
0,
false
);
Assert.assertThrows(ISE.class, () -> stageDefinition.generatePartitionsForShuffle(null));
}
@Test
public void testGeneratePartitionsForNonNullShuffleWithNullCollector()
{
StageDefinition stageDefinition = new StageDefinition(
new StageId("query", 1),
ImmutableList.of(new StageInputSpec(0)),
ImmutableSet.of(),
new OffsetLimitFrameProcessorFactory(0, 1L),
RowSignature.empty(),
new MaxCountShuffleSpec(new ClusterBy(ImmutableList.of(new SortColumn("test", false)), 1), 2, false),
1,
false
);
Assert.assertThrows(ISE.class, () -> stageDefinition.generatePartitionsForShuffle(null));
}
@Test
public void testGeneratePartitionsForNonNullShuffleWithNonNullCollector()
{
StageDefinition stageDefinition = new StageDefinition(
new StageId("query", 1),
ImmutableList.of(new StageInputSpec(0)),
ImmutableSet.of(),
new OffsetLimitFrameProcessorFactory(0, 1L),
RowSignature.empty(),
new MaxCountShuffleSpec(new ClusterBy(ImmutableList.of(new SortColumn("test", false)), 0), 1, false),
1,
false
);
Assert.assertThrows(
ISE.class,
() -> stageDefinition.generatePartitionsForShuffle(ClusterByStatisticsCollectorImpl.create(new ClusterBy(
ImmutableList.of(new SortColumn("test", false)),
1
), RowSignature.builder().add("test", ColumnType.STRING).build(), 1000, 100, false, false))
);
}
}

View File

@ -21,31 +21,36 @@ package org.apache.druid.msq.kernel.controller;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.frame.key.ClusterByPartitions;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.frame.key.KeyTestUtils;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.indexing.error.UnknownFault;
import org.apache.druid.msq.input.InputSpecSlicerFactory;
import org.apache.druid.msq.input.MapInputSpecSlicer;
import org.apache.druid.msq.input.stage.StageInputSpec;
import org.apache.druid.msq.input.stage.StageInputSpecSlicer;
import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import javax.annotation.Nonnull;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
{
public static final UnknownFault RETRIABLE_FAULT = UnknownFault.forMessage("");
public ControllerQueryKernelTester testControllerQueryKernel(int numWorkers)
{
@ -82,7 +87,11 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
public ControllerQueryKernelTester queryDefinition(QueryDefinition queryDefinition)
{
this.queryDefinition = Preconditions.checkNotNull(queryDefinition);
this.controllerQueryKernel = new ControllerQueryKernel(queryDefinition);
this.controllerQueryKernel = new ControllerQueryKernel(
queryDefinition,
100_000_000,
true
);
return this;
}
@ -110,25 +119,54 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
createAndGetNewStageNumbers(false);
// Initial phase would always be new as we can call this method only once for each
StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber);
switch (controllerStagePhase) {
case NEW:
break;
case READING_INPUT:
controllerQueryKernel.startStage(new StageId(queryDefinition.getQueryId(), stageNumber));
controllerQueryKernel.createWorkOrders(stageId.getStageNumber(), null);
controllerQueryKernel.startStage(stageId);
for (int i = 0; i < queryDefinition.getStageDefinition(stageId).getMaxWorkerCount(); ++i) {
controllerQueryKernel.workOrdersSentForWorker(stageId, i);
}
break;
case POST_READING:
case MERGING_STATISTICS:
setupStage(stageNumber, ControllerStagePhase.READING_INPUT, true);
final ClusterByStatisticsCollector collector = getMockCollector(
stageNumber);
for (int i = 0; i < queryDefinition.getStageDefinition(stageId).getMaxWorkerCount(); ++i) {
controllerQueryKernel.addPartialKeyStatisticsForStageAndWorker(
stageId,
i,
collector.snapshot().partialKeyStatistics()
);
controllerQueryKernel.startFetchingStatsFromWorker(stageId, ImmutableSet.of(i));
}
for (int i = 0; i < queryDefinition.getStageDefinition(stageId).getMaxWorkerCount(); ++i) {
controllerQueryKernel.mergeClusterByStatisticsCollectorForAllTimeChunks(
stageId,
i,
collector.snapshot()
);
}
break;
case POST_READING:
if (queryDefinition.getStageDefinition(stageNumber).mustGatherResultKeyStatistics()) {
for (int i = 0; i < numWorkers; ++i) {
controllerQueryKernel.addPartialKeyStatisticsForStageAndWorker(
new StageId(queryDefinition.getQueryId(), stageNumber),
i,
ClusterByStatisticsSnapshot.empty().partialKeyStatistics()
);
setupStage(stageNumber, ControllerStagePhase.MERGING_STATISTICS, true);
for (int i = 0; i < queryDefinition.getStageDefinition(stageId).getMaxWorkerCount(); ++i) {
controllerQueryKernel.partitionBoundariesSentForWorker(stageId, i);
}
} else {
throw new IAE("Stage %d doesn't gather key result statistics", stageNumber);
}
@ -141,9 +179,9 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
} else {
setupStage(stageNumber, ControllerStagePhase.READING_INPUT, true);
}
for (int i = 0; i < numWorkers; ++i) {
for (int i = 0; i < queryDefinition.getStageDefinition(stageId).getMaxWorkerCount(); ++i) {
controllerQueryKernel.setResultsCompleteForStageAndWorker(
new StageId(queryDefinition.getQueryId(), stageNumber),
stageId,
i,
new Object()
);
@ -152,11 +190,11 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
case FINISHED:
setupStage(stageNumber, ControllerStagePhase.RESULTS_READY, true);
controllerQueryKernel.finishStage(new StageId(queryDefinition.getQueryId(), stageNumber), false);
controllerQueryKernel.finishStage(stageId, false);
break;
case FAILED:
controllerQueryKernel.failStage(new StageId(queryDefinition.getQueryId(), stageNumber));
controllerQueryKernel.failStage(stageId);
break;
}
if (!recursiveCall) {
@ -165,6 +203,7 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
return this;
}
public ControllerQueryKernelTester init()
{
@ -225,7 +264,18 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
public void startStage(int stageNumber)
{
Preconditions.checkArgument(initialized);
controllerQueryKernel.createWorkOrders(stageNumber, null);
controllerQueryKernel.startStage(new StageId(queryDefinition.getQueryId(), stageNumber));
}
public void startWorkOrder(int stageNumber)
{
StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber);
Preconditions.checkArgument(initialized);
IntStream.range(0, queryDefinition.getStageDefinition(stageId).getMaxWorkerCount())
.forEach(n -> controllerQueryKernel.workOrdersSentForWorker(stageId, n));
}
@ -240,50 +290,56 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
controllerQueryKernel.finishStage(new StageId(queryDefinition.getQueryId(), stageNumber), strict);
}
public ClusterByStatisticsCollector addResultKeyStatisticsForStageAndWorker(int stageNumber, int workerNumber)
public void addPartialKeyStatsInformation(int stageNumber, int... workers)
{
Preconditions.checkArgument(initialized);
// Simulate 1000 keys being encountered in the data, so the kernel can generate some partitions.
final ClusterByStatisticsCollector keyStatsCollector =
queryDefinition.getStageDefinition(stageNumber).createResultKeyStatisticsCollector(10_000_000);
for (int i = 0; i < 1000; i++) {
final RowKey key = KeyTestUtils.createKey(
MockQueryDefinitionBuilder.STAGE_SIGNATURE,
String.valueOf(i)
final ClusterByStatisticsCollector keyStatsCollector = getMockCollector(stageNumber);
StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber);
for (int worker : workers) {
controllerQueryKernel.addPartialKeyStatisticsForStageAndWorker(
stageId,
worker,
keyStatsCollector.snapshot().partialKeyStatistics()
);
keyStatsCollector.add(key, 1);
}
controllerQueryKernel.addPartialKeyStatisticsForStageAndWorker(
new StageId(queryDefinition.getQueryId(), stageNumber),
workerNumber,
keyStatsCollector.snapshot().partialKeyStatistics()
);
return keyStatsCollector;
}
public void setResultsCompleteForStageAndWorker(int stageNumber, int workerNumber)
public void statsBeingFetchedForWorkers(int stageNumber, Integer... workers)
{
Preconditions.checkArgument(initialized);
controllerQueryKernel.setResultsCompleteForStageAndWorker(
controllerQueryKernel.startFetchingStatsFromWorker(
new StageId(queryDefinition.getQueryId(), stageNumber),
workerNumber,
new Object()
ImmutableSet.copyOf(workers)
);
}
public void setPartitionBoundaries(int stageNumber, ClusterByStatisticsCollector clusterByStatisticsCollector)
public void mergeClusterByStatsForAllTimeChunksForWorkers(int stageNumber, Integer... workers)
{
Preconditions.checkArgument(initialized);
StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber);
StageDefinition stageDefinition = controllerQueryKernel.getStageDefinition(stageId);
ClusterByPartitions clusterByPartitions =
stageDefinition
.generatePartitionsForShuffle(clusterByStatisticsCollector)
.valueOrThrow();
controllerQueryKernel.setClusterByPartitionBoundaries(stageId, clusterByPartitions);
final ClusterByStatisticsCollector keyStatsCollector = getMockCollector(stageNumber);
for (int worker : workers) {
controllerQueryKernel.mergeClusterByStatisticsCollectorForAllTimeChunks(
stageId,
worker,
keyStatsCollector.snapshot()
);
}
}
public void setResultsCompleteForStageAndWorkers(int stageNumber, int... workers)
{
Preconditions.checkArgument(initialized);
final StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber);
for (int worker : workers) {
controllerQueryKernel.setResultsCompleteForStageAndWorker(
stageId,
worker,
new Object()
);
}
}
public void failStage(int stageNumber)
@ -302,16 +358,17 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
if (controllerStageTracker.getPhase() != expectedControllerStagePhase) {
throw new ISE(
StringUtils.format(
"Stage kernel for stage number %d is in %s phase which is different from the expected phase",
"Stage kernel for stage number %d is in %s phase which is different from the expected phase %s",
stageNumber,
controllerStageTracker.getPhase()
controllerStageTracker.getPhase(),
expectedControllerStagePhase
)
);
}
}
/**
* Checks if the state of the BaseControllerQueryKernel is initialized properly. Currently this is just stubbed to
* Checks if the state of the BaseControllerQueryKernel is initialized properly. Currently, this is just stubbed to
* return true irrespective of the actual state
*/
private boolean isValidInitState()
@ -325,5 +382,71 @@ public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest
.map(StageId::getStageNumber)
.collect(Collectors.toSet());
}
public void sendWorkOrdersForWorkers(int stageNumber, int... workers)
{
Preconditions.checkArgument(initialized);
final StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber);
for (int worker : workers) {
controllerQueryKernel.workOrdersSentForWorker(stageId, worker);
}
}
public void sendPartitionBoundariesForStageAndWorkers(int stageNumber, int... workers)
{
Preconditions.checkArgument(initialized);
final StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber);
for (int worker : workers) {
controllerQueryKernel.partitionBoundariesSentForWorker(stageId, worker);
}
}
public void sendPartitionBoundariesForStage(int stageNumber)
{
Preconditions.checkArgument(initialized);
final StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber);
for (int worker : controllerQueryKernel.getWorkersToSendPartitionBoundaries(stageId)) {
controllerQueryKernel.partitionBoundariesSentForWorker(stageId, worker);
}
}
public List<WorkOrder> getRetriableWorkOrdersAndChangeState(int workeNumber, MSQFault msqFault)
{
return controllerQueryKernel.getWorkInCaseWorkerEligibleForRetryElseThrow(workeNumber, msqFault);
}
public void failWorkerAndAssertWorkOrderes(int workeNumber, int retriedStage)
{
// fail one worker
List<WorkOrder> workOrderList = getRetriableWorkOrdersAndChangeState(
workeNumber,
RETRIABLE_FAULT
);
// does not enable the current stage to enable running from start
Assert.assertTrue(createAndGetNewStageNumbers().size() == 0);
// only work order of failed worker should be there
Assert.assertTrue(workOrderList.size() == 1);
Assert.assertTrue(workOrderList.get(0).getWorkerNumber() == workeNumber);
Assert.assertTrue(workOrderList.get(0).getStageNumber() == retriedStage);
}
@Nonnull
private ClusterByStatisticsCollector getMockCollector(int stageNumber)
{
final ClusterByStatisticsCollector keyStatsCollector =
queryDefinition.getStageDefinition(stageNumber).createResultKeyStatisticsCollector(10_000_000);
for (int i = 0; i < 1000; i++) {
final RowKey key = KeyTestUtils.createKey(
MockQueryDefinitionBuilder.STAGE_SIGNATURE,
String.valueOf(i)
);
keyStatsCollector.add(key, 1);
}
return keyStatsCollector;
}
}
}

View File

@ -20,13 +20,15 @@
package org.apache.druid.msq.kernel.controller;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.msq.kernel.worker.WorkerStagePhase;
import org.junit.Assert;
import org.junit.Test;
import java.util.Set;
public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
public class ControllerQueryKernelTest extends BaseControllerQueryKernelTest
{
@Test
@ -147,14 +149,16 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
Assert.assertEquals(ImmutableSet.of(0), newStageNumbers);
Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers);
controllerQueryKernelTester.startStage(0);
ClusterByStatisticsCollector clusterByStatisticsCollector =
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
0,
0
);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0);
controllerQueryKernelTester.addPartialKeyStatsInformation(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.setPartitionBoundaries(0, clusterByStatisticsCollector);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(0, 0);
controllerQueryKernelTester.statsBeingFetchedForWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.mergeClusterByStatsForAllTimeChunksForWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.POST_READING);
controllerQueryKernelTester.sendPartitionBoundariesForStageAndWorkers(0, 0);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers();
@ -162,24 +166,25 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
Assert.assertEquals(ImmutableSet.of(1), newStageNumbers);
Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers);
controllerQueryKernelTester.startStage(1);
clusterByStatisticsCollector =
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
1,
0
);
controllerQueryKernelTester.sendWorkOrdersForWorkers(1, 0);
controllerQueryKernelTester.addPartialKeyStatsInformation(1, 0);
controllerQueryKernelTester.statsBeingFetchedForWorkers(1, 0);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.READING_INPUT);
clusterByStatisticsCollector.addAll(
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
1,
1
)
);
controllerQueryKernelTester.mergeClusterByStatsForAllTimeChunksForWorkers(1, 0);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.sendWorkOrdersForWorkers(1, 1);
controllerQueryKernelTester.addPartialKeyStatsInformation(1, 1);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.setPartitionBoundaries(1, clusterByStatisticsCollector);
controllerQueryKernelTester.statsBeingFetchedForWorkers(1, 1);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.mergeClusterByStatsForAllTimeChunksForWorkers(1, 1);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.POST_READING);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(1, 0);
controllerQueryKernelTester.sendPartitionBoundariesForStageAndWorkers(1, 0);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(1, 0);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.POST_READING);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(1, 1);
controllerQueryKernelTester.sendPartitionBoundariesForStageAndWorkers(1, 1);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(1, 1);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.RESULTS_READY);
newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers();
@ -188,7 +193,8 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
Assert.assertEquals(ImmutableSet.of(0), effectivelyFinishedStageNumbers);
controllerQueryKernelTester.startStage(2);
controllerQueryKernelTester.assertStagePhase(2, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(2, 0);
controllerQueryKernelTester.sendWorkOrdersForWorkers(2, 0);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(2, 0);
controllerQueryKernelTester.assertStagePhase(2, ControllerStagePhase.RESULTS_READY);
controllerQueryKernelTester.finishStage(0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.FINISHED);
@ -199,23 +205,22 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
Assert.assertEquals(ImmutableSet.of(1), effectivelyFinishedStageNumbers);
controllerQueryKernelTester.startStage(3);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.READING_INPUT);
ClusterByStatisticsCollector clusterByStatisticsCollector3 =
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
3,
0
);
controllerQueryKernelTester.startWorkOrder(3);
controllerQueryKernelTester.addPartialKeyStatsInformation(3, 1);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.READING_INPUT);
ClusterByStatisticsCollector clusterByStatisticsCollector4 =
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
3,
1
);
controllerQueryKernelTester.addPartialKeyStatsInformation(3, 0);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.setPartitionBoundaries(3, clusterByStatisticsCollector3.addAll(clusterByStatisticsCollector4));
controllerQueryKernelTester.statsBeingFetchedForWorkers(3, 0, 1);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.mergeClusterByStatsForAllTimeChunksForWorkers(3, 0, 1);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.POST_READING);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(3, 0);
controllerQueryKernelTester.sendPartitionBoundariesForStageAndWorkers(3, 0);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(3, 0);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.POST_READING);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(3, 1);
controllerQueryKernelTester.sendPartitionBoundariesForStageAndWorkers(3, 1);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(3, 1);
controllerQueryKernelTester.assertStagePhase(3, ControllerStagePhase.RESULTS_READY);
controllerQueryKernelTester.finishStage(1);
controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.FINISHED);
@ -243,27 +248,24 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
controllerQueryKernelTester.createAndGetNewStageNumbers();
controllerQueryKernelTester.startStage(0);
ClusterByStatisticsCollector clusterByStatisticsCollector =
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
0,
0
);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0);
controllerQueryKernelTester.addPartialKeyStatsInformation(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
clusterByStatisticsCollector.addAll(
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(
0,
1
)
);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 1);
controllerQueryKernelTester.addPartialKeyStatsInformation(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.setPartitionBoundaries(0, clusterByStatisticsCollector);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(0, 0);
controllerQueryKernelTester.statsBeingFetchedForWorkers(0, 0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.MERGING_STATISTICS);
controllerQueryKernelTester.mergeClusterByStatsForAllTimeChunksForWorkers(0, 0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.POST_READING);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(0, 1);
controllerQueryKernelTester.sendPartitionBoundariesForStageAndWorkers(0, 0);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.POST_READING);
controllerQueryKernelTester.sendPartitionBoundariesForStageAndWorkers(0, 1);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
controllerQueryKernelTester.finishStage(0, false);
@ -287,13 +289,23 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
controllerQueryKernelTester.createAndGetNewStageNumbers();
controllerQueryKernelTester.startStage(0);
controllerQueryKernelTester.addResultKeyStatisticsForStageAndWorker(0, 0);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0);
controllerQueryKernelTester.addPartialKeyStatsInformation(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorker(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
Assert.assertThrows(
StringUtils.format(
"Worker[%d] for stage[%d] expected to be in state[%s]. Found state[%s]",
1,
0,
WorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES,
WorkerStagePhase.NEW
),
ISE.class,
() -> controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0)
);
}
@Test
@ -414,7 +426,8 @@ public class ControllerQueryKernelTests extends BaseControllerQueryKernelTest
private static void transitionNewToResultsComplete(ControllerQueryKernelTester queryKernelTester, int stageNumber)
{
queryKernelTester.startStage(stageNumber);
queryKernelTester.setResultsCompleteForStageAndWorker(stageNumber, 0);
queryKernelTester.startWorkOrder(stageNumber);
queryKernelTester.setResultsCompleteForStageAndWorkers(stageNumber, 0);
}
}

View File

@ -135,7 +135,9 @@ public class MockQueryDefinitionBuilder
.map(StageInputSpec::new).collect(Collectors.toList());
if (inputSpecs.isEmpty()) {
inputSpecs.add(new ControllerTestInputSpec());
for (int i = 0; i < maxWorkers; i++) {
inputSpecs.add(new ControllerTestInputSpec());
}
}
queryDefinitionBuilder.add(

View File

@ -0,0 +1,334 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.kernel.controller;
import org.junit.Assert;
import org.junit.Test;
import javax.annotation.Nonnull;
public class NonShufflingWorkersWithRetryKernelTest extends BaseControllerQueryKernelTest
{
@Test
public void testWorkerFailedAfterInitialization()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(2);
controllerQueryKernelTester.init();
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 1
&& controllerQueryKernelTester.createAndGetNewStageNumbers().contains(0));
Assert.assertTrue(controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(0, RETRIABLE_FAULT).size() == 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.NEW);
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 1
&& controllerQueryKernelTester.createAndGetNewStageNumbers().contains(0));
}
@Test
public void testWorkerFailedBeforeAnyWorkOrdersSent()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(2);
controllerQueryKernelTester.init();
controllerQueryKernelTester.createAndGetNewStageNumbers();
controllerQueryKernelTester.startStage(0);
controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(0, RETRIABLE_FAULT);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
}
@Test
public void testWorkerFailedBeforeAllWorkOrdersSent()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(2);
controllerQueryKernelTester.init();
controllerQueryKernelTester.createAndGetNewStageNumbers();
controllerQueryKernelTester.startStage(0);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(0, RETRIABLE_FAULT);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
}
@Test
public void testWorkerFailedBeforeAnyResultsRecieved()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(2);
// workorders sent for both stage
controllerQueryKernelTester.setupStage(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.init();
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
// fail one worker
controllerQueryKernelTester.failWorkerAndAssertWorkOrderes(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 0);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 1);
}
@Test
public void testWorkerFailedBeforeAllResultsRecieved()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(2);
// workorders sent for both stage
controllerQueryKernelTester.setupStage(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.init();
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
// fail one worker
controllerQueryKernelTester.failWorkerAndAssertWorkOrderes(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 0);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 1);
}
@Test
public void testWorkerFailedBeforeFinished()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(2);
controllerQueryKernelTester.setupStage(0, ControllerStagePhase.RESULTS_READY);
controllerQueryKernelTester.init();
Assert.assertEquals(0, controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(0, RETRIABLE_FAULT).size());
Assert.assertEquals(0, controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(1, RETRIABLE_FAULT).size());
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
}
@Test
public void testWorkerFailedAfterFinished()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(2);
controllerQueryKernelTester.setupStage(0, ControllerStagePhase.FINISHED);
controllerQueryKernelTester.init();
Assert.assertEquals(0, controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(0, RETRIABLE_FAULT).size());
Assert.assertEquals(0, controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(1, RETRIABLE_FAULT).size());
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.FINISHED);
}
@Test
public void testMultipleWorkersFailedAfterInitialization()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(3);
controllerQueryKernelTester.init();
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 1
&& controllerQueryKernelTester.createAndGetNewStageNumbers().contains(0));
Assert.assertTrue(controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(0, RETRIABLE_FAULT).size() == 0);
Assert.assertTrue(controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(1, RETRIABLE_FAULT).size() == 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.NEW);
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 1
&& controllerQueryKernelTester.createAndGetNewStageNumbers().contains(0));
}
@Test
public void testMultipleWorkersFailedBeforeAnyWorkOrdersSent()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(3);
controllerQueryKernelTester.init();
controllerQueryKernelTester.createAndGetNewStageNumbers();
controllerQueryKernelTester.startStage(0);
controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(0, RETRIABLE_FAULT);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(1, RETRIABLE_FAULT);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 2);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 2);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
}
@Test
public void testMulttipleWorkerFailedBeforeAllWorkOrdersSent()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(3);
controllerQueryKernelTester.init();
controllerQueryKernelTester.createAndGetNewStageNumbers();
controllerQueryKernelTester.startStage(0);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(0, RETRIABLE_FAULT);
controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(2, RETRIABLE_FAULT);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 2);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 1, 2);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
}
@Test
public void testMultipleWorkersFailedBeforeAnyResultsRecieved()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(3);
controllerQueryKernelTester.setupStage(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.init();
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.failWorkerAndAssertWorkOrderes(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.failWorkerAndAssertWorkOrderes(1, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 0);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0, 2);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 1);
}
@Test
public void testMultipleWorkersFailedBeforeAllResultsRecieved()
{
ControllerQueryKernelTester controllerQueryKernelTester = getSimpleQueryDefinition(3);
// workorders sent for all stages
controllerQueryKernelTester.setupStage(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.init();
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 1);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
controllerQueryKernelTester.failWorkerAndAssertWorkOrderes(0, 0);
controllerQueryKernelTester.failWorkerAndAssertWorkOrderes(2, 0);
// should be no op
Assert.assertTrue(controllerQueryKernelTester.getRetriableWorkOrdersAndChangeState(1, RETRIABLE_FAULT).size() == 0);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RETRYING);
controllerQueryKernelTester.sendWorkOrdersForWorkers(0, 0, 2);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.READING_INPUT);
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 0);
controllerQueryKernelTester.setResultsCompleteForStageAndWorkers(0, 0, 2);
controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.RESULTS_READY);
Assert.assertTrue(controllerQueryKernelTester.createAndGetNewStageNumbers().size() == 1);
}
@Nonnull
private ControllerQueryKernelTester getSimpleQueryDefinition(int numWorkers)
{
ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(numWorkers);
// 0 -> 1
controllerQueryKernelTester.queryDefinition(
new MockQueryDefinitionBuilder(2)
.addVertex(0, 1)
.defineStage(0, false, numWorkers)
.defineStage(1, false, numWorkers)
.getQueryDefinitionBuilder()
.build()
);
return controllerQueryKernelTester;
}
}

View File

@ -72,6 +72,7 @@ import org.apache.druid.metadata.input.InputSourceModule;
import org.apache.druid.msq.counters.CounterSnapshots;
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.counters.QueryCounterSnapshot;
import org.apache.druid.msq.exec.ClusterStatisticsMergeMode;
import org.apache.druid.msq.exec.Controller;
import org.apache.druid.msq.exec.WorkerMemoryParameters;
import org.apache.druid.msq.guice.MSQDurableStorageModule;
@ -85,6 +86,8 @@ import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.error.InsertLockPreemptedFaultTest;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.indexing.error.MSQFaultUtils;
import org.apache.druid.msq.indexing.error.TooManyAttemptsForWorker;
import org.apache.druid.msq.indexing.report.MSQResultsReport;
import org.apache.druid.msq.indexing.report.MSQTaskReport;
import org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
@ -214,17 +217,41 @@ public class MSQTestBase extends BaseCalciteQueryTest
{
public static final Map<String, Object> DEFAULT_MSQ_CONTEXT =
ImmutableMap.<String, Object>builder()
.put(MultiStageQueryContext.CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, true)
.put(QueryContexts.CTX_SQL_QUERY_ID, "test-query")
.put(QueryContexts.FINALIZE_KEY, true)
.build();
public static final Map<String, Object> DURABLE_STORAGE_MSQ_CONTEXT =
ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, true).build();
public static final Map<String, Object> FAULT_TOLERANCE_MSQ_CONTEXT =
ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put(MultiStageQueryContext.CTX_FAULT_TOLERANCE, true).build();
public static final Map<String, Object> SEQUENTIAL_MERGE_MSQ_CONTEXT =
ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put(
MultiStageQueryContext.CTX_CLUSTER_STATISTICS_MERGE_MODE,
ClusterStatisticsMergeMode.SEQUENTIAL.toString()
)
.build();
public static final Map<String, Object>
ROLLUP_CONTEXT = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put(MultiStageQueryContext.CTX_FINALIZE_AGGREGATIONS, false)
.put(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, false)
.build();
ROLLUP_CONTEXT_PARAMS = ImmutableMap.<String, Object>builder()
.put(MultiStageQueryContext.CTX_FINALIZE_AGGREGATIONS, false)
.put(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, false)
.build();
public static final String FAULT_TOLERANCE = "fault_tolerance";
public static final String DURABLE_STORAGE = "durable_storage";
public static final String DEFAULT = "default";
public static final String SEQUENTIAL_MERGE = "sequential_merge";
public final boolean useDefault = NullHandling.replaceWithDefault();
@ -480,11 +507,11 @@ public class MSQTestBase extends BaseCalciteQueryTest
* Returns query context expected for a scan query. Same as {@link #DEFAULT_MSQ_CONTEXT}, but
* includes {@link DruidQuery#CTX_SCAN_SIGNATURE}.
*/
protected Map<String, Object> defaultScanQueryContext(final RowSignature signature)
protected Map<String, Object> defaultScanQueryContext(Map<String, Object> context, final RowSignature signature)
{
try {
return ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.putAll(context)
.put(
DruidQuery.CTX_SCAN_SIGNATURE,
queryFramework().queryJsonMapper().writeValueAsString(signature)
@ -971,9 +998,12 @@ public class MSQTestBase extends BaseCalciteQueryTest
if (expectedMSQFault != null || expectedMSQFaultClass != null) {
MSQErrorReport msqErrorReport = getErrorReportOrThrow(controllerId);
if (expectedMSQFault != null) {
String errorMessage = msqErrorReport.getFault() instanceof TooManyAttemptsForWorker
? ((TooManyAttemptsForWorker) msqErrorReport.getFault()).getRootErrorMessage()
: MSQFaultUtils.generateMessageWithErrorCode(msqErrorReport.getFault());
Assert.assertEquals(
expectedMSQFault.getCodeWithMessage(),
msqErrorReport.getFault().getCodeWithMessage()
MSQFaultUtils.generateMessageWithErrorCode(expectedMSQFault),
errorMessage
);
}
if (expectedMSQFaultClass != null) {
@ -1141,9 +1171,12 @@ public class MSQTestBase extends BaseCalciteQueryTest
if (expectedMSQFault != null || expectedMSQFaultClass != null) {
MSQErrorReport msqErrorReport = getErrorReportOrThrow(controllerId);
if (expectedMSQFault != null) {
String errorMessage = msqErrorReport.getFault() instanceof TooManyAttemptsForWorker
? ((TooManyAttemptsForWorker) msqErrorReport.getFault()).getRootErrorMessage()
: MSQFaultUtils.generateMessageWithErrorCode(msqErrorReport.getFault());
Assert.assertEquals(
expectedMSQFault.getCodeWithMessage(),
msqErrorReport.getFault().getCodeWithMessage()
MSQFaultUtils.generateMessageWithErrorCode(expectedMSQFault),
errorMessage
);
}
if (expectedMSQFaultClass != null) {

View File

@ -55,10 +55,10 @@ public class MSQTestControllerClient implements ControllerClient
}
@Override
public void postCounters(CounterSnapshotsTree snapshotsTree)
public void postCounters(String workerId, CounterSnapshotsTree snapshotsTree)
{
if (snapshotsTree != null) {
controller.updateCounters(snapshotsTree);
controller.updateCounters(workerId, snapshotsTree);
}
}

View File

@ -111,7 +111,7 @@ public class MSQTestControllerContext implements ControllerContext
WorkerManagerClient workerManagerClient = new WorkerManagerClient()
{
@Override
public String run(String controllerId, MSQWorkerTask task)
public String run(String taskId, MSQWorkerTask task)
{
if (controller == null) {
throw new ISE("Controller needs to be set using the register method");
@ -161,8 +161,8 @@ public class MSQTestControllerContext implements ControllerContext
taskStatus.getId(),
taskStatus.getStatusCode(),
taskStatus.getDuration(),
null,
null
taskStatus.getErrorMsg(),
taskStatus.getLocation()
)
);
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.util;
import org.apache.druid.msq.indexing.error.MSQFaultUtils;
import org.apache.druid.msq.indexing.error.UnknownFault;
import org.apache.druid.msq.indexing.error.WorkerFailedFault;
import org.junit.Assert;
import org.junit.Test;
public class MSQFaultUtilsTest
{
@Test
public void testGetErrorCodeFromMessage()
{
Assert.assertEquals(UnknownFault.CODE, MSQFaultUtils.getErrorCodeFromMessage(
"Task execution process exited unsuccessfully with code[137]. See middleManager logs for more details..."));
Assert.assertEquals(UnknownFault.CODE, MSQFaultUtils.getErrorCodeFromMessage(""));
Assert.assertEquals(UnknownFault.CODE, MSQFaultUtils.getErrorCodeFromMessage(null));
Assert.assertEquals("ABC", MSQFaultUtils.getErrorCodeFromMessage("ABC: xyz xyz : xyz"));
Assert.assertEquals(
WorkerFailedFault.CODE,
MSQFaultUtils.getErrorCodeFromMessage(MSQFaultUtils.generateMessageWithErrorCode(new WorkerFailedFault(
"123",
"error"
)))
);
}
}

View File

@ -34,13 +34,13 @@ import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_DESTINATION;
import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_ENABLE_DURABLE_SHUFFLE_STORAGE;
import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE;
import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_FAULT_TOLERANCE;
import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_FINALIZE_AGGREGATIONS;
import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_MAX_NUM_TASKS;
import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_MSQ_MODE;
@ -61,10 +61,23 @@ public class MultiStageQueryContextTest
@Test
public void isDurableStorageEnabled_parameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, "true");
Map<String, Object> propertyMap = ImmutableMap.of(CTX_DURABLE_SHUFFLE_STORAGE, "true");
Assert.assertTrue(MultiStageQueryContext.isDurableStorageEnabled(QueryContext.of(propertyMap)));
}
@Test
public void isFaultToleranceEnabled_noParameterSetReturnsDefaultValue()
{
Assert.assertFalse(MultiStageQueryContext.isFaultToleranceEnabled(QueryContext.empty()));
}
@Test
public void isFaultToleranceEnabled_parameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_FAULT_TOLERANCE, "true");
Assert.assertTrue(MultiStageQueryContext.isFaultToleranceEnabled(QueryContext.of(propertyMap)));
}
@Test
public void isFinalizeAggregations_noParameterSetReturnsDefaultValue()
{

View File

@ -178,7 +178,7 @@ public interface Task
boolean isReady(TaskActionClient taskActionClient) throws Exception;
/**
* Returns whether or not this task can restore its progress from its on-disk working directory. Restorable tasks
* Returns whether this task can restore its progress from its on-disk working directory. Restorable tasks
* may be started with a non-empty working directory. Tasks that exit uncleanly may still have a chance to attempt
* restores, meaning that restorable tasks should be able to deal with potentially partially written on-disk state.
*/

View File

@ -88,6 +88,9 @@ services:
service: indexer
environment:
- DRUID_INTEGRATION_TEST_GROUP=${DRUID_INTEGRATION_TEST_GROUP}
- druid_msq_intermediate_storage_enable=true
- druid_msq_intermediate_storage_type=local
- druid_msq_intermediate_storage_basePath=/shared/durablestorage/
volumes:
# Test data
- ../../resources:/resources

View File

@ -0,0 +1,98 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
networks:
druid-it-net:
name: druid-it-net
ipam:
config:
- subnet: 172.172.172.0/24
services:
zookeeper:
extends:
file: ../Common/dependencies.yaml
service: zookeeper
metadata:
extends:
file: ../Common/dependencies.yaml
service: metadata
coordinator:
extends:
file: ../Common/druid.yaml
service: coordinator
container_name: coordinator
environment:
- DRUID_INTEGRATION_TEST_GROUP=${DRUID_INTEGRATION_TEST_GROUP}
- druid_manager_segments_pollDuration=PT5S
- druid_coordinator_period=PT10S
depends_on:
- zookeeper
- metadata
overlord:
extends:
file: ../Common/druid.yaml
service: overlord
container_name: overlord
environment:
- DRUID_INTEGRATION_TEST_GROUP=${DRUID_INTEGRATION_TEST_GROUP}
depends_on:
- zookeeper
- metadata
broker:
extends:
file: ../Common/druid.yaml
service: broker
environment:
- DRUID_INTEGRATION_TEST_GROUP=${DRUID_INTEGRATION_TEST_GROUP}
depends_on:
- zookeeper
router:
extends:
file: ../Common/druid.yaml
service: router
environment:
- DRUID_INTEGRATION_TEST_GROUP=${DRUID_INTEGRATION_TEST_GROUP}
depends_on:
- zookeeper
historical:
extends:
file: ../Common/druid.yaml
service: historical
environment:
- DRUID_INTEGRATION_TEST_GROUP=${DRUID_INTEGRATION_TEST_GROUP}
depends_on:
- zookeeper
middlemanager:
extends:
file: ../Common/druid.yaml
service: middlemanager
environment:
- DRUID_INTEGRATION_TEST_GROUP=${DRUID_INTEGRATION_TEST_GROUP}
- druid_msq_intermediate_storage_enable=true
- druid_msq_intermediate_storage_type=local
- druid_msq_intermediate_storage_basePath=/shared/durablestorage/
volumes:
# Test data
- ../../resources:/resources
depends_on:
- zookeeper

View File

@ -328,6 +328,15 @@
<it.category>MultiStageQuery</it.category>
</properties>
</profile>
<profile>
<id>IT-MultiStageQueryWithMM</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<properties>
<it.category>MultiStageQueryWithMM</it.category>
</properties>
</profile>
<profile>
<id>IT-Catalog</id>
<activation>

View File

@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.testsEx.categories;
public class MultiStageQueryWithMM
{
}

View File

@ -128,7 +128,7 @@ public class ITKeyStatisticsSketchMergeMode
));
}
msqHelper.pollTaskIdForCompletion(sqlTaskStatus.getTaskId());
msqHelper.pollTaskIdForSuccess(sqlTaskStatus.getTaskId());
dataLoaderHelper.waitUntilDatasourceIsReady(datasource);
msqHelper.testQueriesFromFile(QUERY_FILE, datasource);
@ -198,7 +198,7 @@ public class ITKeyStatisticsSketchMergeMode
));
}
msqHelper.pollTaskIdForCompletion(sqlTaskStatus.getTaskId());
msqHelper.pollTaskIdForSuccess(sqlTaskStatus.getTaskId());
dataLoaderHelper.waitUntilDatasourceIsReady(datasource);
msqHelper.testQueriesFromFile(QUERY_FILE, datasource);

View File

@ -117,7 +117,7 @@ public class ITMultiStageQuery
));
}
msqHelper.pollTaskIdForCompletion(sqlTaskStatus.getTaskId());
msqHelper.pollTaskIdForSuccess(sqlTaskStatus.getTaskId());
dataLoaderHelper.waitUntilDatasourceIsReady(datasource);
msqHelper.testQueriesFromFile(QUERY_FILE, datasource);

View File

@ -0,0 +1,190 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.testsEx.msq;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.sql.SqlTaskStatus;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.testing.IntegrationTestingConfig;
import org.apache.druid.testing.clients.CoordinatorResourceTestClient;
import org.apache.druid.testing.clients.SqlResourceTestClient;
import org.apache.druid.testing.utils.DataLoaderHelper;
import org.apache.druid.testing.utils.ITRetryUtil;
import org.apache.druid.testing.utils.MsqTestQueryHelper;
import org.apache.druid.testsEx.categories.MultiStageQueryWithMM;
import org.apache.druid.testsEx.config.DruidTestRunner;
import org.apache.druid.testsEx.utils.DruidClusterAdminClient;
import org.junit.Assert;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
/**
* As we need to kill the PID of the launched task, these tests should be run with middle manager only.
*/
@RunWith(DruidTestRunner.class)
@Category(MultiStageQueryWithMM.class)
public class ITMultiStageQueryWorkerFaultTolerance
{
private static final Logger LOG = new Logger(ITMultiStageQueryWorkerFaultTolerance.class);
@Inject
private MsqTestQueryHelper msqHelper;
@Inject
private SqlResourceTestClient msqClient;
@Inject
private IntegrationTestingConfig config;
@Inject
private ObjectMapper jsonMapper;
@Inject
private DataLoaderHelper dataLoaderHelper;
@Inject
private CoordinatorResourceTestClient coordinatorClient;
@Inject
private DruidClusterAdminClient druidClusterAdminClient;
private static final String QUERY_FILE = "/multi-stage-query/wikipedia_msq_select_query_ha.json";
@Test
public void testMsqIngestionAndQuerying() throws Exception
{
String datasource = "dst";
// Clear up the datasource from the previous runs
coordinatorClient.unloadSegmentsForDataSource(datasource);
String queryLocal =
StringUtils.format(
"INSERT INTO %s\n"
+ "SELECT\n"
+ " TIME_PARSE(\"timestamp\") AS __time,\n"
+ " isRobot,\n"
+ " diffUrl,\n"
+ " added,\n"
+ " countryIsoCode,\n"
+ " regionName,\n"
+ " channel,\n"
+ " flags,\n"
+ " delta,\n"
+ " isUnpatrolled,\n"
+ " isNew,\n"
+ " deltaBucket,\n"
+ " isMinor,\n"
+ " isAnonymous,\n"
+ " deleted,\n"
+ " cityName,\n"
+ " metroCode,\n"
+ " namespace,\n"
+ " comment,\n"
+ " page,\n"
+ " commentLength,\n"
+ " countryName,\n"
+ " user,\n"
+ " regionIsoCode\n"
+ "FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{\"type\":\"local\",\"files\":[\"/resources/data/batch_index/json/wikipedia_index_data1.json\",\"/resources/data/batch_index/json/wikipedia_index_data1.json\"]}',\n"
+ " '{\"type\":\"json\"}',\n"
+ " '[{\"type\":\"string\",\"name\":\"timestamp\"},{\"type\":\"string\",\"name\":\"isRobot\"},{\"type\":\"string\",\"name\":\"diffUrl\"},{\"type\":\"long\",\"name\":\"added\"},{\"type\":\"string\",\"name\":\"countryIsoCode\"},{\"type\":\"string\",\"name\":\"regionName\"},{\"type\":\"string\",\"name\":\"channel\"},{\"type\":\"string\",\"name\":\"flags\"},{\"type\":\"long\",\"name\":\"delta\"},{\"type\":\"string\",\"name\":\"isUnpatrolled\"},{\"type\":\"string\",\"name\":\"isNew\"},{\"type\":\"double\",\"name\":\"deltaBucket\"},{\"type\":\"string\",\"name\":\"isMinor\"},{\"type\":\"string\",\"name\":\"isAnonymous\"},{\"type\":\"long\",\"name\":\"deleted\"},{\"type\":\"string\",\"name\":\"cityName\"},{\"type\":\"long\",\"name\":\"metroCode\"},{\"type\":\"string\",\"name\":\"namespace\"},{\"type\":\"string\",\"name\":\"comment\"},{\"type\":\"string\",\"name\":\"page\"},{\"type\":\"long\",\"name\":\"commentLength\"},{\"type\":\"string\",\"name\":\"countryName\"},{\"type\":\"string\",\"name\":\"user\"},{\"type\":\"string\",\"name\":\"regionIsoCode\"}]'\n"
+ " )\n"
+ ")\n"
+ "PARTITIONED BY DAY\n"
+ "CLUSTERED BY \"__time\"",
datasource
);
// Submit the task and wait for the datasource to get loaded
SqlTaskStatus sqlTaskStatus = msqHelper.submitMsqTask(
queryLocal,
ImmutableMap.of(
MultiStageQueryContext.CTX_FAULT_TOLERANCE,
"true",
MultiStageQueryContext.CTX_MAX_NUM_TASKS,
3
)
);
if (sqlTaskStatus.getState().isFailure()) {
Assert.fail(StringUtils.format(
"Unable to start the task successfully.\nPossible exception: %s",
sqlTaskStatus.getError()
));
}
String taskIdToKill = sqlTaskStatus.getTaskId() + "-worker1_0";
killTaskAbruptly(taskIdToKill);
msqHelper.pollTaskIdForSuccess(sqlTaskStatus.getTaskId());
dataLoaderHelper.waitUntilDatasourceIsReady(datasource);
msqHelper.testQueriesFromFile(QUERY_FILE, datasource);
}
private void killTaskAbruptly(String taskIdToKill)
{
String command = "jps -mlv | grep -i peon | grep -i " + taskIdToKill + " |awk '{print $1}'";
ITRetryUtil.retryUntil(() -> {
Pair<String, String> stdOut = druidClusterAdminClient.runCommandInMiddleManagerContainer("/bin/bash", "-c",
command
);
LOG.info(StringUtils.format(
"command %s \nstdout: %s\nstderr: %s",
command,
stdOut.lhs,
stdOut.rhs
));
if (stdOut.rhs != null && stdOut.rhs.length() != 0) {
throw new ISE("Bad command");
}
String pidToKill = stdOut.lhs.trim();
if (pidToKill.length() != 0) {
LOG.info("Found PID to kill %s", pidToKill);
// kill worker after 5 seconds
Thread.sleep(5000);
LOG.info("Killing pid %s", pidToKill);
druidClusterAdminClient.runCommandInMiddleManagerContainer(
"/bin/bash",
"-c",
"kill -9 " + pidToKill
);
return true;
} else {
return false;
}
}, true, 6000, 50, StringUtils.format("Figuring out PID for task[%s] to kill abruptly", taskIdToKill));
}
}

View File

@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-------------------------------------------------------------------------
# Definition of the multi stage query test cluster.
# See https://yaml.org/spec/1.2.2 for more about YAML
include:
- /cluster/Common/zk-metastore.yaml
druid:
coordinator:
instances:
- port: 8081
overlord:
instances:
- port: 8090
broker:
instances:
- port: 8082
router:
instances:
- port: 8888
historical:
instances:
- port: 8083
middlemanager:
instances:
- port: 8091

View File

@ -0,0 +1,55 @@
[
{
"query": "SELECT __time, isRobot, added, delta, deleted, namespace FROM %%DATASOURCE%%",
"expectedResults": [
{
"__time": 1377910953000,
"isRobot": null,
"added": 57,
"delta": -143,
"deleted": 200,
"namespace": "article"
},
{
"__time": 1377910953000,
"isRobot": null,
"added": 57,
"delta": -143,
"deleted": 200,
"namespace": "article"
},
{
"__time": 1377919965000,
"isRobot": null,
"added": 459,
"delta": 330,
"deleted": 129,
"namespace": "wikipedia"
},
{
"__time": 1377919965000,
"isRobot": null,
"added": 459,
"delta": 330,
"deleted": 129,
"namespace": "wikipedia"
},
{
"__time": 1377933081000,
"isRobot": null,
"added": 123,
"delta": 111,
"deleted": 12,
"namespace": "article"
},
{
"__time": 1377933081000,
"isRobot": null,
"added": 123,
"delta": 111,
"deleted": 12,
"namespace": "article"
}
]
}
]

View File

@ -168,6 +168,11 @@ public class MsqTestQueryHelper extends AbstractTestQueryHelper<MsqQueryWithResu
);
}
public void pollTaskIdForSuccess(String taskId) throws Exception
{
Assert.assertEquals(pollTaskIdForCompletion(taskId), TaskState.SUCCESS);
}
/**
* Fetches status reports for a given task
*/
@ -255,7 +260,7 @@ public class MsqTestQueryHelper extends AbstractTestQueryHelper<MsqQueryWithResu
);
}
String taskId = sqlTaskStatus.getTaskId();
pollTaskIdForCompletion(taskId);
pollTaskIdForSuccess(taskId);
compareResults(taskId, queryWithResults);
}
}