fetchClusterByStatisticsSnapshot(
+ String workerTaskId,
+ String queryId,
+ int stageNumber
+ )
{
- return client.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, stageNumber);
+ return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, stageNumber));
}
@Override
@@ -70,7 +74,11 @@ public class ExceptionWrappingWorkerClient implements WorkerClient
long timeChunk
)
{
- return client.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, queryId, stageNumber, timeChunk);
+ return wrap(
+ workerTaskId,
+ client,
+ c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, queryId, stageNumber, timeChunk)
+ );
}
@Override
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java
index 49d6f9080d7..8c5a782f536 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java
@@ -571,16 +571,37 @@ public class WorkerImpl implements Worker
@Override
public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
{
- return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
+ if (stageKernelMap.get(stageId) == null) {
+ throw new ISE("Requested statistics snapshot for non-existent stageId %s.", stageId);
+ } else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == null) {
+ throw new ISE(
+ "Requested statistics snapshot is not generated yet for stageId[%s]",
+ stageId
+ );
+ } else {
+ return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
+ }
}
@Override
public ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk)
{
- ClusterByStatisticsSnapshot snapshot = stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
- return snapshot.getSnapshotForTimeChunk(timeChunk);
+ if (stageKernelMap.get(stageId) == null) {
+ throw new ISE("Requested statistics snapshot for non-existent stageId[%s].", stageId);
+ } else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == null) {
+ throw new ISE(
+ "Requested statistics snapshot is not generated yet for stageId[%s]",
+ stageId
+ );
+ } else {
+ return stageKernelMap.get(stageId)
+ .getResultKeyStatisticsSnapshot()
+ .getSnapshotForTimeChunk(timeChunk);
+ }
+
}
+
@Override
public CounterSnapshotsTree getCounters()
{
@@ -643,7 +664,7 @@ public class WorkerImpl implements Worker
/**
* Decorates the server-wide {@link QueryProcessingPool} such that any Callables and Runnables, not just
* {@link PrioritizedCallable} and {@link PrioritizedRunnable}, may be added to it.
- *
+ *
* In production, the underlying {@link QueryProcessingPool} pool is set up by
* {@link org.apache.druid.guice.DruidProcessingModule}.
*/
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java
index dc6f2199058..2eba0c409d2 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java
@@ -20,6 +20,7 @@
package org.apache.druid.msq.exec;
import com.google.common.util.concurrent.ListenableFuture;
+import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
@@ -40,7 +41,7 @@ import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
-import java.util.stream.IntStream;
+import java.util.stream.Collectors;
/**
* Queues up fetching sketches from workers and progressively generates partitions boundaries.
@@ -78,7 +79,8 @@ public class WorkerSketchFetcher implements AutoCloseable
public CompletableFuture> submitFetcherTask(
CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
List workerTaskIds,
- StageDefinition stageDefinition
+ StageDefinition stageDefinition,
+ IntSet workersForStage
)
{
ClusterBy clusterBy = stageDefinition.getClusterBy();
@@ -87,18 +89,31 @@ public class WorkerSketchFetcher implements AutoCloseable
case SEQUENTIAL:
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
case PARALLEL:
- return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+ return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
case AUTO:
if (clusterBy.getBucketByCount() == 0) {
- log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
+ log.info(
+ "Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge key statistics",
+ stageDefinition.getId().getQueryId(),
+ stageDefinition.getStageNumber()
+ );
// If there is no time clustering, there is no scope for sequential merge
- return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
- } else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
- log.info("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId());
+ return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
+ } else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD
+ || completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
+ log.info(
+ "Query[%s] stage[%d] for AUTO mode: chose SEQUENTIAL mode to merge key statistics",
+ stageDefinition.getId().getQueryId(),
+ stageDefinition.getStageNumber()
+ );
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
}
- log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
- return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+ log.info(
+ "Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge key statistics",
+ stageDefinition.getId().getQueryId(),
+ stageDefinition.getStageNumber()
+ );
+ return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
default:
throw new IllegalStateException("No fetching strategy found for mode: " + clusterStatisticsMergeMode);
}
@@ -111,7 +126,8 @@ public class WorkerSketchFetcher implements AutoCloseable
*/
CompletableFuture> inMemoryFullSketchMerging(
StageDefinition stageDefinition,
- List workerTaskIds
+ List workerTaskIds,
+ IntSet workersForStage
)
{
CompletableFuture> partitionFuture = new CompletableFuture<>();
@@ -119,12 +135,19 @@ public class WorkerSketchFetcher implements AutoCloseable
// Create a new key statistics collector to merge worker sketches into
final ClusterByStatisticsCollector mergedStatisticsCollector =
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
- final int workerCount = workerTaskIds.size();
+ final int workerCount = workersForStage.size();
// Guarded by synchronized mergedStatisticsCollector
final Set finishedWorkers = new HashSet<>();
+ log.info(
+ "Fetching stats using %s for stage[%d] for workers[%s] ",
+ ClusterStatisticsMergeMode.PARALLEL,
+ stageDefinition.getStageNumber(),
+ workersForStage.stream().map(Object::toString).collect(Collectors.joining(","))
+ );
+
// Submit a task for each worker to fetch statistics
- IntStream.range(0, workerCount).forEach(workerNo -> {
+ workersForStage.forEach(workerNo -> {
executorService.submit(() -> {
ListenableFuture snapshotFuture =
workerClient.fetchClusterByStatisticsSnapshot(
@@ -177,6 +200,13 @@ public class WorkerSketchFetcher implements AutoCloseable
workerTaskIds,
completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().entrySet().iterator()
);
+
+ log.info(
+ "Fetching stats using %s for stage[%d] for tasks[%s]",
+ ClusterStatisticsMergeMode.SEQUENTIAL,
+ stageDefinition.getStageNumber(),
+ String.join("", workerTaskIds)
+ );
sequentialFetchStage.submitFetchingTasksForNextTimeChunk();
return sequentialFetchStage.getPartitionFuture();
}
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/WorkerChatHandler.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/WorkerChatHandler.java
index dd6ea7cb712..3eae3b05ccf 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/WorkerChatHandler.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/WorkerChatHandler.java
@@ -19,11 +19,13 @@
package org.apache.druid.msq.indexing;
+import com.google.common.collect.ImmutableMap;
import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.commons.lang.mutable.MutableLong;
import org.apache.druid.frame.file.FrameFileHttpResponseHandler;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.indexing.common.TaskToolbox;
+import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.exec.Worker;
import org.apache.druid.msq.kernel.StageId;
@@ -71,7 +73,7 @@ public class WorkerChatHandler implements ChatHandler
/**
* Returns up to {@link #CHANNEL_DATA_CHUNK_SIZE} bytes of stage output data.
- *
+ *
* See {@link org.apache.druid.msq.exec.WorkerClient#fetchChannelData} for the client-side code that calls this API.
*/
@GET
@@ -193,17 +195,30 @@ public class WorkerChatHandler implements ChatHandler
ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper());
ClusterByStatisticsSnapshot clusterByStatisticsSnapshot;
StageId stageId = new StageId(queryId, stageNumber);
- clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
- return Response.status(Response.Status.ACCEPTED)
- .entity(clusterByStatisticsSnapshot)
- .build();
+ try {
+ clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
+ return Response.status(Response.Status.ACCEPTED)
+ .entity(clusterByStatisticsSnapshot)
+ .build();
+ }
+ catch (Exception e) {
+ String errorMessage = StringUtils.format(
+ "Invalid request for key statistics for query[%s] and stage[%d]",
+ queryId,
+ stageNumber
+ );
+ log.error(e, errorMessage);
+ return Response.status(Response.Status.BAD_REQUEST)
+ .entity(ImmutableMap.of("error", errorMessage))
+ .build();
+ }
}
@POST
@Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
- public Response httpSketch(
+ public Response httpFetchKeyStatisticsWithSnapshot(
@PathParam("queryId") final String queryId,
@PathParam("stageNumber") final int stageNumber,
@PathParam("timeChunk") final long timeChunk,
@@ -213,10 +228,24 @@ public class WorkerChatHandler implements ChatHandler
ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper());
ClusterByStatisticsSnapshot snapshotForTimeChunk;
StageId stageId = new StageId(queryId, stageNumber);
- snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk);
- return Response.status(Response.Status.ACCEPTED)
- .entity(snapshotForTimeChunk)
- .build();
+ try {
+ snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk);
+ return Response.status(Response.Status.ACCEPTED)
+ .entity(snapshotForTimeChunk)
+ .build();
+ }
+ catch (Exception e) {
+ String errorMessage = StringUtils.format(
+ "Invalid request for key statistics for query[%s], stage[%d] and timeChunk[%d]",
+ queryId,
+ stageNumber,
+ timeChunk
+ );
+ log.error(e, errorMessage);
+ return Response.status(Response.Status.BAD_REQUEST)
+ .entity(ImmutableMap.of("error", errorMessage))
+ .build();
+ }
}
/**
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java
index b0ed8e5c19d..00a49656be4 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java
@@ -48,8 +48,9 @@ public class WorkerStageKernel
private WorkerStagePhase phase = WorkerStagePhase.NEW;
+ // We read this variable in the main thread and the netty threads
@Nullable
- private ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot;
+ private volatile ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot;
@Nullable
private ClusterByPartitions resultPartitionBoundaries;
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsSnapshot.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsSnapshot.java
index e54253ad218..16a4c1656b0 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsSnapshot.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsSnapshot.java
@@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.frame.key.RowKey;
+import org.apache.druid.java.util.common.ISE;
import javax.annotation.Nullable;
import java.util.Collections;
@@ -61,6 +62,9 @@ public class ClusterByStatisticsSnapshot
public ClusterByStatisticsSnapshot getSnapshotForTimeChunk(long timeChunk)
{
Bucket bucket = buckets.get(timeChunk);
+ if (bucket == null) {
+ throw new ISE("ClusterByStatistics not present for requested timechunk %s", timeChunk);
+ }
return new ClusterByStatisticsSnapshot(ImmutableMap.of(timeChunk, bucket), null);
}
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
index 7c589f2326f..3dc622870f4 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
@@ -60,7 +60,7 @@ public class MultiStageQueryContext
public static final String CTX_ENABLE_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage";
public static final String CTX_CLUSTER_STATISTICS_MERGE_MODE = "clusterStatisticsMergeMode";
- public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.AUTO.toString();
+ public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.PARALLEL.toString();
private static final boolean DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = false;
public static final String CTX_DESTINATION = "destination";
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
index f54d2fa880c..cf4e4052d35 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
@@ -128,6 +128,32 @@ public class MSQInsertTest extends MSQTestBase
}
+ @Test
+ public void testInsertOnFoo1WithTimeFunctionWithSequential()
+ {
+ RowSignature rowSignature = RowSignature.builder()
+ .add("__time", ColumnType.LONG)
+ .add("dim1", ColumnType.STRING)
+ .add("cnt", ColumnType.LONG).build();
+ Map context = ImmutableMap.builder()
+ .putAll(DEFAULT_MSQ_CONTEXT)
+ .put(
+ MultiStageQueryContext.CTX_CLUSTER_STATISTICS_MERGE_MODE,
+ ClusterStatisticsMergeMode.SEQUENTIAL.toString()
+ )
+ .build();
+
+ testIngestQuery().setSql(
+ "insert into foo1 select floor(__time to day) as __time , dim1 , count(*) as cnt from foo where dim1 is not null group by 1, 2 PARTITIONED by day clustered by dim1")
+ .setQueryContext(context)
+ .setExpectedDataSource("foo1")
+ .setExpectedRowSignature(rowSignature)
+ .setExpectedSegment(expectedFooSegments())
+ .setExpectedResultRows(expectedFooRows())
+ .verifyResults();
+
+ }
+
@Test
public void testInsertOnFoo1WithMultiValueDim()
{
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java
new file mode 100644
index 00000000000..52231a116b6
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.exec;
+
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.msq.indexing.MSQWorkerTask;
+import org.apache.druid.msq.kernel.StageId;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.util.HashMap;
+
+
+@RunWith(MockitoJUnitRunner.class)
+public class WorkerImplTest
+{
+ @Mock
+ WorkerContext workerContext;
+
+ @Test
+ public void testFetchStatsThrows()
+ {
+ WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>()), workerContext);
+ Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshot(new StageId("xx", 1)));
+ }
+
+ @Test
+ public void testFetchStatsWithTimeChunkThrows()
+ {
+ WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>()), workerContext);
+ Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshotForTimeChunk(new StageId("xx", 1), 1L));
+ }
+
+}
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherAutoModeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherAutoModeTest.java
index 42f6f0437f5..02be2876f98 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherAutoModeTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherAutoModeTest.java
@@ -19,6 +19,7 @@
package org.apache.druid.msq.exec;
+import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
@@ -56,7 +57,7 @@ public class WorkerSketchFetcherAutoModeTest
target = spy(new WorkerSketchFetcher(mock(WorkerClient.class), ClusterStatisticsMergeMode.AUTO, 300_000_000));
// Don't actually try to fetch sketches
- doReturn(null).when(target).inMemoryFullSketchMerging(any(), any());
+ doReturn(null).when(target).inMemoryFullSketchMerging(any(), any(), any());
doReturn(null).when(target).sequentialTimeChunkMerging(any(), any(), any());
doReturn(StageId.fromString("1_1")).when(stageDefinition).getId();
@@ -81,8 +82,13 @@ public class WorkerSketchFetcherAutoModeTest
// Worker count below threshold
doReturn(1).when(stageDefinition).getMaxWorkerCount();
- target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
- verify(target, times(1)).inMemoryFullSketchMerging(any(), any());
+ target.submitFetcherTask(
+ completeKeyStatisticsInformation,
+ Collections.emptyList(),
+ stageDefinition,
+ IntSet.of()
+ );
+ verify(target, times(1)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(0)).sequentialTimeChunkMerging(any(), any(), any());
}
@@ -98,8 +104,13 @@ public class WorkerSketchFetcherAutoModeTest
// Worker count below threshold
doReturn((int) WorkerSketchFetcher.WORKER_THRESHOLD + 1).when(stageDefinition).getMaxWorkerCount();
- target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
- verify(target, times(0)).inMemoryFullSketchMerging(any(), any());
+ target.submitFetcherTask(
+ completeKeyStatisticsInformation,
+ Collections.emptyList(),
+ stageDefinition,
+ IntSet.of()
+ );
+ verify(target, times(0)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(1)).sequentialTimeChunkMerging(any(), any(), any());
}
@@ -115,8 +126,13 @@ public class WorkerSketchFetcherAutoModeTest
// Worker count above threshold
doReturn((int) WorkerSketchFetcher.WORKER_THRESHOLD + 1).when(stageDefinition).getMaxWorkerCount();
- target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
- verify(target, times(1)).inMemoryFullSketchMerging(any(), any());
+ target.submitFetcherTask(
+ completeKeyStatisticsInformation,
+ Collections.emptyList(),
+ stageDefinition,
+ IntSet.of()
+ );
+ verify(target, times(1)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(0)).sequentialTimeChunkMerging(any(), any(), any());
}
@@ -132,8 +148,13 @@ public class WorkerSketchFetcherAutoModeTest
// Worker count below threshold
doReturn(1).when(stageDefinition).getMaxWorkerCount();
- target.submitFetcherTask(completeKeyStatisticsInformation, Collections.emptyList(), stageDefinition);
- verify(target, times(0)).inMemoryFullSketchMerging(any(), any());
+ target.submitFetcherTask(
+ completeKeyStatisticsInformation,
+ Collections.emptyList(),
+ stageDefinition,
+ IntSet.of()
+ );
+ verify(target, times(0)).inMemoryFullSketchMerging(any(), any(), any());
verify(target, times(1)).sequentialTimeChunkMerging(any(), any(), any());
}
}
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java
index 83fb73043bd..fc244900361 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java
@@ -23,6 +23,8 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.util.concurrent.Futures;
+import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
+import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
@@ -88,11 +90,19 @@ public class WorkerSketchFetcherTest
doReturn(clusterBy).when(stageDefinition).getClusterBy();
doReturn(25_000).when(stageDefinition).getMaxPartitionCount();
- expectedPartitions1 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(mock(RowKey.class), mock(RowKey.class))));
- expectedPartitions2 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(mock(RowKey.class), mock(RowKey.class))));
+ expectedPartitions1 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(
+ mock(RowKey.class),
+ mock(RowKey.class)
+ )));
+ expectedPartitions2 = new ClusterByPartitions(ImmutableList.of(new ClusterByPartition(
+ mock(RowKey.class),
+ mock(RowKey.class)
+ )));
- doReturn(Either.value(expectedPartitions1)).when(stageDefinition).generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector1));
- doReturn(Either.value(expectedPartitions2)).when(stageDefinition).generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector2));
+ doReturn(Either.value(expectedPartitions1)).when(stageDefinition)
+ .generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector1));
+ doReturn(Either.value(expectedPartitions2)).when(stageDefinition)
+ .generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector2));
doReturn(
mergedClusterByStatisticsCollector1,
@@ -128,10 +138,14 @@ public class WorkerSketchFetcherTest
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt());
+ IntSet workersForStage = new IntAVLTreeSet();
+ workersForStage.addAll(ImmutableSet.of(0, 1, 2, 3, 4));
+
CompletableFuture> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
workerIds,
- stageDefinition
+ stageDefinition,
+ workersForStage
);
// Assert that the final result is complete and all other sketches returned have been merged.
@@ -154,7 +168,12 @@ public class WorkerSketchFetcherTest
// Store snapshots in a queue
final Queue snapshotQueue = new ConcurrentLinkedQueue<>();
- SortedMap> timeSegmentVsWorkerMap = ImmutableSortedMap.of(1L, ImmutableSet.of(0, 1, 2), 2L, ImmutableSet.of(0, 1, 4));
+ SortedMap> timeSegmentVsWorkerMap = ImmutableSortedMap.of(
+ 1L,
+ ImmutableSet.of(0, 1, 2),
+ 2L,
+ ImmutableSet.of(0, 1, 4)
+ );
doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap();
final CyclicBarrier barrier = new CyclicBarrier(3);
@@ -168,10 +187,14 @@ public class WorkerSketchFetcherTest
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyInt(), anyLong());
+ IntSet workersForStage = new IntAVLTreeSet();
+ workersForStage.addAll(ImmutableSet.of(0, 1, 2, 3, 4));
+
CompletableFuture> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
ImmutableList.of("0", "1", "2", "3", "4"),
- stageDefinition
+ stageDefinition,
+ workersForStage
);
// Assert that the final result is complete and all other sketches returned have been merged.
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java
new file mode 100644
index 00000000000..5b9d6e497aa
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.indexing;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.indexer.TaskStatus;
+import org.apache.druid.indexing.common.TaskReport;
+import org.apache.druid.indexing.common.TaskReportFileWriter;
+import org.apache.druid.indexing.common.TaskToolbox;
+import org.apache.druid.jackson.DefaultObjectMapper;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.msq.counters.CounterSnapshotsTree;
+import org.apache.druid.msq.exec.Worker;
+import org.apache.druid.msq.kernel.StageId;
+import org.apache.druid.msq.kernel.WorkOrder;
+import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
+import org.apache.druid.segment.IndexIO;
+import org.apache.druid.segment.IndexMergerV9;
+import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
+import org.apache.druid.server.security.AuthConfig;
+import org.apache.druid.server.security.AuthenticationResult;
+import org.apache.druid.sql.calcite.util.CalciteTests;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+import javax.annotation.Nullable;
+import javax.servlet.http.HttpServletRequest;
+import javax.ws.rs.core.Response;
+import java.io.InputStream;
+import java.util.HashMap;
+import java.util.Map;
+
+public class WorkerChatHandlerTest
+{
+ private static final StageId TEST_STAGE = new StageId("123", 0);
+ @Mock
+ private HttpServletRequest req;
+
+ private TaskToolbox toolbox;
+ private AutoCloseable mocks;
+
+ private final TestWorker worker = new TestWorker();
+
+ @Before
+ public void setUp()
+ {
+ ObjectMapper mapper = new DefaultObjectMapper();
+ IndexIO indexIO = new IndexIO(mapper, () -> 0);
+ IndexMergerV9 indexMerger = new IndexMergerV9(
+ mapper,
+ indexIO,
+ OffHeapMemorySegmentWriteOutMediumFactory.instance()
+ );
+
+ mocks = MockitoAnnotations.openMocks(this);
+ Mockito.when(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT))
+ .thenReturn(new AuthenticationResult("druid", "druid", null, null));
+ TaskToolbox.Builder builder = new TaskToolbox.Builder();
+ toolbox = builder.authorizerMapper(CalciteTests.TEST_AUTHORIZER_MAPPER)
+ .indexIO(indexIO)
+ .indexMergerV9(indexMerger)
+ .taskReportFileWriter(
+ new TaskReportFileWriter()
+ {
+ @Override
+ public void write(String taskId, Map reports)
+ {
+
+ }
+
+ @Override
+ public void setObjectMapper(ObjectMapper objectMapper)
+ {
+
+ }
+ }
+ )
+ .build();
+ }
+
+ @Test
+ public void testFetchSnapshot()
+ {
+ WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
+ Assert.assertEquals(
+ ClusterByStatisticsSnapshot.empty(),
+ chatHandler.httpFetchKeyStatistics(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), req)
+ .getEntity()
+ );
+ }
+
+ @Test
+ public void testFetchSnapshot404()
+ {
+ WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
+ Assert.assertEquals(
+ Response.Status.BAD_REQUEST.getStatusCode(),
+ chatHandler.httpFetchKeyStatistics("123", 2, req)
+ .getStatus()
+ );
+ }
+
+ @Test
+ public void testFetchSnapshotWithTimeChunk()
+ {
+ WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
+ Assert.assertEquals(
+ ClusterByStatisticsSnapshot.empty(),
+ chatHandler.httpFetchKeyStatisticsWithSnapshot(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), 1, req)
+ .getEntity()
+ );
+ }
+
+ @Test
+ public void testFetchSnapshotWithTimeChunk404()
+ {
+ WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
+ Assert.assertEquals(
+ Response.Status.BAD_REQUEST.getStatusCode(),
+ chatHandler.httpFetchKeyStatisticsWithSnapshot("123", 2, 1, req)
+ .getStatus()
+ );
+ }
+
+
+ private static class TestWorker implements Worker
+ {
+
+ @Override
+ public String id()
+ {
+ return TEST_STAGE.getQueryId() + "task";
+ }
+
+ @Override
+ public MSQWorkerTask task()
+ {
+ return new MSQWorkerTask("controller", "ds", 1, new HashMap<>());
+ }
+
+ @Override
+ public TaskStatus run()
+ {
+ return null;
+ }
+
+ @Override
+ public void stopGracefully()
+ {
+
+ }
+
+ @Override
+ public void controllerFailed()
+ {
+
+ }
+
+ @Override
+ public void postWorkOrder(WorkOrder workOrder)
+ {
+
+ }
+
+ @Override
+ public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
+ {
+ if (TEST_STAGE.equals(stageId)) {
+ return ClusterByStatisticsSnapshot.empty();
+ } else {
+ throw new ISE("stage not found %s", stageId);
+ }
+ }
+
+ @Override
+ public ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk)
+ {
+ if (TEST_STAGE.equals(stageId)) {
+ return ClusterByStatisticsSnapshot.empty();
+ } else {
+ throw new ISE("stage not found %s", stageId);
+ }
+ }
+
+ @Override
+ public boolean postResultPartitionBoundaries(
+ ClusterByPartitions stagePartitionBoundaries,
+ String queryId,
+ int stageNumber
+ )
+ {
+ return false;
+ }
+
+ @Nullable
+ @Override
+ public InputStream readChannel(String queryId, int stageNumber, int partitionNumber, long offset)
+ {
+ return null;
+ }
+
+ @Override
+ public CounterSnapshotsTree getCounters()
+ {
+ return null;
+ }
+
+ @Override
+ public void postCleanupStage(StageId stageId)
+ {
+
+ }
+
+ @Override
+ public void postFinish()
+ {
+
+ }
+ }
+
+ @After
+ public void tearDown()
+ {
+ try {
+ mocks.close();
+ }
+ catch (Exception ignored) {
+ // ignore tear down exceptions
+ }
+ }
+}