[Backport] Dart and security backports (#17249) (#17278) (#17281) (#17282) (#17283) (#17277) (#17285)

* MSQ: Allow for worker gaps. (#17277)
* DartSqlResource: Sort queries by start time. (#17282)
* DartSqlResource: Add controllerHost to GetQueriesResponse. (#17283)
* DartWorkerModule: Replace en dash with regular dash. (#17281)
* DartSqlResource: Return HTTP 202 on cancellation even if no such query. (#17278)
* Upgraded Protobuf to 3.25.5 (#17249)
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
(cherry picked from commit 7d9e6d36fd)
---------
Co-authored-by: Gian Merlino <gianmerlino@gmail.com>
Co-authored-by: Shivam Garg <shigarg@visa.com>
This commit is contained in:
Karan Kumar 2024-10-08 19:46:40 +05:30 committed by GitHub
parent f43964a808
commit ccb7c2edd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 499 additions and 75 deletions

View File

@ -59,6 +59,7 @@ public class ControllerHolder
private final ControllerContext controllerContext;
private final String sqlQueryId;
private final String sql;
private final String controllerHost;
private final AuthenticationResult authenticationResult;
private final DateTime startTime;
private final AtomicReference<State> state = new AtomicReference<>(State.ACCEPTED);
@ -68,6 +69,7 @@ public class ControllerHolder
final ControllerContext controllerContext,
final String sqlQueryId,
final String sql,
final String controllerHost,
final AuthenticationResult authenticationResult,
final DateTime startTime
)
@ -76,6 +78,7 @@ public class ControllerHolder
this.controllerContext = controllerContext;
this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId");
this.sql = sql;
this.controllerHost = controllerHost;
this.authenticationResult = authenticationResult;
this.startTime = Preconditions.checkNotNull(startTime, "startTime");
}
@ -95,6 +98,11 @@ public class ControllerHolder
return sql;
}
public String getControllerHost()
{
return controllerHost;
}
public AuthenticationResult getAuthenticationResult()
{
return authenticationResult;

View File

@ -26,6 +26,7 @@ import com.google.common.base.Preconditions;
import org.apache.druid.msq.dart.controller.ControllerHolder;
import org.apache.druid.msq.util.MSQTaskQueryMakerUtils;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.server.DruidNode;
import org.joda.time.DateTime;
import java.util.Objects;
@ -38,6 +39,7 @@ public class DartQueryInfo
private final String sqlQueryId;
private final String dartQueryId;
private final String sql;
private final String controllerHost;
private final String authenticator;
private final String identity;
private final DateTime startTime;
@ -48,6 +50,7 @@ public class DartQueryInfo
@JsonProperty("sqlQueryId") final String sqlQueryId,
@JsonProperty("dartQueryId") final String dartQueryId,
@JsonProperty("sql") final String sql,
@JsonProperty("controllerHost") final String controllerHost,
@JsonProperty("authenticator") final String authenticator,
@JsonProperty("identity") final String identity,
@JsonProperty("startTime") final DateTime startTime,
@ -57,6 +60,7 @@ public class DartQueryInfo
this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId");
this.dartQueryId = Preconditions.checkNotNull(dartQueryId, "dartQueryId");
this.sql = sql;
this.controllerHost = controllerHost;
this.authenticator = authenticator;
this.identity = identity;
this.startTime = startTime;
@ -69,6 +73,7 @@ public class DartQueryInfo
holder.getSqlQueryId(),
holder.getController().queryId(),
MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(holder.getSql()),
holder.getControllerHost(),
holder.getAuthenticationResult().getAuthenticatedBy(),
holder.getAuthenticationResult().getIdentity(),
holder.getStartTime(),
@ -104,6 +109,16 @@ public class DartQueryInfo
return sql;
}
/**
* Controller host:port, from {@link DruidNode#getHostAndPortToUse()}, that is executing this query.
*/
@JsonProperty
@JsonInclude(JsonInclude.Include.NON_NULL)
public String getControllerHost()
{
return controllerHost;
}
/**
* Authenticator that authenticated the identity from {@link #getIdentity()}.
*/
@ -145,7 +160,7 @@ public class DartQueryInfo
*/
public DartQueryInfo withoutAuthenticationResult()
{
return new DartQueryInfo(sqlQueryId, dartQueryId, sql, null, null, startTime, state);
return new DartQueryInfo(sqlQueryId, dartQueryId, sql, controllerHost, null, null, startTime, state);
}
@Override
@ -161,6 +176,7 @@ public class DartQueryInfo
return Objects.equals(sqlQueryId, that.sqlQueryId)
&& Objects.equals(dartQueryId, that.dartQueryId)
&& Objects.equals(sql, that.sql)
&& Objects.equals(controllerHost, that.controllerHost)
&& Objects.equals(authenticator, that.authenticator)
&& Objects.equals(identity, that.identity)
&& Objects.equals(startTime, that.startTime)
@ -170,7 +186,7 @@ public class DartQueryInfo
@Override
public int hashCode()
{
return Objects.hash(sqlQueryId, dartQueryId, sql, authenticator, identity, startTime, state);
return Objects.hash(sqlQueryId, dartQueryId, sql, controllerHost, authenticator, identity, startTime, state);
}
@Override
@ -180,10 +196,11 @@ public class DartQueryInfo
"sqlQueryId='" + sqlQueryId + '\'' +
", dartQueryId='" + dartQueryId + '\'' +
", sql='" + sql + '\'' +
", controllerHost='" + controllerHost + '\'' +
", authenticator='" + authenticator + '\'' +
", identity='" + identity + '\'' +
", startTime=" + startTime +
", state=" + state +
", state='" + state + '\'' +
'}';
}
}

View File

@ -154,7 +154,6 @@ public class DartSqlResource extends SqlResource
controllerRegistry.getAllHolders()
.stream()
.map(DartQueryInfo::fromControllerHolder)
.sorted(Comparator.comparing(DartQueryInfo::getStartTime))
.collect(Collectors.toList());
// Add queries from all other servers, if "selfOnly" is not set.
@ -172,6 +171,9 @@ public class DartSqlResource extends SqlResource
}
}
// Sort queries by start time, breaking ties by query ID, so the list comes back in a consistent and nice order.
queries.sort(Comparator.comparing(DartQueryInfo::getStartTime).thenComparing(DartQueryInfo::getDartQueryId));
final GetQueriesResponse response;
if (stateReadAccess.isAllowed()) {
// User can READ STATE, so they can see all running queries, as well as authentication details.
@ -237,7 +239,10 @@ public class DartSqlResource extends SqlResource
List<SqlLifecycleManager.Cancelable> cancelables = sqlLifecycleManager.getAll(sqlQueryId);
if (cancelables.isEmpty()) {
return Response.status(Response.Status.NOT_FOUND).build();
// Return ACCEPTED even if the query wasn't found. When the Router broadcasts cancellation requests to all
// Brokers, this ensures the user sees a successful request.
AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req);
return Response.status(Response.Status.ACCEPTED).build();
}
final Access access = authorizeCancellation(req, cancelables);
@ -247,14 +252,12 @@ public class DartSqlResource extends SqlResource
// Don't call cancel() on the cancelables. That just cancels native queries, which is useless here. Instead,
// get the controller and stop it.
boolean found = false;
for (SqlLifecycleManager.Cancelable cancelable : cancelables) {
final HttpStatement stmt = (HttpStatement) cancelable;
final Object dartQueryId = stmt.context().get(DartSqlEngine.CTX_DART_QUERY_ID);
if (dartQueryId instanceof String) {
final ControllerHolder holder = controllerRegistry.get((String) dartQueryId);
if (holder != null) {
found = true;
holder.cancel();
}
} else {
@ -267,7 +270,9 @@ public class DartSqlResource extends SqlResource
}
}
return Response.status(found ? Response.Status.ACCEPTED : Response.Status.NOT_FOUND).build();
// Return ACCEPTED even if the query wasn't found. When the Router broadcasts cancellation requests to all
// Brokers, this ensures the user sees a successful request.
return Response.status(Response.Status.ACCEPTED).build();
} else {
return Response.status(Response.Status.FORBIDDEN).build();
}

