From ccb7c2edd959a26d7f96528a2779b80051815ac1 Mon Sep 17 00:00:00 2001 From: Karan Kumar Date: Tue, 8 Oct 2024 19:46:40 +0530 Subject: [PATCH] [Backport] Dart and security backports (#17249) (#17278) (#17281) (#17282) (#17283) (#17277) (#17285) * MSQ: Allow for worker gaps. (#17277) * DartSqlResource: Sort queries by start time. (#17282) * DartSqlResource: Add controllerHost to GetQueriesResponse. (#17283) * DartWorkerModule: Replace en dash with regular dash. (#17281) * DartSqlResource: Return HTTP 202 on cancellation even if no such query. (#17278) * Upgraded Protobuf to 3.25.5 (#17249) --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> (cherry picked from commit 7d9e6d36fddd7893825d1fa2f5da2e20f67c5de8) --------- Co-authored-by: Gian Merlino Co-authored-by: Shivam Garg --- .../msq/dart/controller/ControllerHolder.java | 8 + .../dart/controller/http/DartQueryInfo.java | 23 ++- .../dart/controller/http/DartSqlResource.java | 15 +- .../dart/controller/sql/DartQueryMaker.java | 1 + .../msq/dart/guice/DartWorkerModule.java | 2 +- .../msq/input/stage/ReadablePartition.java | 12 ++ .../msq/input/stage/ReadablePartitions.java | 33 +++- .../SparseStripedReadablePartitions.java | 142 ++++++++++++++++++ .../controller/ControllerStageTracker.java | 26 ++-- .../msq/kernel/controller/WorkerInputs.java | 41 +++-- .../controller/http/DartSqlResourceTest.java | 18 ++- .../http/GetQueriesResponseTest.java | 1 + .../controller/sql/DartSqlClientImplTest.java | 2 + .../CollectedReadablePartitionsTest.java | 12 +- .../stage/CombinedReadablePartitionsTest.java | 2 +- .../SparseStripedReadablePartitionsTest.java | 98 ++++++++++++ .../stage/StripedReadablePartitionsTest.java | 34 ++++- .../kernel/controller/WorkerInputsTest.java | 98 +++++++++--- licenses.yaml | 4 +- pom.xml | 2 +- 20 files changed, 499 insertions(+), 75 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java index 9644444dad2..ca60ee3cbc1 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java @@ -59,6 +59,7 @@ public class ControllerHolder private final ControllerContext controllerContext; private final String sqlQueryId; private final String sql; + private final String controllerHost; private final AuthenticationResult authenticationResult; private final DateTime startTime; private final AtomicReference state = new AtomicReference<>(State.ACCEPTED); @@ -68,6 +69,7 @@ public class ControllerHolder final ControllerContext controllerContext, final String sqlQueryId, final String sql, + final String controllerHost, final AuthenticationResult authenticationResult, final DateTime startTime ) @@ -76,6 +78,7 @@ public class ControllerHolder this.controllerContext = controllerContext; this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId"); this.sql = sql; + this.controllerHost = controllerHost; this.authenticationResult = authenticationResult; this.startTime = Preconditions.checkNotNull(startTime, "startTime"); } @@ -95,6 +98,11 @@ public class ControllerHolder return sql; } + public String getControllerHost() + { + return controllerHost; + } + public AuthenticationResult getAuthenticationResult() { return authenticationResult; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java index e5f3abb894e..2bc5d08704d 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java @@ -26,6 +26,7 @@ import com.google.common.base.Preconditions; import org.apache.druid.msq.dart.controller.ControllerHolder; import org.apache.druid.msq.util.MSQTaskQueryMakerUtils; import org.apache.druid.query.QueryContexts; +import org.apache.druid.server.DruidNode; import org.joda.time.DateTime; import java.util.Objects; @@ -38,6 +39,7 @@ public class DartQueryInfo private final String sqlQueryId; private final String dartQueryId; private final String sql; + private final String controllerHost; private final String authenticator; private final String identity; private final DateTime startTime; @@ -48,6 +50,7 @@ public class DartQueryInfo @JsonProperty("sqlQueryId") final String sqlQueryId, @JsonProperty("dartQueryId") final String dartQueryId, @JsonProperty("sql") final String sql, + @JsonProperty("controllerHost") final String controllerHost, @JsonProperty("authenticator") final String authenticator, @JsonProperty("identity") final String identity, @JsonProperty("startTime") final DateTime startTime, @@ -57,6 +60,7 @@ public class DartQueryInfo this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId"); this.dartQueryId = Preconditions.checkNotNull(dartQueryId, "dartQueryId"); this.sql = sql; + this.controllerHost = controllerHost; this.authenticator = authenticator; this.identity = identity; this.startTime = startTime; @@ -69,6 +73,7 @@ public class DartQueryInfo holder.getSqlQueryId(), holder.getController().queryId(), MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(holder.getSql()), + holder.getControllerHost(), holder.getAuthenticationResult().getAuthenticatedBy(), holder.getAuthenticationResult().getIdentity(), holder.getStartTime(), @@ -104,6 +109,16 @@ public class DartQueryInfo return sql; } + /** + * Controller host:port, from {@link DruidNode#getHostAndPortToUse()}, that is executing this query. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public String getControllerHost() + { + return controllerHost; + } + /** * Authenticator that authenticated the identity from {@link #getIdentity()}. */ @@ -145,7 +160,7 @@ public class DartQueryInfo */ public DartQueryInfo withoutAuthenticationResult() { - return new DartQueryInfo(sqlQueryId, dartQueryId, sql, null, null, startTime, state); + return new DartQueryInfo(sqlQueryId, dartQueryId, sql, controllerHost, null, null, startTime, state); } @Override @@ -161,6 +176,7 @@ public class DartQueryInfo return Objects.equals(sqlQueryId, that.sqlQueryId) && Objects.equals(dartQueryId, that.dartQueryId) && Objects.equals(sql, that.sql) + && Objects.equals(controllerHost, that.controllerHost) && Objects.equals(authenticator, that.authenticator) && Objects.equals(identity, that.identity) && Objects.equals(startTime, that.startTime) @@ -170,7 +186,7 @@ public class DartQueryInfo @Override public int hashCode() { - return Objects.hash(sqlQueryId, dartQueryId, sql, authenticator, identity, startTime, state); + return Objects.hash(sqlQueryId, dartQueryId, sql, controllerHost, authenticator, identity, startTime, state); } @Override @@ -180,10 +196,11 @@ public class DartQueryInfo "sqlQueryId='" + sqlQueryId + '\'' + ", dartQueryId='" + dartQueryId + '\'' + ", sql='" + sql + '\'' + + ", controllerHost='" + controllerHost + '\'' + ", authenticator='" + authenticator + '\'' + ", identity='" + identity + '\'' + ", startTime=" + startTime + - ", state=" + state + + ", state='" + state + '\'' + '}'; } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java index 37e9f105131..65d770a29c5 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java @@ -154,7 +154,6 @@ public class DartSqlResource extends SqlResource controllerRegistry.getAllHolders() .stream() .map(DartQueryInfo::fromControllerHolder) - .sorted(Comparator.comparing(DartQueryInfo::getStartTime)) .collect(Collectors.toList()); // Add queries from all other servers, if "selfOnly" is not set. @@ -172,6 +171,9 @@ public class DartSqlResource extends SqlResource } } + // Sort queries by start time, breaking ties by query ID, so the list comes back in a consistent and nice order. + queries.sort(Comparator.comparing(DartQueryInfo::getStartTime).thenComparing(DartQueryInfo::getDartQueryId)); + final GetQueriesResponse response; if (stateReadAccess.isAllowed()) { // User can READ STATE, so they can see all running queries, as well as authentication details. @@ -237,7 +239,10 @@ public class DartSqlResource extends SqlResource List cancelables = sqlLifecycleManager.getAll(sqlQueryId); if (cancelables.isEmpty()) { - return Response.status(Response.Status.NOT_FOUND).build(); + // Return ACCEPTED even if the query wasn't found. When the Router broadcasts cancellation requests to all + // Brokers, this ensures the user sees a successful request. + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req); + return Response.status(Response.Status.ACCEPTED).build(); } final Access access = authorizeCancellation(req, cancelables); @@ -247,14 +252,12 @@ public class DartSqlResource extends SqlResource // Don't call cancel() on the cancelables. That just cancels native queries, which is useless here. Instead, // get the controller and stop it. - boolean found = false; for (SqlLifecycleManager.Cancelable cancelable : cancelables) { final HttpStatement stmt = (HttpStatement) cancelable; final Object dartQueryId = stmt.context().get(DartSqlEngine.CTX_DART_QUERY_ID); if (dartQueryId instanceof String) { final ControllerHolder holder = controllerRegistry.get((String) dartQueryId); if (holder != null) { - found = true; holder.cancel(); } } else { @@ -267,7 +270,9 @@ public class DartSqlResource extends SqlResource } } - return Response.status(found ? Response.Status.ACCEPTED : Response.Status.NOT_FOUND).build(); + // Return ACCEPTED even if the query wasn't found. When the Router broadcasts cancellation requests to all + // Brokers, this ensures the user sees a successful request. + return Response.status(Response.Status.ACCEPTED).build(); } else { return Response.status(Response.Status.FORBIDDEN).build(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java index 37ed936a117..66686f7640f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java @@ -154,6 +154,7 @@ public class DartQueryMaker implements QueryMaker controllerContext, plannerContext.getSqlQueryId(), plannerContext.getSql(), + controllerContext.selfNode().getHostAndPortToUse(), plannerContext.getAuthenticationResult(), DateTimes.nowUtc() ); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java index 15bc0e65299..e9bd59f53d8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java @@ -113,7 +113,7 @@ public class DartWorkerModule implements DruidModule final AuthorizerMapper authorizerMapper ) { - final ExecutorService exec = Execs.multiThreaded(memoryIntrospector.numTasksInJvm(), "dart–worker-%s"); + final ExecutorService exec = Execs.multiThreaded(memoryIntrospector.numTasksInJvm(), "dart-worker-%s"); final File baseTempDir = new File(processingConfig.getTmpDir(), StringUtils.format("dart_%s", selfNode.getPortToUse())); return new DartWorkerRunner( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java index 99098d1d4cb..5f366c60009 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java @@ -59,6 +59,18 @@ public class ReadablePartition return new ReadablePartition(stageNumber, workerNumbers, partitionNumber); } + /** + * Returns an output partition that is striped across a set of {@code workerNumbers}. + */ + public static ReadablePartition striped( + final int stageNumber, + final IntSortedSet workerNumbers, + final int partitionNumber + ) + { + return new ReadablePartition(stageNumber, workerNumbers, partitionNumber); + } + /** * Returns an output partition that has been collected onto a single worker. */ diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java index a71535fbcfc..dcf0042f68b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo; import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2IntSortedMap; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; +import it.unimi.dsi.fastutil.ints.IntSortedSet; import java.util.Collections; import java.util.List; @@ -39,6 +40,7 @@ import java.util.Map; @JsonSubTypes(value = { @JsonSubTypes.Type(name = "collected", value = CollectedReadablePartitions.class), @JsonSubTypes.Type(name = "striped", value = StripedReadablePartitions.class), + @JsonSubTypes.Type(name = "sparseStriped", value = SparseStripedReadablePartitions.class), @JsonSubTypes.Type(name = "combined", value = CombinedReadablePartitions.class) }) public interface ReadablePartitions extends Iterable @@ -59,7 +61,7 @@ public interface ReadablePartitions extends Iterable /** * Combines various sets of partitions into a single set. */ - static CombinedReadablePartitions combine(List readablePartitions) + static ReadablePartitions combine(List readablePartitions) { return new CombinedReadablePartitions(readablePartitions); } @@ -68,7 +70,7 @@ public interface ReadablePartitions extends Iterable * Returns a set of {@code numPartitions} partitions striped across {@code numWorkers} workers: each worker contains * a "stripe" of each partition. */ - static StripedReadablePartitions striped( + static ReadablePartitions striped( final int stageNumber, final int numWorkers, final int numPartitions @@ -82,11 +84,36 @@ public interface ReadablePartitions extends Iterable return new StripedReadablePartitions(stageNumber, numWorkers, partitionNumbers); } + /** + * Returns a set of {@code numPartitions} partitions striped across {@code workers}: each worker contains + * a "stripe" of each partition. + */ + static ReadablePartitions striped( + final int stageNumber, + final IntSortedSet workers, + final int numPartitions + ) + { + final IntAVLTreeSet partitionNumbers = new IntAVLTreeSet(); + for (int i = 0; i < numPartitions; i++) { + partitionNumbers.add(i); + } + + if (workers.lastInt() == workers.size() - 1) { + // Dense worker set. Use StripedReadablePartitions for compactness (send a single number rather than the + // entire worker set) and for backwards compatibility (older workers cannot understand + // SparseStripedReadablePartitions). + return new StripedReadablePartitions(stageNumber, workers.size(), partitionNumbers); + } else { + return new SparseStripedReadablePartitions(stageNumber, workers, partitionNumbers); + } + } + /** * Returns a set of partitions that have been collected onto specific workers: each partition is on exactly * one worker. */ - static CollectedReadablePartitions collected( + static ReadablePartitions collected( final int stageNumber, final Map partitionToWorkerMap ) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java new file mode 100644 index 00000000000..e9a02a7d488 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.input.stage; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.Iterators; +import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; +import it.unimi.dsi.fastutil.ints.IntSortedSet; +import org.apache.druid.msq.input.SlicerUtils; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +/** + * Set of partitions striped across a sparse set of {@code workers}. Each worker contains a "stripe" of each partition. + * + * @see StripedReadablePartitions dense version, where workers from [0..N) are all used. + */ +public class SparseStripedReadablePartitions implements ReadablePartitions +{ + private final int stageNumber; + private final IntSortedSet workers; + private final IntSortedSet partitionNumbers; + + /** + * Constructor. Most callers should use {@link ReadablePartitions#striped(int, int, int)} instead, which takes + * a partition count rather than a set of partition numbers. + */ + public SparseStripedReadablePartitions( + final int stageNumber, + final IntSortedSet workers, + final IntSortedSet partitionNumbers + ) + { + this.stageNumber = stageNumber; + this.workers = workers; + this.partitionNumbers = partitionNumbers; + } + + @JsonCreator + private SparseStripedReadablePartitions( + @JsonProperty("stageNumber") final int stageNumber, + @JsonProperty("workers") final Set workers, + @JsonProperty("partitionNumbers") final Set partitionNumbers + ) + { + this(stageNumber, new IntAVLTreeSet(workers), new IntAVLTreeSet(partitionNumbers)); + } + + @Override + public Iterator iterator() + { + return Iterators.transform( + partitionNumbers.iterator(), + partitionNumber -> ReadablePartition.striped(stageNumber, workers, partitionNumber) + ); + } + + @Override + public List split(final int maxNumSplits) + { + final List retVal = new ArrayList<>(); + + for (List entries : SlicerUtils.makeSlicesStatic(partitionNumbers.iterator(), maxNumSplits)) { + if (!entries.isEmpty()) { + retVal.add(new SparseStripedReadablePartitions(stageNumber, workers, new IntAVLTreeSet(entries))); + } + } + + return retVal; + } + + @JsonProperty + int getStageNumber() + { + return stageNumber; + } + + @JsonProperty + IntSortedSet getWorkers() + { + return workers; + } + + @JsonProperty + IntSortedSet getPartitionNumbers() + { + return partitionNumbers; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SparseStripedReadablePartitions that = (SparseStripedReadablePartitions) o; + return stageNumber == that.stageNumber + && Objects.equals(workers, that.workers) + && Objects.equals(partitionNumbers, that.partitionNumbers); + } + + @Override + public int hashCode() + { + return Objects.hash(stageNumber, workers, partitionNumbers); + } + + @Override + public String toString() + { + return "StripedReadablePartitions{" + + "stageNumber=" + stageNumber + + ", workers=" + workers + + ", partitionNumbers=" + partitionNumbers + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java index 338a35e0d24..533cb57b97f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java @@ -403,7 +403,7 @@ class ControllerStageTracker throw new ISE("Stage does not gather result key statistics"); } - if (workerNumber < 0 || workerNumber >= workerCount) { + if (!workerInputs.workers().contains(workerNumber)) { throw new IAE("Invalid workerNumber [%s]", workerNumber); } @@ -522,7 +522,7 @@ class ControllerStageTracker throw new ISE("Stage does not gather result key statistics"); } - if (workerNumber < 0 || workerNumber >= workerCount) { + if (!workerInputs.workers().contains(workerNumber)) { throw new IAE("Invalid workerNumber [%s]", workerNumber); } @@ -656,7 +656,7 @@ class ControllerStageTracker throw new ISE("Stage does not gather result key statistics"); } - if (workerNumber < 0 || workerNumber >= workerCount) { + if (!workerInputs.workers().contains(workerNumber)) { throw new IAE("Invalid workerNumber [%s]", workerNumber); } @@ -763,7 +763,7 @@ class ControllerStageTracker this.resultPartitionBoundaries = clusterByPartitions; this.resultPartitions = ReadablePartitions.striped( stageDef.getStageNumber(), - workerCount, + workerInputs.workers(), clusterByPartitions.size() ); @@ -788,7 +788,7 @@ class ControllerStageTracker throw DruidException.defensive("Cannot setDoneReadingInput for stage[%s], it is not sorting", stageDef.getId()); } - if (workerNumber < 0 || workerNumber >= workerCount) { + if (!workerInputs.workers().contains(workerNumber)) { throw new IAE("Invalid workerNumber[%s] for stage[%s]", workerNumber, stageDef.getId()); } @@ -830,7 +830,7 @@ class ControllerStageTracker @SuppressWarnings("unchecked") boolean setResultsCompleteForWorker(final int workerNumber, final Object resultObject) { - if (workerNumber < 0 || workerNumber >= workerCount) { + if (!workerInputs.workers().contains(workerNumber)) { throw new IAE("Invalid workerNumber [%s]", workerNumber); } @@ -947,14 +947,18 @@ class ControllerStageTracker resultPartitionBoundaries = maybeResultPartitionBoundaries.valueOrThrow(); resultPartitions = ReadablePartitions.striped( stageNumber, - workerCount, + workerInputs.workers(), resultPartitionBoundaries.size() ); - } else if (shuffleSpec.kind() == ShuffleKind.MIX) { - resultPartitionBoundaries = ClusterByPartitions.oneUniversalPartition(); - resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount()); } else { - resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount()); + if (shuffleSpec.kind() == ShuffleKind.MIX) { + resultPartitionBoundaries = ClusterByPartitions.oneUniversalPartition(); + } + resultPartitions = ReadablePartitions.striped( + stageNumber, + workerInputs.workers(), + shuffleSpec.partitionCount() + ); } } else { // No reshuffling: retain partitioning from nonbroadcast inputs. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java index 83d7a602bc1..8dcaee9c213 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java @@ -24,7 +24,9 @@ import com.google.common.collect.Iterables; import it.unimi.dsi.fastutil.ints.Int2IntMap; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; -import it.unimi.dsi.fastutil.ints.IntSet; +import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap; +import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; +import it.unimi.dsi.fastutil.ints.IntSortedSet; import it.unimi.dsi.fastutil.objects.ObjectIterator; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSpec; @@ -45,9 +47,9 @@ import java.util.stream.IntStream; public class WorkerInputs { // Worker number -> input number -> input slice. - private final Int2ObjectMap> assignmentsMap; + private final Int2ObjectSortedMap> assignmentsMap; - private WorkerInputs(final Int2ObjectMap> assignmentsMap) + private WorkerInputs(final Int2ObjectSortedMap> assignmentsMap) { this.assignmentsMap = assignmentsMap; } @@ -64,7 +66,7 @@ public class WorkerInputs ) { // Split each inputSpec and assign to workers. This list maps worker number -> input number -> input slice. - final Int2ObjectMap> assignmentsMap = new Int2ObjectAVLTreeMap<>(); + final Int2ObjectSortedMap> assignmentsMap = new Int2ObjectAVLTreeMap<>(); final int numInputs = stageDef.getInputSpecs().size(); if (numInputs == 0) { @@ -117,8 +119,8 @@ public class WorkerInputs final ObjectIterator>> assignmentsIterator = assignmentsMap.int2ObjectEntrySet().iterator(); + final IntSortedSet nilWorkers = new IntAVLTreeSet(); - boolean first = true; while (assignmentsIterator.hasNext()) { final Int2ObjectMap.Entry> entry = assignmentsIterator.next(); final List slices = entry.getValue(); @@ -130,20 +132,29 @@ public class WorkerInputs } } - // Eliminate workers that have no non-nil, non-broadcast inputs. (Except the first one, because if all input - // is nil, *some* worker has to do *something*.) - final boolean hasNonNilNonBroadcastInput = + // Identify nil workers (workers with no non-broadcast inputs). + final boolean isNilWorker = IntStream.range(0, numInputs) - .anyMatch(i -> - !slices.get(i).equals(NilInputSlice.INSTANCE) // Non-nil - && !stageDef.getBroadcastInputNumbers().contains(i) // Non-broadcast + .allMatch(i -> + slices.get(i).equals(NilInputSlice.INSTANCE) // Nil regular input + || stageDef.getBroadcastInputNumbers().contains(i) // Broadcast ); - if (!first && !hasNonNilNonBroadcastInput) { - assignmentsIterator.remove(); + if (isNilWorker) { + nilWorkers.add(entry.getIntKey()); } + } - first = false; + if (nilWorkers.size() == assignmentsMap.size()) { + // All workers have nil regular inputs. Remove all workers exept the first (*some* worker has to do *something*). + final List firstSlices = assignmentsMap.get(nilWorkers.firstInt()); + assignmentsMap.clear(); + assignmentsMap.put(nilWorkers.firstInt(), firstSlices); + } else { + // Remove all nil workers. + for (final int nilWorker : nilWorkers) { + assignmentsMap.remove(nilWorker); + } } return new WorkerInputs(assignmentsMap); @@ -154,7 +165,7 @@ public class WorkerInputs return Preconditions.checkNotNull(assignmentsMap.get(workerNumber), "worker [%s]", workerNumber); } - public IntSet workers() + public IntSortedSet workers() { return assignmentsMap.keySet(); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java index db347917872..51e17235203 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java @@ -331,9 +331,10 @@ public class DartSqlResourceTest extends MSQTestBase "sid", "did2", "SELECT 2", + "localhost:1002", AUTHENTICATOR_NAME, DIFFERENT_REGULAR_USER_NAME, - DateTimes.of("2000"), + DateTimes.of("2001"), ControllerHolder.State.RUNNING.toString() ); Mockito.when(dartSqlClient.getRunningQueries(true)) @@ -398,6 +399,7 @@ public class DartSqlResourceTest extends MSQTestBase "sid", "did2", "SELECT 2", + "localhost:1002", AUTHENTICATOR_NAME, DIFFERENT_REGULAR_USER_NAME, DateTimes.of("2000"), @@ -434,6 +436,7 @@ public class DartSqlResourceTest extends MSQTestBase "sid", "did2", "SELECT 2", + "localhost:1002", AUTHENTICATOR_NAME, DIFFERENT_REGULAR_USER_NAME, DateTimes.of("2000"), @@ -724,7 +727,7 @@ public class DartSqlResourceTest extends MSQTestBase .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); final Response cancellationResponse = sqlResource.cancelQuery("nonexistent", httpServletRequest); - Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), cancellationResponse.getStatus()); + Assertions.assertEquals(Response.Status.ACCEPTED.getStatusCode(), cancellationResponse.getStatus()); } /** @@ -739,8 +742,15 @@ public class DartSqlResourceTest extends MSQTestBase Mockito.when(controller.queryId()).thenReturn("did_" + identity); final AuthenticationResult authenticationResult = makeAuthenticationResult(identity); - final ControllerHolder holder = - new ControllerHolder(controller, null, "sid", "SELECT 1", authenticationResult, DateTimes.of("2000")); + final ControllerHolder holder = new ControllerHolder( + controller, + null, + "sid", + "SELECT 1", + "localhost:1001", + authenticationResult, + DateTimes.of("2000") + ); controllerRegistry.register(holder); return holder; diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java index 7b43c863c9d..bffaace5745 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java @@ -41,6 +41,7 @@ public class GetQueriesResponseTest "xyz", "abc", "SELECT 1", + "localhost:1001", "auth", "anon", DateTimes.of("2000"), diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java index 19a4eaf0b15..114ea9c7207 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java @@ -69,6 +69,7 @@ public class DartSqlClientImplTest "sid", "did", "SELECT 1", + "localhost:1001", "", "", DateTimes.of("2000"), @@ -97,6 +98,7 @@ public class DartSqlClientImplTest "sid", "did", "SELECT 1", + "localhost:1001", "", "", DateTimes.of("2000"), diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java index 6ed7d2d43d4..d4db7a0a7c5 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java @@ -33,21 +33,24 @@ public class CollectedReadablePartitionsTest @Test public void testPartitionToWorkerMap() { - final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); + final CollectedReadablePartitions partitions = + (CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); Assert.assertEquals(ImmutableMap.of(0, 1, 1, 2, 2, 1), partitions.getPartitionToWorkerMap()); } @Test public void testStageNumber() { - final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); + final CollectedReadablePartitions partitions = + (CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); Assert.assertEquals(1, partitions.getStageNumber()); } @Test public void testSplit() { - final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); + final CollectedReadablePartitions partitions = + (CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); Assert.assertEquals( ImmutableList.of( @@ -64,7 +67,8 @@ public class CollectedReadablePartitionsTest final ObjectMapper mapper = TestHelper.makeJsonMapper() .registerModules(new MSQIndexingModule().getJacksonModules()); - final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); + final CollectedReadablePartitions partitions = + (CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); Assert.assertEquals( partitions, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java index 685f4ff7a8a..16bd047b624 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java @@ -31,7 +31,7 @@ import org.junit.Test; public class CombinedReadablePartitionsTest { - private static final CombinedReadablePartitions PARTITIONS = ReadablePartitions.combine( + private static final ReadablePartitions PARTITIONS = ReadablePartitions.combine( ImmutableList.of( ReadablePartitions.striped(0, 2, 2), ReadablePartitions.striped(1, 2, 4) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java new file mode 100644 index 00000000000..5268fd60180 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.input.stage; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.msq.guice.MSQIndexingModule; +import org.apache.druid.segment.TestHelper; +import org.junit.Assert; +import org.junit.Test; + +public class SparseStripedReadablePartitionsTest +{ + @Test + public void testPartitionNumbers() + { + final SparseStripedReadablePartitions partitions = + (SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3); + Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers()); + } + + @Test + public void testWorkers() + { + final SparseStripedReadablePartitions partitions = + (SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3); + Assert.assertEquals(IntSet.of(1, 3), partitions.getWorkers()); + } + + @Test + public void testStageNumber() + { + final SparseStripedReadablePartitions partitions = + (SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3); + Assert.assertEquals(1, partitions.getStageNumber()); + } + + @Test + public void testSplit() + { + final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3}); + final SparseStripedReadablePartitions partitions = + (SparseStripedReadablePartitions) ReadablePartitions.striped(1, workers, 3); + + Assert.assertEquals( + ImmutableList.of( + new SparseStripedReadablePartitions(1, workers, new IntAVLTreeSet(new int[]{0, 2})), + new SparseStripedReadablePartitions(1, workers, new IntAVLTreeSet(new int[]{1})) + ), + partitions.split(2) + ); + } + + @Test + public void testSerde() throws Exception + { + final ObjectMapper mapper = TestHelper.makeJsonMapper() + .registerModules(new MSQIndexingModule().getJacksonModules()); + + final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3}); + final ReadablePartitions partitions = ReadablePartitions.striped(1, workers, 3); + + Assert.assertEquals( + partitions, + mapper.readValue( + mapper.writeValueAsString(partitions), + ReadablePartitions.class + ) + ); + } + + @Test + public void testEquals() + { + EqualsVerifier.forClass(SparseStripedReadablePartitions.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java index 38e0707f5d0..05b42b33250 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java @@ -26,36 +26,60 @@ import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; import nl.jqno.equalsverifier.EqualsVerifier; import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.segment.TestHelper; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; import org.junit.Assert; import org.junit.Test; public class StripedReadablePartitionsTest { + @Test + public void testFromDenseSet() + { + // Tests that when ReadablePartitions.striped is called with a dense set, we get StripedReadablePartitions. + + final IntAVLTreeSet workers = new IntAVLTreeSet(); + workers.add(0); + workers.add(1); + + final ReadablePartitions readablePartitionsFromSet = ReadablePartitions.striped(1, workers, 3); + + MatcherAssert.assertThat( + readablePartitionsFromSet, + CoreMatchers.instanceOf(StripedReadablePartitions.class) + ); + + Assert.assertEquals( + ReadablePartitions.striped(1, 2, 3), + readablePartitionsFromSet + ); + } + @Test public void testPartitionNumbers() { - final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); + final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3); Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers()); } @Test public void testNumWorkers() { - final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); + final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3); Assert.assertEquals(2, partitions.getNumWorkers()); } @Test public void testStageNumber() { - final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); + final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3); Assert.assertEquals(1, partitions.getStageNumber()); } @Test public void testSplit() { - final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); + final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); Assert.assertEquals( ImmutableList.of( @@ -72,7 +96,7 @@ public class StripedReadablePartitionsTest final ObjectMapper mapper = TestHelper.makeJsonMapper() .registerModules(new MSQIndexingModule().getJacksonModules()); - final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); + final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); Assert.assertEquals( partitions, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java index 605e0bf2de7..e74125b0830 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java @@ -25,9 +25,11 @@ import it.unimi.dsi.fastutil.ints.Int2IntMaps; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; import it.unimi.dsi.fastutil.ints.IntSet; +import it.unimi.dsi.fastutil.ints.IntSortedSet; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongList; import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.error.DruidException; import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.InputSlice; @@ -75,7 +77,7 @@ public class WorkerInputsTest final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.MAX, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -91,6 +93,35 @@ public class WorkerInputsTest ); } + @Test + public void test_max_threeInputs_fourWorkers_withGaps() + { + final StageDefinition stageDef = + StageDefinition.builder(0) + .inputs(new TestInputSpec(1, 2, 3)) + .maxWorkerCount(4) + .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) + .build(QUERY_ID); + + final WorkerInputs inputs = WorkerInputs.create( + stageDef, + Int2IntMaps.EMPTY_MAP, + new TestInputSpecSlicer(new IntAVLTreeSet(new int[]{1, 3, 4, 5}), true), + WorkerAssignmentStrategy.MAX, + Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER + ); + + Assert.assertEquals( + ImmutableMap.>builder() + .put(1, Collections.singletonList(new TestInputSlice(1))) + .put(3, Collections.singletonList(new TestInputSlice(2))) + .put(4, Collections.singletonList(new TestInputSlice(3))) + .put(5, Collections.singletonList(new TestInputSlice())) + .build(), + inputs.assignmentsMap() + ); + } + @Test public void test_max_zeroInputs_fourWorkers() { @@ -104,7 +135,7 @@ public class WorkerInputsTest final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.MAX, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -133,7 +164,7 @@ public class WorkerInputsTest final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -159,7 +190,7 @@ public class WorkerInputsTest final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -186,7 +217,7 @@ public class WorkerInputsTest final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -212,7 +243,7 @@ public class WorkerInputsTest final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -324,7 +355,7 @@ public class WorkerInputsTest final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -351,7 +382,7 @@ public class WorkerInputsTest final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(2), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -384,7 +415,7 @@ public class WorkerInputsTest final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(1), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -411,7 +442,7 @@ public class WorkerInputsTest .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) .build(QUERY_ID); - TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true)); + TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true)); final WorkerInputs inputs = WorkerInputs.create( stageDef, @@ -455,7 +486,7 @@ public class WorkerInputsTest .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) .build(QUERY_ID); - TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true)); + TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true)); final WorkerInputs inputs = WorkerInputs.create( stageDef, @@ -498,7 +529,7 @@ public class WorkerInputsTest .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) .build(QUERY_ID); - TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true)); + TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true)); final WorkerInputs inputs = WorkerInputs.create( stageDef, @@ -585,11 +616,23 @@ public class WorkerInputsTest private static class TestInputSpecSlicer implements InputSpecSlicer { + private final IntSortedSet workers; private final boolean canSliceDynamic; - public TestInputSpecSlicer(boolean canSliceDynamic) + /** + * Create a test slicer. + * + * @param workers Set of workers to consider assigning work to. + * @param canSliceDynamic Whether this slicer can slice dynamically. + */ + public TestInputSpecSlicer(final IntSortedSet workers, final boolean canSliceDynamic) { + this.workers = workers; this.canSliceDynamic = canSliceDynamic; + + if (workers.isEmpty()) { + throw DruidException.defensive("Need more than one worker in workers[%s]", workers); + } } @Override @@ -606,9 +649,9 @@ public class WorkerInputsTest SlicerUtils.makeSlicesStatic( testInputSpec.values.iterator(), i -> i, - maxNumSlices + Math.min(maxNumSlices, workers.size()) ); - return makeSlices(assignments); + return makeSlices(workers, assignments); } @Override @@ -624,24 +667,39 @@ public class WorkerInputsTest SlicerUtils.makeSlicesDynamic( testInputSpec.values.iterator(), i -> i, - maxNumSlices, + Math.min(maxNumSlices, workers.size()), maxFilesPerSlice, maxBytesPerSlice ); - return makeSlices(assignments); + return makeSlices(workers, assignments); } private static List makeSlices( + final IntSortedSet workers, final List> assignments ) { final List retVal = new ArrayList<>(assignments.size()); - - for (final List assignment : assignments) { - retVal.add(new TestInputSlice(new LongArrayList(assignment))); + for (int assignment = 0, workerNumber = 0; + workerNumber <= workers.lastInt() && assignment < assignments.size(); + workerNumber++) { + if (workers.contains(workerNumber)) { + retVal.add(new TestInputSlice(new LongArrayList(assignments.get(assignment++)))); + } else { + retVal.add(NilInputSlice.INSTANCE); + } } return retVal; } } + + private static IntSortedSet denseWorkers(final int numWorkers) + { + final IntAVLTreeSet workers = new IntAVLTreeSet(); + for (int i = 0; i < numWorkers; i++) { + workers.add(i); + } + return workers; + } } diff --git a/licenses.yaml b/licenses.yaml index a04391ce902..12c2d031c37 100644 --- a/licenses.yaml +++ b/licenses.yaml @@ -3327,7 +3327,7 @@ name: Protocol Buffers license_category: binary module: java-core license_name: BSD-3-Clause License -version: 3.24.0 +version: 3.25.5 copyright: Google, Inc. license_file_path: - licenses/bin/protobuf-java.BSD3 @@ -3493,7 +3493,7 @@ name: Protocol Buffers license_category: binary module: extensions/druid-protobuf-extensions license_name: BSD-3-Clause License -version: 3.24.0 +version: 3.25.5 copyright: Google, Inc. license_file_path: licenses/bin/protobuf-java.BSD3 libraries: diff --git a/pom.xml b/pom.xml index b2a2dbf22de..3568d1a045b 100644 --- a/pom.xml +++ b/pom.xml @@ -108,7 +108,7 @@ 3.10.6.Final 4.1.108.Final 42.7.2 - 3.24.0 + 3.25.5 1.3.1 1.7.36 5.13.0