[Backport] Dart and security backports (#17249) (#17278) (#17281) (#17282) (#17283) (#17277) (#17285)

* MSQ: Allow for worker gaps. (#17277)
* DartSqlResource: Sort queries by start time. (#17282)
* DartSqlResource: Add controllerHost to GetQueriesResponse. (#17283)
* DartWorkerModule: Replace en dash with regular dash. (#17281)
* DartSqlResource: Return HTTP 202 on cancellation even if no such query. (#17278)
* Upgraded Protobuf to 3.25.5 (#17249)
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
(cherry picked from commit 7d9e6d36fd)
---------
Co-authored-by: Gian Merlino <gianmerlino@gmail.com>
Co-authored-by: Shivam Garg <shigarg@visa.com>
This commit is contained in:
Karan Kumar 2024-10-08 19:46:40 +05:30 committed by GitHub
parent f43964a808
commit ccb7c2edd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 499 additions and 75 deletions

View File

@ -59,6 +59,7 @@ public class ControllerHolder
private final ControllerContext controllerContext; private final ControllerContext controllerContext;
private final String sqlQueryId; private final String sqlQueryId;
private final String sql; private final String sql;
private final String controllerHost;
private final AuthenticationResult authenticationResult; private final AuthenticationResult authenticationResult;
private final DateTime startTime; private final DateTime startTime;
private final AtomicReference<State> state = new AtomicReference<>(State.ACCEPTED); private final AtomicReference<State> state = new AtomicReference<>(State.ACCEPTED);
@ -68,6 +69,7 @@ public class ControllerHolder
final ControllerContext controllerContext, final ControllerContext controllerContext,
final String sqlQueryId, final String sqlQueryId,
final String sql, final String sql,
final String controllerHost,
final AuthenticationResult authenticationResult, final AuthenticationResult authenticationResult,
final DateTime startTime final DateTime startTime
) )
@ -76,6 +78,7 @@ public class ControllerHolder
this.controllerContext = controllerContext; this.controllerContext = controllerContext;
this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId"); this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId");
this.sql = sql; this.sql = sql;
this.controllerHost = controllerHost;
this.authenticationResult = authenticationResult; this.authenticationResult = authenticationResult;
this.startTime = Preconditions.checkNotNull(startTime, "startTime"); this.startTime = Preconditions.checkNotNull(startTime, "startTime");
} }
@ -95,6 +98,11 @@ public class ControllerHolder
return sql; return sql;
} }
public String getControllerHost()
{
return controllerHost;
}
public AuthenticationResult getAuthenticationResult() public AuthenticationResult getAuthenticationResult()
{ {
return authenticationResult; return authenticationResult;

View File

@ -26,6 +26,7 @@ import com.google.common.base.Preconditions;
import org.apache.druid.msq.dart.controller.ControllerHolder; import org.apache.druid.msq.dart.controller.ControllerHolder;
import org.apache.druid.msq.util.MSQTaskQueryMakerUtils; import org.apache.druid.msq.util.MSQTaskQueryMakerUtils;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.server.DruidNode;
import org.joda.time.DateTime; import org.joda.time.DateTime;
import java.util.Objects; import java.util.Objects;
@ -38,6 +39,7 @@ public class DartQueryInfo
private final String sqlQueryId; private final String sqlQueryId;
private final String dartQueryId; private final String dartQueryId;
private final String sql; private final String sql;
private final String controllerHost;
private final String authenticator; private final String authenticator;
private final String identity; private final String identity;
private final DateTime startTime; private final DateTime startTime;
@ -48,6 +50,7 @@ public class DartQueryInfo
@JsonProperty("sqlQueryId") final String sqlQueryId, @JsonProperty("sqlQueryId") final String sqlQueryId,
@JsonProperty("dartQueryId") final String dartQueryId, @JsonProperty("dartQueryId") final String dartQueryId,
@JsonProperty("sql") final String sql, @JsonProperty("sql") final String sql,
@JsonProperty("controllerHost") final String controllerHost,
@JsonProperty("authenticator") final String authenticator, @JsonProperty("authenticator") final String authenticator,
@JsonProperty("identity") final String identity, @JsonProperty("identity") final String identity,
@JsonProperty("startTime") final DateTime startTime, @JsonProperty("startTime") final DateTime startTime,
@ -57,6 +60,7 @@ public class DartQueryInfo
this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId"); this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId");
this.dartQueryId = Preconditions.checkNotNull(dartQueryId, "dartQueryId"); this.dartQueryId = Preconditions.checkNotNull(dartQueryId, "dartQueryId");
this.sql = sql; this.sql = sql;
this.controllerHost = controllerHost;
this.authenticator = authenticator; this.authenticator = authenticator;
this.identity = identity; this.identity = identity;
this.startTime = startTime; this.startTime = startTime;
@ -69,6 +73,7 @@ public class DartQueryInfo
holder.getSqlQueryId(), holder.getSqlQueryId(),
holder.getController().queryId(), holder.getController().queryId(),
MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(holder.getSql()), MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(holder.getSql()),
holder.getControllerHost(),
holder.getAuthenticationResult().getAuthenticatedBy(), holder.getAuthenticationResult().getAuthenticatedBy(),
holder.getAuthenticationResult().getIdentity(), holder.getAuthenticationResult().getIdentity(),
holder.getStartTime(), holder.getStartTime(),
@ -104,6 +109,16 @@ public class DartQueryInfo
return sql; return sql;
} }
/**
* Controller host:port, from {@link DruidNode#getHostAndPortToUse()}, that is executing this query.
*/
@JsonProperty
@JsonInclude(JsonInclude.Include.NON_NULL)
public String getControllerHost()
{
return controllerHost;
}
/** /**
* Authenticator that authenticated the identity from {@link #getIdentity()}. * Authenticator that authenticated the identity from {@link #getIdentity()}.
*/ */
@ -145,7 +160,7 @@ public class DartQueryInfo
*/ */
public DartQueryInfo withoutAuthenticationResult() public DartQueryInfo withoutAuthenticationResult()
{ {
return new DartQueryInfo(sqlQueryId, dartQueryId, sql, null, null, startTime, state); return new DartQueryInfo(sqlQueryId, dartQueryId, sql, controllerHost, null, null, startTime, state);
} }
@Override @Override
@ -161,6 +176,7 @@ public class DartQueryInfo
return Objects.equals(sqlQueryId, that.sqlQueryId) return Objects.equals(sqlQueryId, that.sqlQueryId)
&& Objects.equals(dartQueryId, that.dartQueryId) && Objects.equals(dartQueryId, that.dartQueryId)
&& Objects.equals(sql, that.sql) && Objects.equals(sql, that.sql)
&& Objects.equals(controllerHost, that.controllerHost)
&& Objects.equals(authenticator, that.authenticator) && Objects.equals(authenticator, that.authenticator)
&& Objects.equals(identity, that.identity) && Objects.equals(identity, that.identity)
&& Objects.equals(startTime, that.startTime) && Objects.equals(startTime, that.startTime)
@ -170,7 +186,7 @@ public class DartQueryInfo
@Override @Override
public int hashCode() public int hashCode()
{ {
return Objects.hash(sqlQueryId, dartQueryId, sql, authenticator, identity, startTime, state); return Objects.hash(sqlQueryId, dartQueryId, sql, controllerHost, authenticator, identity, startTime, state);
} }
@Override @Override
@ -180,10 +196,11 @@ public class DartQueryInfo
"sqlQueryId='" + sqlQueryId + '\'' + "sqlQueryId='" + sqlQueryId + '\'' +
", dartQueryId='" + dartQueryId + '\'' + ", dartQueryId='" + dartQueryId + '\'' +
", sql='" + sql + '\'' + ", sql='" + sql + '\'' +
", controllerHost='" + controllerHost + '\'' +
", authenticator='" + authenticator + '\'' + ", authenticator='" + authenticator + '\'' +
", identity='" + identity + '\'' + ", identity='" + identity + '\'' +
", startTime=" + startTime + ", startTime=" + startTime +
", state=" + state + ", state='" + state + '\'' +
'}'; '}';
} }
} }