View File

@ -154,6 +154,7 @@ public class DartQueryMaker implements QueryMaker
controllerContext,
plannerContext.getSqlQueryId(),
plannerContext.getSql(),
controllerContext.selfNode().getHostAndPortToUse(),
plannerContext.getAuthenticationResult(),
DateTimes.nowUtc()
);

View File

@ -113,7 +113,7 @@ public class DartWorkerModule implements DruidModule
final AuthorizerMapper authorizerMapper
)
{
final ExecutorService exec = Execs.multiThreaded(memoryIntrospector.numTasksInJvm(), "dartworker-%s");
final ExecutorService exec = Execs.multiThreaded(memoryIntrospector.numTasksInJvm(), "dart-worker-%s");
final File baseTempDir =
new File(processingConfig.getTmpDir(), StringUtils.format("dart_%s", selfNode.getPortToUse()));
return new DartWorkerRunner(

View File

@ -59,6 +59,18 @@ public class ReadablePartition
return new ReadablePartition(stageNumber, workerNumbers, partitionNumber);
}
/**
* Returns an output partition that is striped across a set of {@code workerNumbers}.
*/
public static ReadablePartition striped(
final int stageNumber,
final IntSortedSet workerNumbers,
final int partitionNumber
)
{
return new ReadablePartition(stageNumber, workerNumbers, partitionNumber);
}
/**
* Returns an output partition that has been collected onto a single worker.
*/

View File

@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2IntSortedMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import java.util.Collections;
import java.util.List;
@ -39,6 +40,7 @@ import java.util.Map;
@JsonSubTypes(value = {
@JsonSubTypes.Type(name = "collected", value = CollectedReadablePartitions.class),
@JsonSubTypes.Type(name = "striped", value = StripedReadablePartitions.class),
@JsonSubTypes.Type(name = "sparseStriped", value = SparseStripedReadablePartitions.class),
@JsonSubTypes.Type(name = "combined", value = CombinedReadablePartitions.class)
})
public interface ReadablePartitions extends Iterable<ReadablePartition>
@ -59,7 +61,7 @@ public interface ReadablePartitions extends Iterable<ReadablePartition>
/**
* Combines various sets of partitions into a single set.
*/
static CombinedReadablePartitions combine(List<ReadablePartitions> readablePartitions)
static ReadablePartitions combine(List<ReadablePartitions> readablePartitions)
{
return new CombinedReadablePartitions(readablePartitions);
}
@ -68,7 +70,7 @@ public interface ReadablePartitions extends Iterable<ReadablePartition>
* Returns a set of {@code numPartitions} partitions striped across {@code numWorkers} workers: each worker contains
* a "stripe" of each partition.
*/
static StripedReadablePartitions striped(
static ReadablePartitions striped(
final int stageNumber,
final int numWorkers,
final int numPartitions
@ -82,11 +84,36 @@ public interface ReadablePartitions extends Iterable<ReadablePartition>
return new StripedReadablePartitions(stageNumber, numWorkers, partitionNumbers);
}
/**
* Returns a set of {@code numPartitions} partitions striped across {@code workers}: each worker contains
* a "stripe" of each partition.
*/
static ReadablePartitions striped(
final int stageNumber,
final IntSortedSet workers,
final int numPartitions
)
{
final IntAVLTreeSet partitionNumbers = new IntAVLTreeSet();
for (int i = 0; i < numPartitions; i++) {
partitionNumbers.add(i);
}
if (workers.lastInt() == workers.size() - 1) {
// Dense worker set. Use StripedReadablePartitions for compactness (send a single number rather than the
// entire worker set) and for backwards compatibility (older workers cannot understand
// SparseStripedReadablePartitions).
return new StripedReadablePartitions(stageNumber, workers.size(), partitionNumbers);
} else {
return new SparseStripedReadablePartitions(stageNumber, workers, partitionNumbers);
}
}
/**
* Returns a set of partitions that have been collected onto specific workers: each partition is on exactly
* one worker.
*/
static CollectedReadablePartitions collected(
static ReadablePartitions collected(
final int stageNumber,
final Map<Integer, Integer> partitionToWorkerMap
)

View File

@ -0,0 +1,142 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.input.stage;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.Iterators;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import org.apache.druid.msq.input.SlicerUtils;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* Set of partitions striped across a sparse set of {@code workers}. Each worker contains a "stripe" of each partition.
*
* @see StripedReadablePartitions dense version, where workers from [0..N) are all used.
*/
public class SparseStripedReadablePartitions implements ReadablePartitions
{
private final int stageNumber;
private final IntSortedSet workers;
private final IntSortedSet partitionNumbers;
/**
* Constructor. Most callers should use {@link ReadablePartitions#striped(int, int, int)} instead, which takes
* a partition count rather than a set of partition numbers.
*/
public SparseStripedReadablePartitions(
final int stageNumber,
final IntSortedSet workers,
final IntSortedSet partitionNumbers
)
{
this.stageNumber = stageNumber;
this.workers = workers;
this.partitionNumbers = partitionNumbers;
}
@JsonCreator
private SparseStripedReadablePartitions(
@JsonProperty("stageNumber") final int stageNumber,
@JsonProperty("workers") final Set<Integer> workers,
@JsonProperty("partitionNumbers") final Set<Integer> partitionNumbers
)
{
this(stageNumber, new IntAVLTreeSet(workers), new IntAVLTreeSet(partitionNumbers));
}
@Override
public Iterator<ReadablePartition> iterator()
{
return Iterators.transform(
partitionNumbers.iterator(),
partitionNumber -> ReadablePartition.striped(stageNumber, workers, partitionNumber)
);
}
@Override
public List<ReadablePartitions> split(final int maxNumSplits)
{
final List<ReadablePartitions> retVal = new ArrayList<>();
for (List<Integer> entries : SlicerUtils.makeSlicesStatic(partitionNumbers.iterator(), maxNumSplits)) {
if (!entries.isEmpty()) {
retVal.add(new SparseStripedReadablePartitions(stageNumber, workers, new IntAVLTreeSet(entries)));
}
}
return retVal;
}
@JsonProperty
int getStageNumber()
{
return stageNumber;
}
@JsonProperty
IntSortedSet getWorkers()
{
return workers;
}
@JsonProperty
IntSortedSet getPartitionNumbers()
{
return partitionNumbers;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SparseStripedReadablePartitions that = (SparseStripedReadablePartitions) o;
return stageNumber == that.stageNumber
&& Objects.equals(workers, that.workers)
&& Objects.equals(partitionNumbers, that.partitionNumbers);
}
@Override
public int hashCode()
{
return Objects.hash(stageNumber, workers, partitionNumbers);
}
@Override
public String toString()
{
return "StripedReadablePartitions{" +
"stageNumber=" + stageNumber +
", workers=" + workers +
", partitionNumbers=" + partitionNumbers +
'}';
}
}

View File

@ -403,7 +403,7 @@ class ControllerStageTracker
throw new ISE("Stage does not gather result key statistics");
}
if (workerNumber < 0 || workerNumber >= workerCount) {
if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
@ -522,7 +522,7 @@ class ControllerStageTracker
throw new ISE("Stage does not gather result key statistics");
}
if (workerNumber < 0 || workerNumber >= workerCount) {
if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
@ -656,7 +656,7 @@ class ControllerStageTracker
throw new ISE("Stage does not gather result key statistics");
}
if (workerNumber < 0 || workerNumber >= workerCount) {
if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
@ -763,7 +763,7 @@ class ControllerStageTracker
this.resultPartitionBoundaries = clusterByPartitions;
this.resultPartitions = ReadablePartitions.striped(
stageDef.getStageNumber(),
workerCount,
workerInputs.workers(),
clusterByPartitions.size()
);
@ -788,7 +788,7 @@ class ControllerStageTracker
throw DruidException.defensive("Cannot setDoneReadingInput for stage[%s], it is not sorting", stageDef.getId());
}
if (workerNumber < 0 || workerNumber >= workerCount) {
if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber[%s] for stage[%s]", workerNumber, stageDef.getId());
}
@ -830,7 +830,7 @@ class ControllerStageTracker
@SuppressWarnings("unchecked")
boolean setResultsCompleteForWorker(final int workerNumber, final Object resultObject)
{
if (workerNumber < 0 || workerNumber >= workerCount) {
if (!workerInputs.workers().contains(workerNumber)) {
throw new IAE("Invalid workerNumber [%s]", workerNumber);
}
@ -947,14 +947,18 @@ class ControllerStageTracker
resultPartitionBoundaries = maybeResultPartitionBoundaries.valueOrThrow();
resultPartitions = ReadablePartitions.striped(
stageNumber,
workerCount,
workerInputs.workers(),
resultPartitionBoundaries.size()
);
} else if (shuffleSpec.kind() == ShuffleKind.MIX) {
resultPartitionBoundaries = ClusterByPartitions.oneUniversalPartition();
resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount());
} else {
resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount());
if (shuffleSpec.kind() == ShuffleKind.MIX) {
resultPartitionBoundaries = ClusterByPartitions.oneUniversalPartition();
}
resultPartitions = ReadablePartitions.striped(
stageNumber,
workerInputs.workers(),
shuffleSpec.partitionCount()
);
}
} else {
// No reshuffling: retain partitioning from nonbroadcast inputs.

View File

@ -24,7 +24,9 @@ import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.InputSpec;
@ -45,9 +47,9 @@ import java.util.stream.IntStream;
public class WorkerInputs
{
// Worker number -> input number -> input slice.
private final Int2ObjectMap<List<InputSlice>> assignmentsMap;
private final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap;
private WorkerInputs(final Int2ObjectMap<List<InputSlice>> assignmentsMap)
private WorkerInputs(final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap)
{
this.assignmentsMap = assignmentsMap;
}
@ -64,7 +66,7 @@ public class WorkerInputs
)
{
// Split each inputSpec and assign to workers. This list maps worker number -> input number -> input slice.
final Int2ObjectMap<List<InputSlice>> assignmentsMap = new Int2ObjectAVLTreeMap<>();
final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap = new Int2ObjectAVLTreeMap<>();
final int numInputs = stageDef.getInputSpecs().size();
if (numInputs == 0) {
@ -117,8 +119,8 @@ public class WorkerInputs
final ObjectIterator<Int2ObjectMap.Entry<List<InputSlice>>> assignmentsIterator =
assignmentsMap.int2ObjectEntrySet().iterator();
final IntSortedSet nilWorkers = new IntAVLTreeSet();
boolean first = true;
while (assignmentsIterator.hasNext()) {
final Int2ObjectMap.Entry<List<InputSlice>> entry = assignmentsIterator.next();
final List<InputSlice> slices = entry.getValue();
@ -130,20 +132,29 @@ public class WorkerInputs
}
}
// Eliminate workers that have no non-nil, non-broadcast inputs. (Except the first one, because if all input
// is nil, *some* worker has to do *something*.)
final boolean hasNonNilNonBroadcastInput =
// Identify nil workers (workers with no non-broadcast inputs).
final boolean isNilWorker =
IntStream.range(0, numInputs)
.anyMatch(i ->
!slices.get(i).equals(NilInputSlice.INSTANCE) // Non-nil
&& !stageDef.getBroadcastInputNumbers().contains(i) // Non-broadcast
.allMatch(i ->
slices.get(i).equals(NilInputSlice.INSTANCE) // Nil regular input
|| stageDef.getBroadcastInputNumbers().contains(i) // Broadcast
);
if (!first && !hasNonNilNonBroadcastInput) {
assignmentsIterator.remove();
if (isNilWorker) {
nilWorkers.add(entry.getIntKey());
}
}
first = false;
if (nilWorkers.size() == assignmentsMap.size()) {
// All workers have nil regular inputs. Remove all workers exept the first (*some* worker has to do *something*).
final List<InputSlice> firstSlices = assignmentsMap.get(nilWorkers.firstInt());
assignmentsMap.clear();
assignmentsMap.put(nilWorkers.firstInt(), firstSlices);
} else {
// Remove all nil workers.
for (final int nilWorker : nilWorkers) {
assignmentsMap.remove(nilWorker);
}
}
return new WorkerInputs(assignmentsMap);
@ -154,7 +165,7 @@ public class WorkerInputs
return Preconditions.checkNotNull(assignmentsMap.get(workerNumber), "worker [%s]", workerNumber);
}
public IntSet workers()
public IntSortedSet workers()
{
return assignmentsMap.keySet();
}

View File

@ -331,9 +331,10 @@ public class DartSqlResourceTest extends MSQTestBase
"sid",
"did2",
"SELECT 2",
"localhost:1002",
AUTHENTICATOR_NAME,
DIFFERENT_REGULAR_USER_NAME,
DateTimes.of("2000"),
DateTimes.of("2001"),
ControllerHolder.State.RUNNING.toString()
);
Mockito.when(dartSqlClient.getRunningQueries(true))
@ -398,6 +399,7 @@ public class DartSqlResourceTest extends MSQTestBase
"sid",
"did2",
"SELECT 2",
"localhost:1002",
AUTHENTICATOR_NAME,
DIFFERENT_REGULAR_USER_NAME,
DateTimes.of("2000"),
@ -434,6 +436,7 @@ public class DartSqlResourceTest extends MSQTestBase
"sid",
"did2",
"SELECT 2",
"localhost:1002",
AUTHENTICATOR_NAME,
DIFFERENT_REGULAR_USER_NAME,
DateTimes.of("2000"),
@ -724,7 +727,7 @@ public class DartSqlResourceTest extends MSQTestBase
.thenReturn(makeAuthenticationResult(REGULAR_USER_NAME));
final Response cancellationResponse = sqlResource.cancelQuery("nonexistent", httpServletRequest);
Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), cancellationResponse.getStatus());
Assertions.assertEquals(Response.Status.ACCEPTED.getStatusCode(), cancellationResponse.getStatus());
}
/**
@ -739,8 +742,15 @@ public class DartSqlResourceTest extends MSQTestBase
Mockito.when(controller.queryId()).thenReturn("did_" + identity);
final AuthenticationResult authenticationResult = makeAuthenticationResult(identity);
final ControllerHolder holder =
new ControllerHolder(controller, null, "sid", "SELECT 1", authenticationResult, DateTimes.of("2000"));
final ControllerHolder holder = new ControllerHolder(
controller,
null,
"sid",
"SELECT 1",
"localhost:1001",
authenticationResult,
DateTimes.of("2000")
);
controllerRegistry.register(holder);
return holder;

View File

@ -41,6 +41,7 @@ public class GetQueriesResponseTest
"xyz",
"abc",
"SELECT 1",
"localhost:1001",
"auth",
"anon",
DateTimes.of("2000"),

View File

@ -69,6 +69,7 @@ public class DartSqlClientImplTest
"sid",
"did",
"SELECT 1",
"localhost:1001",
"",
"",
DateTimes.of("2000"),
@ -97,6 +98,7 @@ public class DartSqlClientImplTest
"sid",
"did",
"SELECT 1",
"localhost:1001",
"",
"",
DateTimes.of("2000"),

View File

@ -33,21 +33,24 @@ public class CollectedReadablePartitionsTest
@Test
public void testPartitionToWorkerMap()
{
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals(ImmutableMap.of(0, 1, 1, 2, 2, 1), partitions.getPartitionToWorkerMap());
}
@Test
public void testStageNumber()
{
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals(1, partitions.getStageNumber());
}
@Test
public void testSplit()
{
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals(
ImmutableList.of(
@ -64,7 +67,8 @@ public class CollectedReadablePartitionsTest
final ObjectMapper mapper = TestHelper.makeJsonMapper()
.registerModules(new MSQIndexingModule().getJacksonModules());
final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
final CollectedReadablePartitions partitions =
(CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
Assert.assertEquals(
partitions,

View File

@ -31,7 +31,7 @@ import org.junit.Test;
public class CombinedReadablePartitionsTest
{
private static final CombinedReadablePartitions PARTITIONS = ReadablePartitions.combine(
private static final ReadablePartitions PARTITIONS = ReadablePartitions.combine(
ImmutableList.of(
ReadablePartitions.striped(0, 2, 2),
ReadablePartitions.striped(1, 2, 4)

View File

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.input.stage;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.segment.TestHelper;
import org.junit.Assert;
import org.junit.Test;
public class SparseStripedReadablePartitionsTest
{
@Test
public void testPartitionNumbers()
{
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3);
Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers());
}
@Test
public void testWorkers()
{
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3);
Assert.assertEquals(IntSet.of(1, 3), partitions.getWorkers());
}
@Test
public void testStageNumber()
{
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3);
Assert.assertEquals(1, partitions.getStageNumber());
}
@Test
public void testSplit()
{
final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3});
final SparseStripedReadablePartitions partitions =
(SparseStripedReadablePartitions) ReadablePartitions.striped(1, workers, 3);
Assert.assertEquals(
ImmutableList.of(
new SparseStripedReadablePartitions(1, workers, new IntAVLTreeSet(new int[]{0, 2})),
new SparseStripedReadablePartitions(1, workers, new IntAVLTreeSet(new int[]{1}))
),
partitions.split(2)
);
}
@Test
public void testSerde() throws Exception
{
final ObjectMapper mapper = TestHelper.makeJsonMapper()
.registerModules(new MSQIndexingModule().getJacksonModules());
final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3});
final ReadablePartitions partitions = ReadablePartitions.striped(1, workers, 3);
Assert.assertEquals(
partitions,
mapper.readValue(
mapper.writeValueAsString(partitions),
ReadablePartitions.class
)
);
}
@Test
public void testEquals()
{
EqualsVerifier.forClass(SparseStripedReadablePartitions.class).usingGetClass().verify();
}
}

View File

@ -26,36 +26,60 @@ import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.segment.TestHelper;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.junit.Assert;
import org.junit.Test;
public class StripedReadablePartitionsTest
{
@Test
public void testFromDenseSet()
{
// Tests that when ReadablePartitions.striped is called with a dense set, we get StripedReadablePartitions.
final IntAVLTreeSet workers = new IntAVLTreeSet();
workers.add(0);
workers.add(1);
final ReadablePartitions readablePartitionsFromSet = ReadablePartitions.striped(1, workers, 3);
MatcherAssert.assertThat(
readablePartitionsFromSet,
CoreMatchers.instanceOf(StripedReadablePartitions.class)
);
Assert.assertEquals(
ReadablePartitions.striped(1, 2, 3),
readablePartitionsFromSet
);
}
@Test
public void testPartitionNumbers()
{
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers());
}
@Test
public void testNumWorkers()
{
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(2, partitions.getNumWorkers());
}
@Test
public void testStageNumber()
{
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(1, partitions.getStageNumber());
}
@Test
public void testSplit()
{
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(
ImmutableList.of(
@ -72,7 +96,7 @@ public class StripedReadablePartitionsTest
final ObjectMapper mapper = TestHelper.makeJsonMapper()
.registerModules(new MSQIndexingModule().getJacksonModules());
final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
Assert.assertEquals(
partitions,

View File

@ -25,9 +25,11 @@ import it.unimi.dsi.fastutil.ints.Int2IntMaps;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongList;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.error.DruidException;
import org.apache.druid.msq.exec.Limits;
import org.apache.druid.msq.exec.OutputChannelMode;
import org.apache.druid.msq.input.InputSlice;
@ -75,7 +77,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.MAX,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -91,6 +93,35 @@ public class WorkerInputsTest
);
}
@Test
public void test_max_threeInputs_fourWorkers_withGaps()
{
final StageDefinition stageDef =
StageDefinition.builder(0)
.inputs(new TestInputSpec(1, 2, 3))
.maxWorkerCount(4)
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID);
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(new IntAVLTreeSet(new int[]{1, 3, 4, 5}), true),
WorkerAssignmentStrategy.MAX,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
Assert.assertEquals(
ImmutableMap.<Integer, List<InputSlice>>builder()
.put(1, Collections.singletonList(new TestInputSlice(1)))
.put(3, Collections.singletonList(new TestInputSlice(2)))
.put(4, Collections.singletonList(new TestInputSlice(3)))
.put(5, Collections.singletonList(new TestInputSlice()))
.build(),
inputs.assignmentsMap()
);
}
@Test
public void test_max_zeroInputs_fourWorkers()
{
@ -104,7 +135,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.MAX,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -133,7 +164,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -159,7 +190,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -186,7 +217,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -212,7 +243,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -324,7 +355,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(4), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -351,7 +382,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(2), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -384,7 +415,7 @@ public class WorkerInputsTest
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
Int2IntMaps.EMPTY_MAP,
new TestInputSpecSlicer(true),
new TestInputSpecSlicer(denseWorkers(1), true),
WorkerAssignmentStrategy.AUTO,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
);
@ -411,7 +442,7 @@ public class WorkerInputsTest
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID);
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true));
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true));
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
@ -455,7 +486,7 @@ public class WorkerInputsTest
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID);
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true));
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true));
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
@ -498,7 +529,7 @@ public class WorkerInputsTest
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L))
.build(QUERY_ID);
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true));
TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true));
final WorkerInputs inputs = WorkerInputs.create(
stageDef,
@ -585,11 +616,23 @@ public class WorkerInputsTest
private static class TestInputSpecSlicer implements InputSpecSlicer
{
private final IntSortedSet workers;
private final boolean canSliceDynamic;
public TestInputSpecSlicer(boolean canSliceDynamic)
/**
* Create a test slicer.
*
* @param workers Set of workers to consider assigning work to.
* @param canSliceDynamic Whether this slicer can slice dynamically.
*/
public TestInputSpecSlicer(final IntSortedSet workers, final boolean canSliceDynamic)
{
this.workers = workers;
this.canSliceDynamic = canSliceDynamic;
if (workers.isEmpty()) {
throw DruidException.defensive("Need more than one worker in workers[%s]", workers);
}
}
@Override
@ -606,9 +649,9 @@ public class WorkerInputsTest
SlicerUtils.makeSlicesStatic(
testInputSpec.values.iterator(),
i -> i,
maxNumSlices
Math.min(maxNumSlices, workers.size())
);
return makeSlices(assignments);
return makeSlices(workers, assignments);
}
@Override
@ -624,24 +667,39 @@ public class WorkerInputsTest
SlicerUtils.makeSlicesDynamic(
testInputSpec.values.iterator(),
i -> i,
maxNumSlices,
Math.min(maxNumSlices, workers.size()),
maxFilesPerSlice,
maxBytesPerSlice
);
return makeSlices(assignments);
return makeSlices(workers, assignments);
}
private static List<InputSlice> makeSlices(
final IntSortedSet workers,
final List<List<Long>> assignments
)
{
final List<InputSlice> retVal = new ArrayList<>(assignments.size());
for (final List<Long> assignment : assignments) {
retVal.add(new TestInputSlice(new LongArrayList(assignment)));
for (int assignment = 0, workerNumber = 0;
workerNumber <= workers.lastInt() && assignment < assignments.size();
workerNumber++) {
if (workers.contains(workerNumber)) {
retVal.add(new TestInputSlice(new LongArrayList(assignments.get(assignment++))));
} else {
retVal.add(NilInputSlice.INSTANCE);
}
}
return retVal;
}
}
private static IntSortedSet denseWorkers(final int numWorkers)
{
final IntAVLTreeSet workers = new IntAVLTreeSet();
for (int i = 0; i < numWorkers; i++) {
workers.add(i);
}
return workers;
}
}

View File

@ -3327,7 +3327,7 @@ name: Protocol Buffers
license_category: binary
module: java-core
license_name: BSD-3-Clause License
version: 3.24.0
version: 3.25.5
copyright: Google, Inc.
license_file_path:
- licenses/bin/protobuf-java.BSD3
@ -3493,7 +3493,7 @@ name: Protocol Buffers
license_category: binary
module: extensions/druid-protobuf-extensions
license_name: BSD-3-Clause License
version: 3.24.0
version: 3.25.5
copyright: Google, Inc.
license_file_path: licenses/bin/protobuf-java.BSD3
libraries:

View File

@ -108,7 +108,7 @@
<netty3.version>3.10.6.Final</netty3.version>
<netty4.version>4.1.108.Final</netty4.version>
<postgresql.version>42.7.2</postgresql.version>
<protobuf.version>3.24.0</protobuf.version>
<protobuf.version>3.25.5</protobuf.version>
<resilience4j.version>1.3.1</resilience4j.version>
<slf4j.version>1.7.36</slf4j.version>
<jna.version>5.13.0</jna.version>