Multiple fixes for the MSQ stats merging piece which (#13463)

* Add validation checks to worker chat handler apis

* Merge things and polishing the error messages.

* Minor error message change

* Fixing race and adding some tests

* Fixing controller fetching stats from wrong workers.
Fixing race
Changing default mode to Parallel
Adding logging.
Fixing exceptions not propagated properly.

* Changing to kernel worker count

* Added a better logic to figure out assigned worker for a stage.

* Nits

* Moving to existing kernel methods

* Adding more coverage

Co-authored-by: cryptoe <karankumar1100@gmail.com>
This commit is contained in:
Adarsh Sanjeev 2022-12-15 09:35:11 +05:30 committed by GitHub
parent 089d8da561
commit 2b605aa9cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 529 additions and 50 deletions

View File

@ -325,7 +325,7 @@ The following table lists the context parameters for the MSQ task engine:
| `maxParseExceptions`| SELECT, INSERT, REPLACE<br /><br />Maximum number of parse exceptions that are ignored while executing the query before it stops with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value to -1.| 0 |
| `rowsPerSegment` | INSERT or REPLACE<br /><br />The number of rows per segment to target. The actual number of rows per segment may be somewhat higher or lower than this number. In most cases, use the default. For general information about sizing rows per segment, see [Segment Size Optimization](../operations/segment-optimization.md). | 3,000,000 |
| `indexSpec` | INSERT or REPLACE<br /><br />An [`indexSpec`](../ingestion/ingestion-spec.md#indexspec) to use when generating segments. May be a JSON string or object. See [Front coding](../ingestion/ingestion-spec.md#front-coding) for details on configuring an `indexSpec` with front coding. | See [`indexSpec`](../ingestion/ingestion-spec.md#indexspec). |
| `clusterStatisticsMergeMode` | Whether to use parallel or sequential mode for merging of the worker sketches. Can be `PARALLEL`, `SEQUENTIAL` or `AUTO`. See [Sketch Merging Mode](#sketch-merging-mode) for more information. | `AUTO` |
| `clusterStatisticsMergeMode` | Whether to use parallel or sequential mode for merging of the worker sketches. Can be `PARALLEL`, `SEQUENTIAL` or `AUTO`. See [Sketch Merging Mode](#sketch-merging-mode) for more information. | `PARALLEL` |
## Sketch Merging Mode
This section details the advantages and performance of various Cluster By Statistics Merge Modes.

View File

@ -263,6 +263,7 @@ public class ControllerImpl implements Controller
// For live reports. Written by the main controller thread, read by HTTP threads.
private final ConcurrentHashMap<Integer, Integer> stagePartitionCountsForLiveReports = new ConcurrentHashMap<>();
private WorkerSketchFetcher workerSketchFetcher;
// Time at which the query started.
// For live reports. Written by the main controller thread, read by HTTP threads.
@ -624,14 +625,21 @@ public class ControllerImpl implements Controller
workerSketchFetcher.submitFetcherTask(
completeKeyStatisticsInformation,
workerTaskIds,
stageDef
stageDef,
queryKernel.getWorkerInputsForStage(stageId).workers()
// we only need tasks which are active for this stage.
);
// Add the listener to handle completion.
clusterByPartitionsCompletableFuture.whenComplete((clusterByPartitionsEither, throwable) -> {
addToKernelManipulationQueue(holder -> {
if (throwable != null) {
holder.failStageForReason(stageId, UnknownFault.forException(throwable));
log.error("Error while fetching stats for stageId[%s]", stageId);
if (throwable instanceof MSQException) {
holder.failStageForReason(stageId, ((MSQException) throwable).getFault());
} else {
holder.failStageForReason(stageId, UnknownFault.forException(throwable));
}
} else if (clusterByPartitionsEither.isError()) {
holder.failStageForReason(stageId, new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
} else {

View File

@ -57,9 +57,13 @@ public class ExceptionWrappingWorkerClient implements WorkerClient
}
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(String workerTaskId, String queryId, int stageNumber)
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerTaskId,
String queryId,
int stageNumber
)
{
return client.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, stageNumber);
return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, stageNumber));
}
@Override
@ -70,7 +74,11 @@ public class ExceptionWrappingWorkerClient implements WorkerClient
long timeChunk
)
{
return client.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, queryId, stageNumber, timeChunk);
return wrap(
workerTaskId,
client,
c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, queryId, stageNumber, timeChunk)
);
}
@Override

View File

@ -571,16 +571,37 @@ public class WorkerImpl implements Worker
@Override
public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
{
return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
if (stageKernelMap.get(stageId) == null) {
throw new ISE("Requested statistics snapshot for non-existent stageId %s.", stageId);
} else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == null) {
throw new ISE(
"Requested statistics snapshot is not generated yet for stageId[%s]",
stageId
);
} else {
return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
}
}
@Override
public ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk)
{
ClusterByStatisticsSnapshot snapshot = stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
return snapshot.getSnapshotForTimeChunk(timeChunk);
if (stageKernelMap.get(stageId) == null) {
throw new ISE("Requested statistics snapshot for non-existent stageId[%s].", stageId);
} else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == null) {
throw new ISE(
"Requested statistics snapshot is not generated yet for stageId[%s]",
stageId
);
} else {
return stageKernelMap.get(stageId)
.getResultKeyStatisticsSnapshot()
.getSnapshotForTimeChunk(timeChunk);
}
}
@Override
public CounterSnapshotsTree getCounters()
{
@ -643,7 +664,7 @@ public class WorkerImpl implements Worker
/**
* Decorates the server-wide {@link QueryProcessingPool} such that any Callables and Runnables, not just
* {@link PrioritizedCallable} and {@link PrioritizedRunnable}, may be added to it.
*
* <p>
* In production, the underlying {@link QueryProcessingPool} pool is set up by
* {@link org.apache.druid.guice.DruidProcessingModule}.
*/

View File

@ -20,6 +20,7 @@
package org.apache.druid.msq.exec;
import com.google.common.util.concurrent.ListenableFuture;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
@ -40,7 +41,7 @@ import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.stream.IntStream;
import java.util.stream.Collectors;
/**
* Queues up fetching sketches from workers and progressively generates partitions boundaries.
@ -78,7 +79,8 @@ public class WorkerSketchFetcher implements AutoCloseable
public CompletableFuture<Either<Long, ClusterByPartitions>> submitFetcherTask(
CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
List<String> workerTaskIds,
StageDefinition stageDefinition
StageDefinition stageDefinition,
IntSet workersForStage
)
{
ClusterBy clusterBy = stageDefinition.getClusterBy();
@ -87,18 +89,31 @@ public class WorkerSketchFetcher implements AutoCloseable
case SEQUENTIAL:
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
case PARALLEL:
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
case AUTO:
if (clusterBy.getBucketByCount() == 0) {
log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info(
"Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge key statistics",
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
// If there is no time clustering, there is no scope for sequential merge
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
} else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
log.info("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId());
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
} else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD
|| completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
log.info(
"Query[%s] stage[%d] for AUTO mode: chose SEQUENTIAL mode to merge key statistics",
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
}
log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
log.info(
"Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge key statistics",
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
default:
throw new IllegalStateException("No fetching strategy found for mode: " + clusterStatisticsMergeMode);
}
@ -111,7 +126,8 @@ public class WorkerSketchFetcher implements AutoCloseable
*/
CompletableFuture<Either<Long, ClusterByPartitions>> inMemoryFullSketchMerging(
StageDefinition stageDefinition,
List<String> workerTaskIds
List<String> workerTaskIds,
IntSet workersForStage
)
{
CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture = new CompletableFuture<>();
@ -119,12 +135,19 @@ public class WorkerSketchFetcher implements AutoCloseable
// Create a new key statistics collector to merge worker sketches into
final ClusterByStatisticsCollector mergedStatisticsCollector =
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
final int workerCount = workerTaskIds.size();
final int workerCount = workersForStage.size();
// Guarded by synchronized mergedStatisticsCollector
final Set<Integer> finishedWorkers = new HashSet<>();
log.info(
"Fetching stats using %s for stage[%d] for workers[%s] ",
ClusterStatisticsMergeMode.PARALLEL,
stageDefinition.getStageNumber(),
workersForStage.stream().map(Object::toString).collect(Collectors.joining(","))
);
// Submit a task for each worker to fetch statistics
IntStream.range(0, workerCount).forEach(workerNo -> {
workersForStage.forEach(workerNo -> {
executorService.submit(() -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
workerClient.fetchClusterByStatisticsSnapshot(
@ -177,6 +200,13 @@ public class WorkerSketchFetcher implements AutoCloseable
workerTaskIds,
completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().entrySet().iterator()
);
log.info(
"Fetching stats using %s for stage[%d] for tasks[%s]",
ClusterStatisticsMergeMode.SEQUENTIAL,
stageDefinition.getStageNumber(),
String.join("", workerTaskIds)
);
sequentialFetchStage.submitFetchingTasksForNextTimeChunk();
return sequentialFetchStage.getPartitionFuture();
}

View File

@ -19,11 +19,13 @@
package org.apache.druid.msq.indexing;
import com.google.common.collect.ImmutableMap;
import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.commons.lang.mutable.MutableLong;
import org.apache.druid.frame.file.FrameFileHttpResponseHandler;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.indexing.common.TaskToolbox;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.exec.Worker;
import org.apache.druid.msq.kernel.StageId;
@ -71,7 +73,7 @@ public class WorkerChatHandler implements ChatHandler
/**
* Returns up to {@link #CHANNEL_DATA_CHUNK_SIZE} bytes of stage output data.
*
* <p>
* See {@link org.apache.druid.msq.exec.WorkerClient#fetchChannelData} for the client-side code that calls this API.
*/
@GET
@ -193,17 +195,30 @@ public class WorkerChatHandler implements ChatHandler
ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper());
ClusterByStatisticsSnapshot clusterByStatisticsSnapshot;
StageId stageId = new StageId(queryId, stageNumber);
clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
return Response.status(Response.Status.ACCEPTED)
.entity(clusterByStatisticsSnapshot)
.build();
try {
clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
return Response.status(Response.Status.ACCEPTED)
.entity(clusterByStatisticsSnapshot)
.build();
}
catch (Exception e) {
String errorMessage = StringUtils.format(
"Invalid request for key statistics for query[%s] and stage[%d]",
queryId,
stageNumber
);
log.error(e, errorMessage);
return Response.status(Response.Status.BAD_REQUEST)
.entity(ImmutableMap.<String, Object>of("error", errorMessage))
.build();
}
}
@POST
@Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response httpSketch(
public Response httpFetchKeyStatisticsWithSnapshot(
@PathParam("queryId") final String queryId,
@PathParam("stageNumber") final int stageNumber,
@PathParam("timeChunk") final long timeChunk,
@ -213,10 +228,24 @@ public class WorkerChatHandler implements ChatHandler
ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper());
ClusterByStatisticsSnapshot snapshotForTimeChunk;
StageId stageId = new StageId(queryId, stageNumber);
snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk);
return Response.status(Response.Status.ACCEPTED)
.entity(snapshotForTimeChunk)
.build();
try {
snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk);
return Response.status(Response.Status.ACCEPTED)
.entity(snapshotForTimeChunk)
.build();
}
catch (Exception e) {
String errorMessage = StringUtils.format(
"Invalid request for key statistics for query[%s], stage[%d] and timeChunk[%d]",
queryId,
stageNumber,
timeChunk
);
log.error(e, errorMessage);
return Response.status(Response.Status.BAD_REQUEST)
.entity(ImmutableMap.<String, Object>of("error", errorMessage))
.build();
}
}
/**

View File

@ -48,8 +48,9 @@ public class WorkerStageKernel
private WorkerStagePhase phase = WorkerStagePhase.NEW;
// We read this variable in the main thread and the netty threads
@Nullable
private ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot;
private volatile ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot;
@Nullable
private ClusterByPartitions resultPartitionBoundaries;

View File

@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.ISE;
import javax.annotation.Nullable;
import java.util.Collections;
@ -61,6 +62,9 @@ public class ClusterByStatisticsSnapshot
public ClusterByStatisticsSnapshot getSnapshotForTimeChunk(long timeChunk)
{
Bucket bucket = buckets.get(timeChunk);
if (bucket == null) {
throw new ISE("ClusterByStatistics not present for requested timechunk %s", timeChunk);
}
return new ClusterByStatisticsSnapshot(ImmutableMap.of(timeChunk, bucket), null);
}

View File

@ -60,7 +60,7 @@ public class MultiStageQueryContext
public static final String CTX_ENABLE_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage";
public static final String CTX_CLUSTER_STATISTICS_MERGE_MODE = "clusterStatisticsMergeMode";
public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.AUTO.toString();
public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.PARALLEL.toString();
private static final boolean DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = false;
public static final String CTX_DESTINATION = "destination";

View File

@ -128,6 +128,32 @@ public class MSQInsertTest extends MSQTestBase
}
@Test
public void testInsertOnFoo1WithTimeFunctionWithSequential()
{
RowSignature rowSignature = RowSignature.builder()
.add("__time", ColumnType.LONG)
.add("dim1", ColumnType.STRING)
.add("cnt", ColumnType.LONG).build();
Map<String, Object> context = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put(
MultiStageQueryContext.CTX_CLUSTER_STATISTICS_MERGE_MODE,
ClusterStatisticsMergeMode.SEQUENTIAL.toString()
)
.build();
testIngestQuery().setSql(
"insert into foo1 select floor(__time to day) as __time , dim1 , count(*) as cnt from foo where dim1 is not null group by 1, 2 PARTITIONED by day clustered by dim1")
.setQueryContext(context)
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setExpectedSegment(expectedFooSegments())
.setExpectedResultRows(expectedFooRows())
.verifyResults();
}
@Test
public void testInsertOnFoo1WithMultiValueDim()
{

View File

@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.exec;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.indexing.MSQWorkerTask;
import org.apache.druid.msq.kernel.StageId;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import java.util.HashMap;
@RunWith(MockitoJUnitRunner.class)
public class WorkerImplTest
{
@Mock
WorkerContext workerContext;
@Test
public void testFetchStatsThrows()
{
WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>()), workerContext);
Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshot(new StageId("xx", 1)));
}
@Test
public void testFetchStatsWithTimeChunkThrows()
{
WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>()), workerContext);
Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshotForTimeChunk(new StageId("xx", 1), 1L));
}
}

View File

@ -19,6 +19,7 @@
package org.apache.druid.msq.exec;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
@ -56,7 +57,7 @@ public class WorkerSketchFetcherAutoModeTest
target = spy(new WorkerSketchFetcher(mock(WorkerClient.class), ClusterStatisticsMergeMode.AUTO, 300_000_000));
// Don't actually try to fetch sketches
doReturn(null).when(target).inMemoryFullSketchMerging(any(), any());
doReturn(null).when(target).inMemoryFullSketchMerging(any(), any(), any());
doReturn(null).when(target).sequentialTimeChunkMerging(any(), any(), any());
doReturn(StageId.fromString("1_1")).when(stageDefinition).getId();
@ -81,8 +82,13 @@ public class WorkerSketchFetcherAutoModeTest
// Worker count below threshold
doReturn(1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
verify(target, times(1)).inMemoryFullSketchMerging(any(), any());
target.submitFetcherTask(
completeKeyStatisticsInformation,
Collections.emptyList(),
stageDefinition,
IntSet.of()
);
verify(target, times(1)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(0)).sequentialTimeChunkMerging(any(), any(), any());
}
@ -98,8 +104,13 @@ public class WorkerSketchFetcherAutoModeTest
// Worker count below threshold
doReturn((int) WorkerSketchFetcher.WORKER_THRESHOLD + 1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
verify(target, times(0)).inMemoryFullSketchMerging(any(), any());
target.submitFetcherTask(
completeKeyStatisticsInformation,
Collections.emptyList(),
stageDefinition,
IntSet.of()
);
verify(target, times(0)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(1)).sequentialTimeChunkMerging(any(), any(), any());
}
@ -115,8 +126,13 @@ public class WorkerSketchFetcherAutoModeTest
// Worker count above threshold
doReturn((int) WorkerSketchFetcher.WORKER_THRESHOLD + 1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
verify(target, times(1)).inMemoryFullSketchMerging(any(), any());
target.submitFetcherTask(
completeKeyStatisticsInformation,
Collections.emptyList(),
stageDefinition,
IntSet.of()
);
verify(target, times(1)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(0)).sequentialTimeChunkMerging(any(), any(), any());
}
@ -132,8 +148,13 @@ public class WorkerSketchFetcherAutoModeTest
// Worker count below threshold
doReturn(1).when(stageDefinition).getMaxWorkerCount();
target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
verify(target, times(0)).inMemoryFullSketchMerging(any(), any());
target.submitFetcherTask(
completeKeyStatisticsInformation,
Collections.emptyList(),
stageDefinition,
IntSet.of()
);
verify(target, times(0)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(1)).sequentialTimeChunkMerging(any(), any(), any());
}
}

View File

@ -23,6 +23,8 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.util.concurrent.Futures;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
@ -88,11 +90,19 @@ public class WorkerSketchFetcherTest
doReturn(clusterBy).when(stageDefinition).getClusterBy();
doReturn(25_000).when(stageDefinition).getMaxPartitionCount();
expectedPartitions1 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(mock(RowKey.class), mock(RowKey.class))));
expectedPartitions2 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(mock(RowKey.class), mock(RowKey.class))));
expectedPartitions1 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(
mock(RowKey.class),
mock(RowKey.class)
)));
expectedPartitions2 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(
mock(RowKey.class),
mock(RowKey.class)
)));
doReturn(Either.value(expectedPartitions1)).when(stageDefinition).generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector1));
doReturn(Either.value(expectedPartitions2)).when(stageDefinition).generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector2));
doReturn(Either.value(expectedPartitions1)).when(stageDefinition)
.generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector1));
doReturn(Either.value(expectedPartitions2)).when(stageDefinition)
.generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector2));
doReturn(
mergedClusterByStatisticsCollector1,
@ -128,10 +138,14 @@ public class WorkerSketchFetcherTest
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt());
IntSet workersForStage = new IntAVLTreeSet();
workersForStage.addAll(ImmutableSet.of(0, 1, 2, 3, 4));
CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
workerIds,
stageDefinition
stageDefinition,
workersForStage
);
// Assert that the final result is complete and all other sketches returned have been merged.
@ -154,7 +168,12 @@ public class WorkerSketchFetcherTest
// Store snapshots in a queue
final Queue<ClusterByStatisticsSnapshot> snapshotQueue = new ConcurrentLinkedQueue<>();
SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap = ImmutableSortedMap.of(1L, ImmutableSet.of(0, 1, 2), 2L, ImmutableSet.of(0, 1, 4));
SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap = ImmutableSortedMap.of(
1L,
ImmutableSet.of(0, 1, 2),
2L,
ImmutableSet.of(0, 1, 4)
);
doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap();
final CyclicBarrier barrier = new CyclicBarrier(3);
@ -168,10 +187,14 @@ public class WorkerSketchFetcherTest
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyInt(), anyLong());
IntSet workersForStage = new IntAVLTreeSet();
workersForStage.addAll(ImmutableSet.of(0, 1, 2, 3, 4));
CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
ImmutableList.of("0", "1", "2", "3", "4"),
stageDefinition
stageDefinition,
workersForStage
);
// Assert that the final result is complete and all other sketches returned have been merged.

View File

@ -0,0 +1,254 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.indexing;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.indexer.TaskStatus;
import org.apache.druid.indexing.common.TaskReport;
import org.apache.druid.indexing.common.TaskReportFileWriter;
import org.apache.druid.indexing.common.TaskToolbox;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.exec.Worker;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.segment.IndexIO;
import org.apache.druid.segment.IndexMergerV9;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.Response;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
public class WorkerChatHandlerTest
{
private static final StageId TEST_STAGE = new StageId("123", 0);
@Mock
private HttpServletRequest req;
private TaskToolbox toolbox;
private AutoCloseable mocks;
private final TestWorker worker = new TestWorker();
@Before
public void setUp()
{
ObjectMapper mapper = new DefaultObjectMapper();
IndexIO indexIO = new IndexIO(mapper, () -> 0);
IndexMergerV9 indexMerger = new IndexMergerV9(
mapper,
indexIO,
OffHeapMemorySegmentWriteOutMediumFactory.instance()
);
mocks = MockitoAnnotations.openMocks(this);
Mockito.when(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT))
.thenReturn(new AuthenticationResult("druid", "druid", null, null));
TaskToolbox.Builder builder = new TaskToolbox.Builder();
toolbox = builder.authorizerMapper(CalciteTests.TEST_AUTHORIZER_MAPPER)
.indexIO(indexIO)
.indexMergerV9(indexMerger)
.taskReportFileWriter(
new TaskReportFileWriter()
{
@Override
public void write(String taskId, Map<String, TaskReport> reports)
{
}
@Override
public void setObjectMapper(ObjectMapper objectMapper)
{
}
}
)
.build();
}
@Test
public void testFetchSnapshot()
{
WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
Assert.assertEquals(
ClusterByStatisticsSnapshot.empty(),
chatHandler.httpFetchKeyStatistics(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), req)
.getEntity()
);
}
@Test
public void testFetchSnapshot404()
{
WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
Assert.assertEquals(
Response.Status.BAD_REQUEST.getStatusCode(),
chatHandler.httpFetchKeyStatistics("123", 2, req)
.getStatus()
);
}
@Test
public void testFetchSnapshotWithTimeChunk()
{
WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
Assert.assertEquals(
ClusterByStatisticsSnapshot.empty(),
chatHandler.httpFetchKeyStatisticsWithSnapshot(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), 1, req)
.getEntity()
);
}
@Test
public void testFetchSnapshotWithTimeChunk404()
{
WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
Assert.assertEquals(
Response.Status.BAD_REQUEST.getStatusCode(),
chatHandler.httpFetchKeyStatisticsWithSnapshot("123", 2, 1, req)
.getStatus()
);
}
private static class TestWorker implements Worker
{
@Override
public String id()
{
return TEST_STAGE.getQueryId() + "task";
}
@Override
public MSQWorkerTask task()
{
return new MSQWorkerTask("controller", "ds", 1, new HashMap<>());
}
@Override
public TaskStatus run()
{
return null;
}
@Override
public void stopGracefully()
{
}
@Override
public void controllerFailed()
{
}
@Override
public void postWorkOrder(WorkOrder workOrder)
{
}
@Override
public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
{
if (TEST_STAGE.equals(stageId)) {
return ClusterByStatisticsSnapshot.empty();
} else {
throw new ISE("stage not found %s", stageId);
}
}
@Override
public ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk)
{
if (TEST_STAGE.equals(stageId)) {
return ClusterByStatisticsSnapshot.empty();
} else {
throw new ISE("stage not found %s", stageId);
}
}
@Override
public boolean postResultPartitionBoundaries(
ClusterByPartitions stagePartitionBoundaries,
String queryId,
int stageNumber
)
{
return false;
}
@Nullable
@Override
public InputStream readChannel(String queryId, int stageNumber, int partitionNumber, long offset)
{
return null;
}
@Override
public CounterSnapshotsTree getCounters()
{
return null;
}
@Override
public void postCleanupStage(StageId stageId)
{
}
@Override
public void postFinish()
{
}
}
@After
public void tearDown()
{
try {
mocks.close();
}
catch (Exception ignored) {
// ignore tear down exceptions
}
}
}