View File

@ -154,7 +154,6 @@ public class DartSqlResource extends SqlResource
controllerRegistry.getAllHolders() controllerRegistry.getAllHolders()
.stream() .stream()
.map(DartQueryInfo::fromControllerHolder) .map(DartQueryInfo::fromControllerHolder)
.sorted(Comparator.comparing(DartQueryInfo::getStartTime))
.collect(Collectors.toList()); .collect(Collectors.toList());
// Add queries from all other servers, if "selfOnly" is not set. // Add queries from all other servers, if "selfOnly" is not set.
@ -172,6 +171,9 @@ public class DartSqlResource extends SqlResource
} }
} }
// Sort queries by start time, breaking ties by query ID, so the list comes back in a consistent and nice order.
queries.sort(Comparator.comparing(DartQueryInfo::getStartTime).thenComparing(DartQueryInfo::getDartQueryId));
final GetQueriesResponse response; final GetQueriesResponse response;
if (stateReadAccess.isAllowed()) { if (stateReadAccess.isAllowed()) {
// User can READ STATE, so they can see all running queries, as well as authentication details. // User can READ STATE, so they can see all running queries, as well as authentication details.
@ -237,7 +239,10 @@ public class DartSqlResource extends SqlResource
List<SqlLifecycleManager.Cancelable> cancelables = sqlLifecycleManager.getAll(sqlQueryId); List<SqlLifecycleManager.Cancelable> cancelables = sqlLifecycleManager.getAll(sqlQueryId);
if (cancelables.isEmpty()) { if (cancelables.isEmpty()) {
return Response.status(Response.Status.NOT_FOUND).build(); // Return ACCEPTED even if the query wasn't found. When the Router broadcasts cancellation requests to all
// Brokers, this ensures the user sees a successful request.
AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req);
return Response.status(Response.Status.ACCEPTED).build();
} }
final Access access = authorizeCancellation(req, cancelables); final Access access = authorizeCancellation(req, cancelables);
@ -247,14 +252,12 @@ public class DartSqlResource extends SqlResource
// Don't call cancel() on the cancelables. That just cancels native queries, which is useless here. Instead, // Don't call cancel() on the cancelables. That just cancels native queries, which is useless here. Instead,
// get the controller and stop it. // get the controller and stop it.
boolean found = false;
for (SqlLifecycleManager.Cancelable cancelable : cancelables) { for (SqlLifecycleManager.Cancelable cancelable : cancelables) {
final HttpStatement stmt = (HttpStatement) cancelable; final HttpStatement stmt = (HttpStatement) cancelable;
final Object dartQueryId = stmt.context().get(DartSqlEngine.CTX_DART_QUERY_ID); final Object dartQueryId = stmt.context().get(DartSqlEngine.CTX_DART_QUERY_ID);
if (dartQueryId instanceof String) { if (dartQueryId instanceof String) {
final ControllerHolder holder = controllerRegistry.get((String) dartQueryId); final ControllerHolder holder = controllerRegistry.get((String) dartQueryId);
if (holder != null) { if (holder != null) {
found = true;
holder.cancel(); holder.cancel();
} }
} else { } else {
@ -267,7 +270,9 @@ public class DartSqlResource extends SqlResource
} }
} }
return Response.status(found ? Response.Status.ACCEPTED : Response.Status.NOT_FOUND).build(); // Return ACCEPTED even if the query wasn't found. When the Router broadcasts cancellation requests to all
// Brokers, this ensures the user sees a successful request.
return Response.status(Response.Status.ACCEPTED).build();
} else { } else {
return Response.status(Response.Status.FORBIDDEN).build(); return Response.status(Response.Status.FORBIDDEN).build();
} }

View File

@ -154,6 +154,7 @@ public class DartQueryMaker implements QueryMaker
controllerContext, controllerContext,
plannerContext.getSqlQueryId(), plannerContext.getSqlQueryId(),
plannerContext.getSql(), plannerContext.getSql(),
controllerContext.selfNode().getHostAndPortToUse(),
plannerContext.getAuthenticationResult(), plannerContext.getAuthenticationResult(),
DateTimes.nowUtc() DateTimes.nowUtc()
); );

View File

@ -113,7 +113,7 @@ public class DartWorkerModule implements DruidModule
final AuthorizerMapper authorizerMapper final AuthorizerMapper authorizerMapper
) )
{ {
final ExecutorService exec = Execs.multiThreaded(memoryIntrospector.numTasksInJvm(), "dartworker-%s"); final ExecutorService exec = Execs.multiThreaded(memoryIntrospector.numTasksInJvm(), "dart-worker-%s");
final File baseTempDir = final File baseTempDir =
new File(processingConfig.getTmpDir(), StringUtils.format("dart_%s", selfNode.getPortToUse())); new File(processingConfig.getTmpDir(), StringUtils.format("dart_%s", selfNode.getPortToUse()));
return new DartWorkerRunner( return new DartWorkerRunner(

View File

@ -59,6 +59,18 @@ public class ReadablePartition
return new ReadablePartition(stageNumber, workerNumbers, partitionNumber); return new ReadablePartition(stageNumber, workerNumbers, partitionNumber);
} }
/**
* Returns an output partition that is striped across a set of {@code workerNumbers}.
*/
public static ReadablePartition striped(
final int stageNumber,
final IntSortedSet workerNumbers,
final int partitionNumber
)
{
return new ReadablePartition(stageNumber, workerNumbers, partitionNumber);
}
/** /**
* Returns an output partition that has been collected onto a single worker. * Returns an output partition that has been collected onto a single worker.
*/ */

View File

@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2IntSortedMap; import it.unimi.dsi.fastutil.ints.Int2IntSortedMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -39,6 +40,7 @@ import java.util.Map;
@JsonSubTypes(value = { @JsonSubTypes(value = {
@JsonSubTypes.Type(name = "collected", value = CollectedReadablePartitions.class), @JsonSubTypes.Type(name = "collected", value = CollectedReadablePartitions.class),
@JsonSubTypes.Type(name = "striped", value = StripedReadablePartitions.class), @JsonSubTypes.Type(name = "striped", value = StripedReadablePartitions.class),
@JsonSubTypes.Type(name = "sparseStriped", value = SparseStripedReadablePartitions.class),
@JsonSubTypes.Type(name = "combined", value = CombinedReadablePartitions.class) @JsonSubTypes.Type(name = "combined", value = CombinedReadablePartitions.class)
}) })
public interface ReadablePartitions extends Iterable<ReadablePartition> public interface ReadablePartitions extends Iterable<ReadablePartition>
@ -59,7 +61,7 @@ public interface ReadablePartitions extends Iterable<ReadablePartition>
/** /**
* Combines various sets of partitions into a single set. * Combines various sets of partitions into a single set.
*/ */
static CombinedReadablePartitions combine(List<ReadablePartitions> readablePartitions) static ReadablePartitions combine(List<ReadablePartitions> readablePartitions)
{ {
return new CombinedReadablePartitions(readablePartitions); return new CombinedReadablePartitions(readablePartitions);
} }
@ -68,7 +70,7 @@ public interface ReadablePartitions extends Iterable<ReadablePartition>
* Returns a set of {@code numPartitions} partitions striped across {@code numWorkers} workers: each worker contains * Returns a set of {@code numPartitions} partitions striped across {@code numWorkers} workers: each worker contains
* a "stripe" of each partition. * a "stripe" of each partition.
*/ */
static StripedReadablePartitions striped( static ReadablePartitions striped(
final int stageNumber, final int stageNumber,
final int numWorkers, final int numWorkers,
final int numPartitions final int numPartitions
@ -82,11 +84,36 @@ public interface ReadablePartitions extends Iterable<ReadablePartition>
return new StripedReadablePartitions(stageNumber, numWorkers, partitionNumbers); return new StripedReadablePartitions(stageNumber, numWorkers, partitionNumbers);
} }
/**
* Returns a set of {@code numPartitions} partitions striped across {@code workers}: each worker contains
* a "stripe" of each partition.
*/
static ReadablePartitions striped(
final int stageNumber,
final IntSortedSet workers,
final int numPartitions
)
{
final IntAVLTreeSet partitionNumbers = new IntAVLTreeSet();
for (int i = 0; i < numPartitions; i++) {
partitionNumbers.add(i);
}
if (workers.lastInt() == workers.size() - 1) {
// Dense worker set. Use StripedReadablePartitions for compactness (send a single number rather than the
// entire worker set) and for backwards compatibility (older workers cannot understand
// SparseStripedReadablePartitions).
return new StripedReadablePartitions(stageNumber, workers.size(), partitionNumbers);
} else {
return new SparseStripedReadablePartitions(stageNumber, workers, partitionNumbers);
}
}
/** /**
* Returns a set of partitions that have been collected onto specific workers: each partition is on exactly * Returns a set of partitions that have been collected onto specific workers: each partition is on exactly
* one worker. * one worker.
*/ */
static CollectedReadablePartitions collected( static ReadablePartitions collected(
final int stageNumber, final int stageNumber,
final Map<Integer, Integer> partitionToWorkerMap final Map<Integer, Integer> partitionToWorkerMap
) )

View File

@ -0,0 +1,142 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.input.stage;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.Iterators;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import org.apache.druid.msq.input.SlicerUtils;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* Set of partitions striped across a sparse set of {@code workers}. Each worker contains a "stripe" of each partition.
*
* @see StripedReadablePartitions dense version, where workers from [0..N) are all used.
*/
public class SparseStripedReadablePartitions implements ReadablePartitions
{
private final int stageNumber;
private final IntSortedSet workers;
private final IntSortedSet partitionNumbers;
/**
* Constructor. Most callers should use {@link ReadablePartitions#striped(int, int, int)} instead, which takes
* a partition count rather than a set of partition numbers.
*/
public SparseStripedReadablePartitions(
final int stageNumber,
final IntSortedSet workers,
final IntSortedSet partitionNumbers
)
{
this.stageNumber = stageNumber;
this.workers = workers;
this.partitionNumbers = partitionNumbers;
}
@JsonCreator
private SparseStripedReadablePartitions(
@JsonProperty("stageNumber") final int stageNumber,
@JsonProperty("workers") final Set<Integer> workers,
@JsonProperty("partitionNumbers") final Set<Integer> partitionNumbers
)
{
this(stageNumber, new IntAVLTreeSet(workers), new IntAVLTreeSet(partitionNumbers));
}
@Override
public Iterator<ReadablePartition> iterator()
{
return Iterators.transform(
partitionNumbers.iterator(),
partitionNumber -> ReadablePartition.striped(stageNumber, workers, partitionNumber)
);
}
@Override
public List<ReadablePartitions> split(final int maxNumSplits)
{
final List<ReadablePartitions> retVal = new ArrayList<>();
for (List<Integer> entries : SlicerUtils.makeSlicesStatic(partitionNumbers.iterator(), maxNumSplits)) {
if (!entries.isEmpty()) {
retVal.add(new SparseStripedReadablePartitions(stageNumber, workers, new IntAVLTreeSet(entries)));
}
}
return retVal;
}
@JsonProperty
int getStageNumber()
{
return stageNumber;
}
@JsonProperty
IntSortedSet getWorkers()
{
return workers;
}
@JsonProperty
IntSortedSet getPartitionNumbers()
{
return partitionNumbers;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SparseStripedReadablePartitions that = (SparseStripedReadablePartitions) o;
return stageNumber == that.stageNumber
&& Objects.equals(workers, that.workers)
&& Objects.equals(partitionNumbers, that.partitionNumbers);
}
@Override
public int hashCode()
{
return Objects.hash(stageNumber, workers, partitionNumbers);
}
@Override
public String toString()
{
return "StripedReadablePartitions{" +
"stageNumber=" + stageNumber +
", workers=" + workers +
", partitionNumbers=" + partitionNumbers +
'}';
}
}

View File

@ -403,7 +403,7 @@ class ControllerStageTracker
throw new ISE("Stage does not gather result key statistics"); throw new ISE("Stage does not gather result key statistics");
} }
if (workerNumber < 0 || workerNumber >= workerCount) { if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber); throw new IAE("Invalid workerNumber [%s]", workerNumber);
} }
@ -522,7 +522,7 @@ class ControllerStageTracker
throw new ISE("Stage does not gather result key statistics"); throw new ISE("Stage does not gather result key statistics");
} }
if (workerNumber < 0 || workerNumber >= workerCount) { if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber); throw new IAE("Invalid workerNumber [%s]", workerNumber);
} }
@ -656,7 +656,7 @@ class ControllerStageTracker
throw new ISE("Stage does not gather result key statistics"); throw new ISE("Stage does not gather result key statistics");
} }
if (workerNumber < 0 || workerNumber >= workerCount) { if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber); throw new IAE("Invalid workerNumber [%s]", workerNumber);
} }
@ -763,7 +763,7 @@ class ControllerStageTracker
this.resultPartitionBoundaries = clusterByPartitions; this.resultPartitionBoundaries = clusterByPartitions;
this.resultPartitions = ReadablePartitions.striped( this.resultPartitions = ReadablePartitions.striped(
stageDef.getStageNumber(), stageDef.getStageNumber(),
workerCount, workerInputs.workers(),
clusterByPartitions.size() clusterByPartitions.size()
); );
@ -788,7 +788,7 @@ class ControllerStageTracker
throw DruidException.defensive("Cannot setDoneReadingInput for stage[%s], it is not sorting", stageDef.getId()); throw DruidException.defensive("Cannot setDoneReadingInput for stage[%s], it is not sorting", stageDef.getId());
} }
if (workerNumber < 0 || workerNumber >= workerCount) { if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber[%s] for stage[%s]", workerNumber, stageDef.getId()); throw new IAE("Invalid workerNumber[%s] for stage[%s]", workerNumber, stageDef.getId());
} }
@ -830,7 +830,7 @@ class ControllerStageTracker
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
boolean setResultsCompleteForWorker(final int workerNumber, final Object resultObject) boolean setResultsCompleteForWorker(final int workerNumber, final Object resultObject)
{ {
if (workerNumber < 0 || workerNumber >= workerCount) { if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber); throw new IAE("Invalid workerNumber [%s]", workerNumber);
} }
@ -947,14 +947,18 @@ class ControllerStageTracker
resultPartitionBoundaries = maybeResultPartitionBoundaries.valueOrThrow(); resultPartitionBoundaries = maybeResultPartitionBoundaries.valueOrThrow();
resultPartitions = ReadablePartitions.striped( resultPartitions = ReadablePartitions.striped(
stageNumber, stageNumber,
workerCount, workerInputs.workers(),
resultPartitionBoundaries.size() resultPartitionBoundaries.size()
); );
} else if (shuffleSpec.kind() == ShuffleKind.MIX) {
resultPartitionBoundaries = ClusterByPartitions.oneUniversalPartition();
resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount());
} else { } else {
resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount()); if (shuffleSpec.kind() == ShuffleKind.MIX) {
resultPartitionBoundaries = ClusterByPartitions.oneUniversalPartition();
}
resultPartitions = ReadablePartitions.striped(
stageNumber,
workerInputs.workers(),
shuffleSpec.partitionCount()
);
} }
} else { } else {
// No reshuffling: retain partitioning from nonbroadcast inputs. // No reshuffling: retain partitioning from nonbroadcast inputs.

View File

@ -24,7 +24,9 @@ import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.Int2IntMap; import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import it.unimi.dsi.fastutil.objects.ObjectIterator; import it.unimi.dsi.fastutil.objects.ObjectIterator;
import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.InputSpec;
@ -45,9 +47,9 @@ import java.util.stream.IntStream;
public class WorkerInputs public class WorkerInputs
{ {
// Worker number -> input number -> input slice. // Worker number -> input number -> input slice.
private final Int2ObjectMap<List<InputSlice>> assignmentsMap; private final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap;
private WorkerInputs(final Int2ObjectMap<List<InputSlice>> assignmentsMap) private WorkerInputs(final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap)
{ {
this.assignmentsMap = assignmentsMap; this.assignmentsMap = assignmentsMap;
} }
@ -64,7 +66,7 @@ public class WorkerInputs
) )
{ {
// Split each inputSpec and assign to workers. This list maps worker number -> input number -> input slice. // Split each inputSpec and assign to workers. This list maps worker number -> input number -> input slice.
final Int2ObjectMap<List<InputSlice>> assignmentsMap = new Int2ObjectAVLTreeMap<>(); final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap = new Int2ObjectAVLTreeMap<>();
final int numInputs = stageDef.getInputSpecs().size(); final int numInputs = stageDef.getInputSpecs().size();
if (numInputs == 0) { if (numInputs == 0) {
@ -117,8 +119,8 @@ public class WorkerInputs
final ObjectIterator<Int2ObjectMap.Entry<List<InputSlice>>> assignmentsIterator = final ObjectIterator<Int2ObjectMap.Entry<List<InputSlice>>> assignmentsIterator =
assignmentsMap.int2ObjectEntrySet().iterator(); assignmentsMap.int2ObjectEntrySet().iterator();
final IntSortedSet nilWorkers = new IntAVLTreeSet();
boolean first = true;
while (assignmentsIterator.hasNext()) { while (assignmentsIterator.hasNext()) {
final Int2ObjectMap.Entry<List<InputSlice>> entry = assignmentsIterator.next(); final Int2ObjectMap.Entry<List<InputSlice>> entry = assignmentsIterator.next();
final List<InputSlice> slices = entry.getValue(); final List<InputSlice> slices = entry.getValue();
@ -130,20 +132,29 @@ public class WorkerInputs
} }
} }
// Eliminate workers that have no non-nil, non-broadcast inputs. (Except the first one, because if all input // Identify nil workers (workers with no non-broadcast inputs).
// is nil, *some* worker has to do *something*.) final boolean isNilWorker =
final boolean hasNonNilNonBroadcastInput =
IntStream.range(0, numInputs) IntStream.range(0, numInputs)
.anyMatch(i -> .allMatch(i ->
!slices.get(i).equals(NilInputSlice.INSTANCE) // Non-nil slices.get(i).equals(NilInputSlice.INSTANCE) // Nil regular input
&& !stageDef.getBroadcastInputNumbers().contains(i) // Non-broadcast || stageDef.getBroadcastInputNumbers().contains(i) // Broadcast
); );
if (!first && !hasNonNilNonBroadcastInput) { if (isNilWorker) {
assignmentsIterator.remove(); nilWorkers.add(entry.getIntKey());
} }
}
first = false; if (nilWorkers.size() == assignmentsMap.size()) {
// All workers have nil regular inputs. Remove all workers exept the first (*some* worker has to do *something*).
final List<InputSlice> firstSlices = assignmentsMap.get(nilWorkers.firstInt());
assignmentsMap.clear();
assignmentsMap.put(nilWorkers.firstInt(), firstSlices);
} else {
// Remove all nil workers.
for (final int nilWorker : nilWorkers) {
assignmentsMap.remove(nilWorker);
}
} }
return new WorkerInputs(assignmentsMap); return new WorkerInputs(assignmentsMap);
@ -154,7 +165,7 @@ public class WorkerInputs
return Preconditions.checkNotNull(assignmentsMap.get(workerNumber), "worker [%s]", workerNumber); return Preconditions.checkNotNull(assignmentsMap.get(workerNumber), "worker [%s]", workerNumber);
} }
public IntSet workers() public IntSortedSet workers()
{ {
return assignmentsMap.keySet(); return assignmentsMap.keySet();
} }

View File

@ -331,9 +331,10 @@ public class DartSqlResourceTest extends MSQTestBase
"sid", "sid",
"did2", "did2",
"SELECT 2", "SELECT 2",
"localhost:1002",
AUTHENTICATOR_NAME, AUTHENTICATOR_NAME,
DIFFERENT_REGULAR_USER_NAME, DIFFERENT_REGULAR_USER_NAME,
DateTimes.of("2000"), DateTimes.of("2001"),
ControllerHolder.State.RUNNING.toString() ControllerHolder.State.RUNNING.toString()
); );
Mockito.when(dartSqlClient.getRunningQueries(true)) Mockito.when(dartSqlClient.getRunningQueries(true))
@ -398,6 +399,7 @@ public class DartSqlResourceTest extends MSQTestBase
"sid", "sid",
"did2", "did2",
"SELECT 2", "SELECT 2",
"localhost:1002",
AUTHENTICATOR_NAME, AUTHENTICATOR_NAME,
DIFFERENT_REGULAR_USER_NAME, DIFFERENT_REGULAR_USER_NAME,
DateTimes.of("2000"), DateTimes.of("2000"),
@ -434,6 +436,7 @@ public class DartSqlResourceTest extends MSQTestBase
"sid", "sid",
"did2", "did2",
"SELECT 2", "SELECT 2",
"localhost:1002",
AUTHENTICATOR_NAME, AUTHENTICATOR_NAME,
DIFFERENT_REGULAR_USER_NAME, DIFFERENT_REGULAR_USER_NAME,
DateTimes.of("2000"), DateTimes.of("2000"),
@ -724,7 +727,7 @@ public class DartSqlResourceTest extends MSQTestBase
.thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME));
final Response cancellationResponse = sqlResource.cancelQuery("nonexistent", httpServletRequest); final Response cancellationResponse = sqlResource.cancelQuery("nonexistent", httpServletRequest);
Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), cancellationResponse.getStatus()); Assertions.assertEquals(Response.Status.ACCEPTED.getStatusCode(), cancellationResponse.getStatus());
} }
/** /**
@ -739,8 +742,15 @@ public class DartSqlResourceTest extends MSQTestBase
Mockito.when(controller.queryId()).thenReturn("did_" + identity); Mockito.when(controller.queryId()).thenReturn("did_" + identity);
final AuthenticationResult authenticationResult = makeAuthenticationResult(identity); final AuthenticationResult authenticationResult = makeAuthenticationResult(identity);
final ControllerHolder holder = final ControllerHolder holder = new ControllerHolder(
new ControllerHolder(controller, null, "sid", "SELECT 1", authenticationResult, DateTimes.of("2000")); controller,
null,
"sid",
"SELECT 1",
"localhost:1001",
authenticationResult,
DateTimes.of("2000")
);
controllerRegistry.register(holder); controllerRegistry.register(holder);
return holder; return holder;

View File

@ -41,6 +41,7 @@ public class GetQueriesResponseTest
"xyz", "xyz",
"abc", "abc",
"SELECT 1", "SELECT 1",
"localhost:1001",
"auth", "auth",
"anon", "anon",
DateTimes.of("2000"), DateTimes.of("2000"),

View File

@ -69,6 +69,7 @@ public class DartSqlClientImplTest
"sid", "sid",
"did", "did",
"SELECT 1", "SELECT 1",
"localhost:1001",
"", "",
"", "",
DateTimes.of("2000"), DateTimes.of("2000"),
@ -97,6 +98,7 @@ public class DartSqlClientImplTest
"sid", "sid",
"did", "did",
"SELECT 1", "SELECT 1",
"localhost:1001",
"", "",
"", "",
DateTimes.of("2000"), DateTimes.of("2000"),

View File

@ -33,21 +33,24 @@ public class CollectedReadablePartitionsTest
@Test @Test
public void testPartitionToWorkerMap() public void testPartitionToWorkerMap()
{ {
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals(ImmutableMap.of(0, 1, 1, 2, 2, 1), partitions.getPartitionToWorkerMap()); Assert.assertEquals(ImmutableMap.of(0, 1, 1, 2, 2, 1), partitions.getPartitionToWorkerMap());
} }
@Test @Test
public void testStageNumber() public void testStageNumber()
{ {
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals(1, partitions.getStageNumber()); Assert.assertEquals(1, partitions.getStageNumber());
} }
@Test @Test
public void testSplit() public void testSplit()
{ {
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals( Assert.assertEquals(
ImmutableList.of( ImmutableList.of(
@ -64,7 +67,8 @@ public class CollectedReadablePartitionsTest
final ObjectMapper mapper = TestHelper.makeJsonMapper() final ObjectMapper mapper = TestHelper.makeJsonMapper()
.registerModules(new MSQIndexingModule().getJacksonModules()); .registerModules(new MSQIndexingModule().getJacksonModules());
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals( Assert.assertEquals(
partitions, partitions,

View File

@ -31,7 +31,7 @@ import org.junit.Test;
public class CombinedReadablePartitionsTest public class CombinedReadablePartitionsTest
{ {
private static final CombinedReadablePartitions PARTITIONS = ReadablePartitions.combine( private static final ReadablePartitions PARTITIONS = ReadablePartitions.combine(
ImmutableList.of( ImmutableList.of(
ReadablePartitions.striped(0, 2, 2), ReadablePartitions.striped(0, 2, 2),
ReadablePartitions.striped(1, 2, 4) ReadablePartitions.striped(1, 2, 4)

View File

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.input.stage;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.segment.TestHelper;
import org.junit.Assert;
import org.junit.Test;
public class SparseStripedReadablePartitionsTest
{
@Test
public void testPartitionNumbers()
{
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3);
Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers());
}
@Test
public void testWorkers()
{
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3);
Assert.assertEquals(IntSet.of(1, 3), partitions.getWorkers());
}
@Test
public void testStageNumber()
{
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3);
Assert.assertEquals(1, partitions.getStageNumber());
}
@Test
public void testSplit()
{
final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3});
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, workers, 3);
Assert.assertEquals(
ImmutableList.of(
new SparseStripedReadablePartitions(1, workers, new IntAVLTreeSet(new int[]{0, 2})),
new SparseStripedReadablePartitions(1, workers, new IntAVLTreeSet(new int[]{1}))
),
partitions.split(2)
);
}
@Test
public void testSerde() throws Exception
{
final ObjectMapper mapper = TestHelper.makeJsonMapper()
.registerModules(new MSQIndexingModule().getJacksonModules());
final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3});
final ReadablePartitions partitions = ReadablePartitions.striped(1, workers, 3);
Assert.assertEquals(
partitions,
mapper.readValue(
mapper.writeValueAsString(partitions),
ReadablePartitions.class
)
);
}
@Test
public void testEquals()
{
EqualsVerifier.forClass(SparseStripedReadablePartitions.class).usingGetClass().verify();
}
}

View File

@ -26,36 +26,60 @@ import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.TestHelper;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
public class StripedReadablePartitionsTest public class StripedReadablePartitionsTest
{ {
@Test
public void testFromDenseSet()
{
// Tests that when ReadablePartitions.striped is called with a dense set, we get StripedReadablePartitions.
final IntAVLTreeSet workers = new IntAVLTreeSet();
workers.add(0);
workers.add(1);
final ReadablePartitions readablePartitionsFromSet = ReadablePartitions.striped(1, workers, 3);
MatcherAssert.assertThat(
readablePartitionsFromSet,
CoreMatchers.instanceOf(StripedReadablePartitions.class)
);
Assert.assertEquals(
ReadablePartitions.striped(1, 2, 3),
readablePartitionsFromSet
);
}
@Test @Test
public void testPartitionNumbers() public void testPartitionNumbers()
{ {
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers()); Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers());
} }
@Test @Test
public void testNumWorkers() public void testNumWorkers()
{ {
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(2, partitions.getNumWorkers()); Assert.assertEquals(2, partitions.getNumWorkers());
} }
@Test @Test
public void testStageNumber() public void testStageNumber()
{ {
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(1, partitions.getStageNumber()); Assert.assertEquals(1, partitions.getStageNumber());
} }
@Test @Test
public void testSplit() public void testSplit()
{ {
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals( Assert.assertEquals(
ImmutableList.of( ImmutableList.of(
@ -72,7 +96,7 @@ public class StripedReadablePartitionsTest
final ObjectMapper mapper = TestHelper.makeJsonMapper() final ObjectMapper mapper = TestHelper.makeJsonMapper()
.registerModules(new MSQIndexingModule().getJacksonModules()); .registerModules(new MSQIndexingModule().getJacksonModules());
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals( Assert.assertEquals(
partitions, partitions,

View File

@ -25,9 +25,11 @@ import it.unimi.dsi.fastutil.ints.Int2IntMaps;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongList; import it.unimi.dsi.fastutil.longs.LongList;
import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.error.DruidException;
import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.exec.Limits;
import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.exec.OutputChannelMode;
import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSlice;
@ -75,7 +77,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
Int2IntMaps.EMPTY_MAP, Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true), new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.MAX, WorkerAssignmentStrategy.MAX,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
@ -91,6 +93,35 @@ public class WorkerInputsTest
); );
} }
@Test
public void test_max_threeInputs_fourWorkers_withGaps()
{
final StageDefinition stageDef =
StageDefinition.builder(0)
.inputs(new TestInputSpec(1, 2, 3))
.maxWorkerCount(4)
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID);
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(new IntAVLTreeSet(new int[]{1, 3, 4, 5}), true),
WorkerAssignmentStrategy.MAX,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
Assert.assertEquals(
ImmutableMap.<Integer, List<InputSlice>>builder()
.put(1, Collections.singletonList(new TestInputSlice(1)))
.put(3, Collections.singletonList(new TestInputSlice(2)))
.put(4, Collections.singletonList(new TestInputSlice(3)))
.put(5, Collections.singletonList(new TestInputSlice()))
.build(),
inputs.assignmentsMap()
);
}
@Test @Test
public void test_max_zeroInputs_fourWorkers() public void test_max_zeroInputs_fourWorkers()
{ {
@ -104,7 +135,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
Int2IntMaps.EMPTY_MAP, Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true), new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.MAX, WorkerAssignmentStrategy.MAX,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
@ -133,7 +164,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
Int2IntMaps.EMPTY_MAP, Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true), new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO, WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
@ -159,7 +190,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
Int2IntMaps.EMPTY_MAP, Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true), new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO, WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
@ -186,7 +217,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
Int2IntMaps.EMPTY_MAP, Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true), new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO, WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
@ -212,7 +243,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
Int2IntMaps.EMPTY_MAP, Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true), new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO, WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
@ -324,7 +355,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
Int2IntMaps.EMPTY_MAP, Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true), new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO, WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
@ -351,7 +382,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
Int2IntMaps.EMPTY_MAP, Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true), new TestInputSpecSlicer(denseWorkers(2), true),
WorkerAssignmentStrategy.AUTO, WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
@ -384,7 +415,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
Int2IntMaps.EMPTY_MAP, Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true), new TestInputSpecSlicer(denseWorkers(1), true),
WorkerAssignmentStrategy.AUTO, WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
@ -411,7 +442,7 @@ public class WorkerInputsTest
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID); .build(QUERY_ID);
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true)); TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true));
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
@ -455,7 +486,7 @@ public class WorkerInputsTest
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID); .build(QUERY_ID);
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true)); TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true));
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
@ -498,7 +529,7 @@ public class WorkerInputsTest
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID); .build(QUERY_ID);
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true)); TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true));
final WorkerInputs inputs = WorkerInputs.create( final WorkerInputs inputs = WorkerInputs.create(
stageDef, stageDef,
@ -585,11 +616,23 @@ public class WorkerInputsTest
private static class TestInputSpecSlicer implements InputSpecSlicer private static class TestInputSpecSlicer implements InputSpecSlicer
{ {
private final IntSortedSet workers;
private final boolean canSliceDynamic; private final boolean canSliceDynamic;
public TestInputSpecSlicer(boolean canSliceDynamic) /**
* Create a test slicer.
*
* @param workers Set of workers to consider assigning work to.
* @param canSliceDynamic Whether this slicer can slice dynamically.
*/
public TestInputSpecSlicer(final IntSortedSet workers, final boolean canSliceDynamic)
{ {
this.workers = workers;
this.canSliceDynamic = canSliceDynamic; this.canSliceDynamic = canSliceDynamic;
if (workers.isEmpty()) {
throw DruidException.defensive("Need more than one worker in workers[%s]", workers);
}
} }
@Override @Override
@ -606,9 +649,9 @@ public class WorkerInputsTest
SlicerUtils.makeSlicesStatic( SlicerUtils.makeSlicesStatic(
testInputSpec.values.iterator(), testInputSpec.values.iterator(),
i -> i, i -> i,
maxNumSlices Math.min(maxNumSlices, workers.size())
); );
return makeSlices(assignments); return makeSlices(workers, assignments);
} }
@Override @Override
@ -624,24 +667,39 @@ public class WorkerInputsTest
SlicerUtils.makeSlicesDynamic( SlicerUtils.makeSlicesDynamic(
testInputSpec.values.iterator(), testInputSpec.values.iterator(),
i -> i, i -> i,
maxNumSlices, Math.min(maxNumSlices, workers.size()),
maxFilesPerSlice, maxFilesPerSlice,
maxBytesPerSlice maxBytesPerSlice
); );
return makeSlices(assignments); return makeSlices(workers, assignments);
} }
private static List<InputSlice> makeSlices( private static List<InputSlice> makeSlices(
final IntSortedSet workers,
final List<List<Long>> assignments final List<List<Long>> assignments
) )
{ {
final List<InputSlice> retVal = new ArrayList<>(assignments.size()); final List<InputSlice> retVal = new ArrayList<>(assignments.size());
for (int assignment = 0, workerNumber = 0;
for (final List<Long> assignment : assignments) { workerNumber <= workers.lastInt() && assignment < assignments.size();
retVal.add(new TestInputSlice(new LongArrayList(assignment))); workerNumber++) {
if (workers.contains(workerNumber)) {
retVal.add(new TestInputSlice(new LongArrayList(assignments.get(assignment++))));
} else {
retVal.add(NilInputSlice.INSTANCE);
}
} }
return retVal; return retVal;
} }
} }
private static IntSortedSet denseWorkers(final int numWorkers)
{
final IntAVLTreeSet workers = new IntAVLTreeSet();
for (int i = 0; i < numWorkers; i++) {
workers.add(i);
}
return workers;
}
} }

View File

@ -3327,7 +3327,7 @@ name: Protocol Buffers
license_category: binary license_category: binary
module: java-core module: java-core
license_name: BSD-3-Clause License license_name: BSD-3-Clause License
version: 3.24.0 version: 3.25.5
copyright: Google, Inc. copyright: Google, Inc.
license_file_path: license_file_path:
- licenses/bin/protobuf-java.BSD3 - licenses/bin/protobuf-java.BSD3
@ -3493,7 +3493,7 @@ name: Protocol Buffers
license_category: binary license_category: binary
module: extensions/druid-protobuf-extensions module: extensions/druid-protobuf-extensions
license_name: BSD-3-Clause License license_name: BSD-3-Clause License
version: 3.24.0 version: 3.25.5
copyright: Google, Inc. copyright: Google, Inc.
license_file_path: licenses/bin/protobuf-java.BSD3 license_file_path: licenses/bin/protobuf-java.BSD3
libraries: libraries:

View File

@ -108,7 +108,7 @@
<netty3.version>3.10.6.Final</netty3.version> <netty3.version>3.10.6.Final</netty3.version>
<netty4.version>4.1.108.Final</netty4.version> <netty4.version>4.1.108.Final</netty4.version>
<postgresql.version>42.7.2</postgresql.version> <postgresql.version>42.7.2</postgresql.version>
<protobuf.version>3.24.0</protobuf.version> <protobuf.version>3.25.5</protobuf.version>
<resilience4j.version>1.3.1</resilience4j.version> <resilience4j.version>1.3.1</resilience4j.version>
<slf4j.version>1.7.36</slf4j.version> <slf4j.version>1.7.36</slf4j.version>
<jna.version>5.13.0</jna.version> <jna.version>5.13.0</jna.version>