Redesign QueryContext class (#13071)

We introduce two new configuration keys that refine the query context security model controlled by druid.auth.authorizeQueryContextParams. When that value is set to true then two other configuration options become available:

druid.auth.unsecuredContextKeys: The set of query context keys that do not require a security check. Use this for the "white-list" of key to allow. All other keys go through the existing context key security checks.
druid.auth.securedContextKeys: The set of query context keys that do require a security check. Use this when you want to allow all but a specific set of keys: only these keys go through the existing context key security checks.
Both are set using JSON list format:

druid.auth.securedContextKeys=["secretKey1", "secretKey2"]
You generally set one or the other values. If both are set, unsecuredContextKeys acts as exceptions to securedContextKeys.

In addition, Druid defines two query context keys which always bypass checks because Druid uses them internally:

sqlQueryId
sqlStringifyArrays
This commit is contained in:
Paul Rogers 2022-10-14 22:32:11 -07:00 committed by GitHub
parent 6332c571bd
commit f4dcc52dac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
147 changed files with 2312 additions and 1710 deletions

View File

@ -29,7 +29,6 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.aggregation.datasketches.hll.sql.HllSketchApproxCountDistinctSqlAggregator;
@ -516,7 +515,7 @@ public class SqlBenchmark
QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize
);
final String sql = QUERIES.get(Integer.parseInt(query));
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) {
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) {
final PlannerResult plannerResult = planner.plan();
final Sequence<Object[]> resultSequence = plannerResult.run().getResults();
final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in);
@ -534,7 +533,7 @@ public class SqlBenchmark
QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize
);
final String sql = QUERIES.get(Integer.parseInt(query));
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) {
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) {
final PlannerResult plannerResult = planner.plan();
blackhole.consume(plannerResult);
}

View File

@ -29,7 +29,6 @@ import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.segment.QueryableIndex;
@ -352,7 +351,7 @@ public class SqlExpressionBenchmark
QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize
);
final String sql = QUERIES.get(Integer.parseInt(query));
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) {
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) {
final PlannerResult plannerResult = planner.plan();
final Sequence<Object[]> resultSequence = plannerResult.run().getResults();
final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in);

View File

@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.expression.TestExprMacroTable;
@ -318,7 +317,7 @@ public class SqlNestedDataBenchmark
QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize
);
final String sql = QUERIES.get(Integer.parseInt(query));
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) {
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) {
final PlannerResult plannerResult = planner.plan();
final Sequence<Object[]> resultSequence = plannerResult.run().getResults();
final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in);

View File

@ -26,7 +26,6 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
@ -66,6 +65,7 @@ import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.util.Collections;
import java.util.concurrent.TimeUnit;
/**
@ -167,7 +167,7 @@ public class SqlVsNativeBenchmark
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void queryPlanner(Blackhole blackhole) throws Exception
{
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sqlQuery, new QueryContext())) {
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sqlQuery, Collections.emptyMap())) {
final PlannerResult plannerResult = planner.plan();
final Sequence<Object[]> resultSequence = plannerResult.run().getResults();
final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in);

View File

@ -29,9 +29,11 @@ import java.util.AbstractCollection;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.Spliterator;
import java.util.TreeSet;
import java.util.function.Function;
@ -148,6 +150,40 @@ public final class CollectionUtils
return list == null || list.isEmpty();
}
/**
* Subtract one set from another: {@code C = A - B}.
*/
public static <T> Set<T> subtract(Set<T> left, Set<T> right)
{
Set<T> result = new HashSet<>(left);
result.removeAll(right);
return result;
}
/**
* Intersection of two sets: {@code C = A B}.
*/
public static <T> Set<T> intersect(Set<T> left, Set<T> right)
{
Set<T> result = new HashSet<>();
for (T key : left) {
if (right.contains(key)) {
result.add(key);
}
}
return result;
}
/**
* Intersection of two sets: {@code C = A B}.
*/
public static <T> Set<T> union(Set<T> left, Set<T> right)
{
Set<T> result = new HashSet<>(left);
result.addAll(right);
return result;
}
private CollectionUtils()
{
}

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.utils;
import com.google.common.collect.ImmutableSet;
import org.junit.Test;
import java.util.Set;
import static org.junit.Assert.assertEquals;
public class CollectionUtilsTest
{
// When Java 9 is allowed, use Set.of().
Set<String> empty = ImmutableSet.of();
Set<String> abc = ImmutableSet.of("a", "b", "c");
Set<String> bcd = ImmutableSet.of("b", "c", "d");
Set<String> efg = ImmutableSet.of("e", "f", "g");
@Test
public void testSubtract()
{
assertEquals(empty, CollectionUtils.subtract(empty, empty));
assertEquals(abc, CollectionUtils.subtract(abc, empty));
assertEquals(empty, CollectionUtils.subtract(abc, abc));
assertEquals(abc, CollectionUtils.subtract(abc, efg));
assertEquals(ImmutableSet.of("a"), CollectionUtils.subtract(abc, bcd));
}
@Test
public void testIntersect()
{
assertEquals(empty, CollectionUtils.intersect(empty, empty));
assertEquals(abc, CollectionUtils.intersect(abc, abc));
assertEquals(empty, CollectionUtils.intersect(abc, efg));
assertEquals(ImmutableSet.of("b", "c"), CollectionUtils.intersect(abc, bcd));
}
@Test
public void testUnion()
{
assertEquals(empty, CollectionUtils.union(empty, empty));
assertEquals(abc, CollectionUtils.union(abc, abc));
assertEquals(ImmutableSet.of("a", "b", "c", "e", "f", "g"), CollectionUtils.union(abc, efg));
assertEquals(ImmutableSet.of("a", "b", "c", "d"), CollectionUtils.union(abc, bcd));
}
}

View File

@ -28,7 +28,6 @@ import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.BaseQuery;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QuerySegmentWalker;
import org.apache.druid.query.filter.DimFilter;
@ -41,6 +40,7 @@ import org.joda.time.Duration;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -146,12 +146,6 @@ public class MaterializedViewQuery<T> implements Query<T>
return query.getContext();
}
@Override
public QueryContext getQueryContext()
{
return query.getQueryContext();
}
@Override
public boolean isDescending()
{

View File

@ -121,7 +121,6 @@ public class MaterializedViewQueryTest
.postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT)
.build();
MaterializedViewQuery query = new MaterializedViewQuery(topNQuery, optimizer);
Assert.assertEquals(20_000_000, query.getContextAsHumanReadableBytes("maxOnDiskStorage", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(20_000_000, query.context().getHumanReadableBytes("maxOnDiskStorage", HumanReadableBytes.ZERO).getBytes());
}
}

View File

@ -237,7 +237,7 @@ public class MovingAverageQuery extends BaseQuery<Row>
@JsonIgnore
public boolean getContextSortByDimsFirst()
{
return getContextBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false);
return context().getBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false);
}
@Override

View File

@ -30,7 +30,6 @@ import org.apache.druid.java.util.common.granularity.PeriodGranularity;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
@ -52,6 +51,7 @@ import org.joda.time.Interval;
import org.joda.time.Period;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -124,7 +124,7 @@ public class MovingAverageQueryRunner implements QueryRunner<Row>
ResponseContext gbqResponseContext = ResponseContext.createEmpty();
gbqResponseContext.merge(responseContext);
gbqResponseContext.putQueryFailDeadlineMs(
System.currentTimeMillis() + QueryContexts.getTimeout(gbq)
System.currentTimeMillis() + gbq.context().getTimeout()
);
Sequence<ResultRow> results = gbq.getRunner(walker).run(QueryPlus.wrap(gbq), gbqResponseContext);
@ -164,7 +164,7 @@ public class MovingAverageQueryRunner implements QueryRunner<Row>
ResponseContext tsqResponseContext = ResponseContext.createEmpty();
tsqResponseContext.merge(responseContext);
tsqResponseContext.putQueryFailDeadlineMs(
System.currentTimeMillis() + QueryContexts.getTimeout(tsq)
System.currentTimeMillis() + tsq.context().getTimeout()
);
Sequence<Result<TimeseriesResultValue>> results = tsq.getRunner(walker).run(QueryPlus.wrap(tsq), tsqResponseContext);

View File

@ -49,6 +49,7 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.List;
public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
@ -171,7 +172,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
histogramName,
input.getDirectColumn(),
k,
getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext())
getMaxStreamLengthFromQueryContext(plannerContext.queryContext())
);
} else {
String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(
@ -182,7 +183,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
histogramName,
virtualColumnName,
k,
getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext())
getMaxStreamLengthFromQueryContext(plannerContext.queryContext())
);
}
@ -201,7 +202,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
static long getMaxStreamLengthFromQueryContext(QueryContext queryContext)
{
return queryContext.getAsLong(
return queryContext.getLong(
CTX_APPROX_QUANTILE_DS_MAX_STREAM_LENGTH,
DoublesSketchAggregatorFactory.DEFAULT_MAX_STREAM_LENGTH
);

View File

@ -46,6 +46,7 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.List;
public class DoublesSketchObjectSqlAggregator implements SqlAggregator
@ -113,7 +114,7 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
histogramName,
input.getDirectColumn(),
k,
DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext())
DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.queryContext())
);
} else {
String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(
@ -124,7 +125,7 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
histogramName,
virtualColumnName,
k,
DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext())
DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.queryContext())
);
}
@ -136,7 +137,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
private static class DoublesSketchSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE1 = "'" + NAME + "(column)'\n";
private static final String SIGNATURE2 = "'" + NAME + "(column, k)'\n";
DoublesSketchSqlAggFunction()

View File

@ -27,6 +27,7 @@ import com.google.common.collect.Iterables;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
@ -53,7 +54,6 @@ import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFacto
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.timeline.DataSegment;
@ -324,7 +324,7 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ
new QuantilePostAggregator("a6", "a6:agg", 0.999f),
new QuantilePostAggregator("a7", "a5:agg", 0.999f)
)
.context(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(ImmutableMap.of(QueryContexts.CTX_SQL_QUERY_ID, "dummy"))
.build()
),
ImmutableList.of(

View File

@ -518,7 +518,7 @@ public class ControllerImpl implements Controller
closer.register(netClient::close);
final boolean isDurableStorageEnabled =
MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().getContext());
MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context());
final QueryDefinition queryDef = makeQueryDefinition(
id(),
@ -1191,7 +1191,7 @@ public class ControllerImpl implements Controller
final InputChannelFactory inputChannelFactory;
if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().getContext())) {
if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context())) {
inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation(
id(),
() -> taskIds,
@ -1294,7 +1294,7 @@ public class ControllerImpl implements Controller
*/
private void cleanUpDurableStorageIfNeeded()
{
if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().getContext())) {
if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context())) {
final String controllerDirName = DurableStorageOutputChannelFactory.getControllerDirectory(task.getId());
try {
// Delete all temporary files as a failsafe
@ -1454,7 +1454,7 @@ public class ControllerImpl implements Controller
)
{
if (isRollupQuery) {
final String queryGranularity = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY, "");
final String queryGranularity = query.context().getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY, "");
if (timeIsGroupByDimension((GroupByQuery) query, columnMappings) && !queryGranularity.isEmpty()) {
return new ArbitraryGranularitySpec(
@ -1483,7 +1483,7 @@ public class ControllerImpl implements Controller
{
if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) {
final String queryTimeColumn = columnMappings.getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME);
return queryTimeColumn.equals(groupByQuery.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD));
return queryTimeColumn.equals(groupByQuery.context().getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD));
} else {
return false;
}
@ -1505,8 +1505,8 @@ public class ControllerImpl implements Controller
private static boolean isRollupQuery(Query<?> query)
{
return query instanceof GroupByQuery
&& !MultiStageQueryContext.isFinalizeAggregations(query.getQueryContext())
&& !query.getContextBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true);
&& !MultiStageQueryContext.isFinalizeAggregations(query.context())
&& !query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true);
}
private static boolean isInlineResults(final MSQSpec querySpec)

View File

@ -106,6 +106,7 @@ import org.apache.druid.msq.util.DecoratedExecutorService;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.PrioritizedCallable;
import org.apache.druid.query.PrioritizedRunnable;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryProcessingPool;
import org.apache.druid.server.DruidNode;
@ -177,7 +178,9 @@ public class WorkerImpl implements Worker
this.context = context;
this.selfDruidNode = context.selfNode();
this.processorBouncer = context.processorBouncer();
this.durableStageStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(task.getContext());
this.durableStageStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(
QueryContext.of(task.getContext())
);
}
@Override

View File

@ -106,7 +106,7 @@ public class MSQControllerTask extends AbstractTask
this.sqlQueryContext = sqlQueryContext;
this.sqlTypeNames = sqlTypeNames;
if (MultiStageQueryContext.isDurableStorageEnabled(querySpec.getQuery().getContext())) {
if (MultiStageQueryContext.isDurableStorageEnabled(querySpec.getQuery().context())) {
this.remoteFetchExecutorService =
Executors.newCachedThreadPool(Execs.makeThreadFactory(getId() + "-remote-fetcher-%d"));
} else {

View File

@ -191,7 +191,7 @@ public class QueryKitUtils
public static VirtualColumn makeSegmentGranularityVirtualColumn(final Query<?> query)
{
final Granularity segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext(query.getContext());
final String timeColumnName = query.getQueryContext().getAsString(QueryKitUtils.CTX_TIME_COLUMN_NAME);
final String timeColumnName = query.context().getString(QueryKitUtils.CTX_TIME_COLUMN_NAME);
if (timeColumnName == null || Granularities.ALL.equals(segmentGranularity)) {
return null;

View File

@ -37,7 +37,6 @@ import org.apache.druid.msq.querykit.ShuffleSpecFactories;
import org.apache.druid.msq.querykit.ShuffleSpecFactory;
import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.having.AlwaysHavingSpec;
@ -205,7 +204,7 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
*/
static boolean isFinalize(final GroupByQuery query)
{
return QueryContexts.isFinalize(query, true);
return query.context().isFinalize(true);
}
/**

View File

@ -57,7 +57,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
{
RowSignature scanSignature;
try {
final String s = scanQuery.getQueryContext().getAsString(DruidQuery.CTX_SCAN_SIGNATURE);
final String s = scanQuery.context().getString(DruidQuery.CTX_SCAN_SIGNATURE);
scanSignature = jsonMapper.readValue(s, RowSignature.class);
}
catch (JsonProcessingException e) {
@ -74,7 +74,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
* 2. This is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec
*/
// No ordering, but there is a limit or an offset. These work by funneling everything through a single partition.
// So there is no point in forcing any particular partitioning. Since everything is funnelled into a single
// So there is no point in forcing any particular partitioning. Since everything is funneled into a single
// partition without a ClusterBy, we don't need to necessarily create it via the resultShuffleSpecFactory provided
@Override
public QueryDefinition makeQueryDefinition(

View File

@ -23,9 +23,10 @@ import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.indexing.error.MSQWarnings;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@ -64,7 +65,7 @@ public enum MSQMode
return value;
}
public static void populateDefaultQueryContext(final String modeStr, final QueryContext originalQueryContext)
public static void populateDefaultQueryContext(final String modeStr, final Map<String, Object> originalQueryContext)
{
MSQMode mode = MSQMode.fromString(modeStr);
if (mode == null) {
@ -74,8 +75,7 @@ public enum MSQMode
Arrays.stream(MSQMode.values()).map(m -> m.value).collect(Collectors.toList())
);
}
Map<String, Object> defaultQueryContext = mode.defaultQueryContext;
log.debug("Populating default query context with %s for the %s multi stage query mode", defaultQueryContext, mode);
originalQueryContext.addDefaultParams(defaultQueryContext);
log.debug("Populating default query context with %s for the %s multi stage query mode", mode.defaultQueryContext, mode);
QueryContexts.addDefaults(originalQueryContext, mode.defaultQueryContext);
}
}

View File

@ -42,6 +42,7 @@ import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.TaskReportMSQDestination;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.rpc.indexing.OverlordClient;
@ -59,6 +60,7 @@ import org.apache.druid.sql.calcite.table.RowSignatures;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
@ -109,17 +111,18 @@ public class MSQTaskQueryMaker implements QueryMaker
{
String taskId = MSQTasks.controllerTaskId(plannerContext.getSqlQueryId());
String msqMode = MultiStageQueryContext.getMSQMode(plannerContext.getQueryContext());
QueryContext queryContext = plannerContext.queryContext();
String msqMode = MultiStageQueryContext.getMSQMode(queryContext);
if (msqMode != null) {
MSQMode.populateDefaultQueryContext(msqMode, plannerContext.getQueryContext());
MSQMode.populateDefaultQueryContext(msqMode, plannerContext.queryContextMap());
}
final String ctxDestination =
DimensionHandlerUtils.convertObjectToString(MultiStageQueryContext.getDestination(plannerContext.getQueryContext()));
DimensionHandlerUtils.convertObjectToString(MultiStageQueryContext.getDestination(queryContext));
Object segmentGranularity;
try {
segmentGranularity = Optional.ofNullable(plannerContext.getQueryContext()
segmentGranularity = Optional.ofNullable(plannerContext.queryContext()
.get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY))
.orElse(jsonMapper.writeValueAsString(DEFAULT_SEGMENT_GRANULARITY));
}
@ -128,7 +131,7 @@ public class MSQTaskQueryMaker implements QueryMaker
+ "segment graularity");
}
final int maxNumTasks = MultiStageQueryContext.getMaxNumTasks(plannerContext.getQueryContext());
final int maxNumTasks = MultiStageQueryContext.getMaxNumTasks(queryContext);
if (maxNumTasks < 2) {
throw new IAE(MultiStageQueryContext.CTX_MAX_NUM_TASKS
@ -139,19 +142,19 @@ public class MSQTaskQueryMaker implements QueryMaker
final int maxNumWorkers = maxNumTasks - 1;
final int rowsPerSegment = MultiStageQueryContext.getRowsPerSegment(
plannerContext.getQueryContext(),
queryContext,
DEFAULT_ROWS_PER_SEGMENT
);
final int maxRowsInMemory = MultiStageQueryContext.getRowsInMemory(
plannerContext.getQueryContext(),
queryContext,
DEFAULT_ROWS_IN_MEMORY
);
final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(plannerContext.getQueryContext());
final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(queryContext);
final List<Interval> replaceTimeChunks =
Optional.ofNullable(plannerContext.getQueryContext().get(DruidSqlReplace.SQL_REPLACE_TIME_CHUNKS))
Optional.ofNullable(plannerContext.queryContext().get(DruidSqlReplace.SQL_REPLACE_TIME_CHUNKS))
.map(
s -> {
if (s instanceof String && "all".equals(StringUtils.toLowerCase((String) s))) {
@ -213,7 +216,7 @@ public class MSQTaskQueryMaker implements QueryMaker
}
final List<String> segmentSortOrder = MultiStageQueryContext.decodeSortOrder(
MultiStageQueryContext.getSortOrder(plannerContext.getQueryContext())
MultiStageQueryContext.getSortOrder(queryContext)
);
validateSegmentSortOrder(
@ -245,7 +248,7 @@ public class MSQTaskQueryMaker implements QueryMaker
.query(druidQuery.getQuery().withOverriddenContext(nativeQueryContextOverrides))
.columnMappings(new ColumnMappings(columnMappings))
.destination(destination)
.assignmentStrategy(MultiStageQueryContext.getAssignmentStrategy(plannerContext.getQueryContext()))
.assignmentStrategy(MultiStageQueryContext.getAssignmentStrategy(queryContext))
.tuningConfig(new MSQTuningConfig(maxNumWorkers, maxRowsInMemory, rowsPerSegment))
.build();
@ -253,7 +256,7 @@ public class MSQTaskQueryMaker implements QueryMaker
taskId,
querySpec,
plannerContext.getSql(),
plannerContext.getQueryContext().getMergedParams(),
plannerContext.queryContextMap(),
sqlTypeNames,
null
);

View File

@ -38,7 +38,6 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.msq.querykit.QueryKitUtils;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.QueryContext;
import org.apache.druid.rpc.indexing.OverlordClient;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
@ -52,6 +51,7 @@ import org.apache.druid.sql.calcite.run.SqlEngines;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class MSQTaskSqlEngine implements SqlEngine
@ -86,7 +86,7 @@ public class MSQTaskSqlEngine implements SqlEngine
}
@Override
public void validateContext(QueryContext queryContext) throws ValidationException
public void validateContext(Map<String, Object> queryContext) throws ValidationException
{
SqlEngines.validateNoSpecialContextKeys(queryContext, SYSTEM_CONTEXT_PARAMETERS);
}
@ -166,7 +166,7 @@ public class MSQTaskSqlEngine implements SqlEngine
{
validateNoDuplicateAliases(fieldMappings);
if (plannerContext.getQueryContext().containsKey(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) {
if (plannerContext.queryContext().containsKey(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) {
throw new ValidationException(
StringUtils.format("Cannot use \"%s\" without INSERT", DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)
);
@ -207,14 +207,14 @@ public class MSQTaskSqlEngine implements SqlEngine
try {
segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext(
plannerContext.getQueryContext().getMergedParams()
plannerContext.queryContextMap()
);
}
catch (Exception e) {
throw new ValidationException(
StringUtils.format(
"Invalid segmentGranularity: %s",
plannerContext.getQueryContext().get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)
plannerContext.queryContext().get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)
),
e
);

View File

@ -26,12 +26,12 @@ import com.google.common.annotations.VisibleForTesting;
import com.opencsv.RFC4180Parser;
import com.opencsv.RFC4180ParserBuilder;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.sql.MSQMode;
import org.apache.druid.query.QueryContext;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
@ -59,7 +59,7 @@ public class MultiStageQueryContext
private static final boolean DEFAULT_FINALIZE_AGGREGATIONS = true;
public static final String CTX_ENABLE_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage";
private static final String DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = "false";
private static final boolean DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = false;
public static final String CTX_DESTINATION = "destination";
private static final String DEFAULT_DESTINATION = null;
@ -77,48 +77,34 @@ public class MultiStageQueryContext
private static final Pattern LOOKS_LIKE_JSON_ARRAY = Pattern.compile("^\\s*\\[.*", Pattern.DOTALL);
public static String getMSQMode(QueryContext queryContext)
public static String getMSQMode(final QueryContext queryContext)
{
return (String) MultiStageQueryContext.getValueFromPropertyMap(
queryContext.getMergedParams(),
return queryContext.getString(
CTX_MSQ_MODE,
null,
DEFAULT_MSQ_MODE
);
}
public static boolean isDurableStorageEnabled(Map<String, Object> propertyMap)
public static boolean isDurableStorageEnabled(final QueryContext queryContext)
{
return Boolean.parseBoolean(
String.valueOf(
getValueFromPropertyMap(
propertyMap,
CTX_ENABLE_DURABLE_SHUFFLE_STORAGE,
null,
DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE
)
)
return queryContext.getBoolean(
CTX_ENABLE_DURABLE_SHUFFLE_STORAGE,
DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE
);
}
public static boolean isFinalizeAggregations(final QueryContext queryContext)
{
return Numbers.parseBoolean(
getValueFromPropertyMap(
queryContext.getMergedParams(),
CTX_FINALIZE_AGGREGATIONS,
null,
DEFAULT_FINALIZE_AGGREGATIONS
)
return queryContext.getBoolean(
CTX_FINALIZE_AGGREGATIONS,
DEFAULT_FINALIZE_AGGREGATIONS
);
}
public static WorkerAssignmentStrategy getAssignmentStrategy(final QueryContext queryContext)
{
String assignmentStrategyString = (String) getValueFromPropertyMap(
queryContext.getMergedParams(),
String assignmentStrategyString = queryContext.getString(
CTX_TASK_ASSIGNMENT_STRATEGY,
null,
DEFAULT_TASK_ASSIGNMENT_STRATEGY
);
@ -127,47 +113,33 @@ public class MultiStageQueryContext
public static int getMaxNumTasks(final QueryContext queryContext)
{
return Numbers.parseInt(
getValueFromPropertyMap(
queryContext.getMergedParams(),
CTX_MAX_NUM_TASKS,
null,
DEFAULT_MAX_NUM_TASKS
)
return queryContext.getInt(
CTX_MAX_NUM_TASKS,
DEFAULT_MAX_NUM_TASKS
);
}
public static Object getDestination(final QueryContext queryContext)
{
return getValueFromPropertyMap(
queryContext.getMergedParams(),
return queryContext.get(
CTX_DESTINATION,
null,
DEFAULT_DESTINATION
);
}
public static int getRowsPerSegment(final QueryContext queryContext, int defaultRowsPerSegment)
{
return Numbers.parseInt(
getValueFromPropertyMap(
queryContext.getMergedParams(),
CTX_ROWS_PER_SEGMENT,
null,
defaultRowsPerSegment
)
return queryContext.getInt(
CTX_ROWS_PER_SEGMENT,
defaultRowsPerSegment
);
}
public static int getRowsInMemory(final QueryContext queryContext, int defaultRowsInMemory)
{
return Numbers.parseInt(
getValueFromPropertyMap(
queryContext.getMergedParams(),
CTX_ROWS_IN_MEMORY,
null,
defaultRowsInMemory
)
return queryContext.getInt(
CTX_ROWS_IN_MEMORY,
defaultRowsInMemory
);
}
@ -196,10 +168,8 @@ public class MultiStageQueryContext
public static String getSortOrder(final QueryContext queryContext)
{
return (String) getValueFromPropertyMap(
queryContext.getMergedParams(),
return queryContext.getString(
CTX_SORT_ORDER,
null,
DEFAULT_SORT_ORDER
);
}

View File

@ -22,34 +22,36 @@ package org.apache.druid.msq.sql;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.indexing.error.MSQWarnings;
import org.apache.druid.query.QueryContext;
import org.junit.Assert;
import org.junit.Test;
import java.util.HashMap;
import java.util.Map;
public class MSQModeTest
{
@Test
public void testPopulateQueryContextWhenNoSupercedingValuePresent()
{
QueryContext originalQueryContext = new QueryContext();
Map<String, Object> originalQueryContext = new HashMap<>();
MSQMode.populateDefaultQueryContext("strict", originalQueryContext);
Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 0), originalQueryContext.getMergedParams());
Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 0), originalQueryContext);
}
@Test
public void testPopulateQueryContextWhenSupercedingValuePresent()
{
QueryContext originalQueryContext = new QueryContext(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10));
Map<String, Object> originalQueryContext = new HashMap<>();
originalQueryContext.put(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10);
MSQMode.populateDefaultQueryContext("strict", originalQueryContext);
Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10), originalQueryContext.getMergedParams());
Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10), originalQueryContext);
}
@Test
public void testPopulateQueryContextWhenInvalidMode()
{
QueryContext originalQueryContext = new QueryContext();
Map<String, Object> originalQueryContext = new HashMap<>();
Assert.assertThrows(ISE.class, () -> {
MSQMode.populateDefaultQueryContext("fake_mode", originalQueryContext);
});

View File

@ -89,7 +89,6 @@ import org.apache.druid.msq.sql.MSQTaskSqlEngine;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.ForwardingQueryProcessingPool;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryProcessingPool;
import org.apache.druid.query.aggregation.AggregatorFactory;
@ -132,7 +131,6 @@ import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.external.ExternalDataSource;
import org.apache.druid.sql.calcite.planner.CalciteRulesManager;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.planner.PlannerFactory;
import org.apache.druid.sql.calcite.rel.DruidQuery;
import org.apache.druid.sql.calcite.run.SqlEngine;
@ -162,6 +160,7 @@ import org.mockito.Mockito;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
@ -207,7 +206,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
public static final Map<String, Object> DEFAULT_MSQ_CONTEXT =
ImmutableMap.<String, Object>builder()
.put(MultiStageQueryContext.CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, true)
.put(PlannerContext.CTX_SQL_QUERY_ID, "test-query")
.put(QueryContexts.CTX_SQL_QUERY_ID, "test-query")
.put(QueryContexts.FINALIZE_KEY, true)
.build();
@ -587,7 +586,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
final DirectStatement stmt = sqlStatementFactory.directStatement(
new SqlQueryPlus(
query,
new QueryContext(context),
context,
Collections.emptyList(),
CalciteTests.REGULAR_USER_AUTH_RESULT
)

View File

@ -27,6 +27,7 @@ import org.junit.Assert;
import org.junit.Test;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
@ -46,33 +47,33 @@ public class MultiStageQueryContextTest
@Test
public void isDurableStorageEnabled_noParameterSetReturnsDefaultValue()
{
Assert.assertFalse(MultiStageQueryContext.isDurableStorageEnabled(ImmutableMap.of()));
Assert.assertFalse(MultiStageQueryContext.isDurableStorageEnabled(QueryContext.empty()));
}
@Test
public void isDurableStorageEnabled_parameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, "true");
Assert.assertTrue(MultiStageQueryContext.isDurableStorageEnabled(propertyMap));
Assert.assertTrue(MultiStageQueryContext.isDurableStorageEnabled(QueryContext.of(propertyMap)));
}
@Test
public void isFinalizeAggregations_noParameterSetReturnsDefaultValue()
{
Assert.assertTrue(MultiStageQueryContext.isFinalizeAggregations(new QueryContext()));
Assert.assertTrue(MultiStageQueryContext.isFinalizeAggregations(QueryContext.empty()));
}
@Test
public void isFinalizeAggregations_parameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_FINALIZE_AGGREGATIONS, "false");
Assert.assertFalse(MultiStageQueryContext.isFinalizeAggregations(new QueryContext(propertyMap)));
Assert.assertFalse(MultiStageQueryContext.isFinalizeAggregations(QueryContext.of(propertyMap)));
}
@Test
public void getAssignmentStrategy_noParameterSetReturnsDefaultValue()
{
Assert.assertEquals(WorkerAssignmentStrategy.MAX, MultiStageQueryContext.getAssignmentStrategy(new QueryContext()));
Assert.assertEquals(WorkerAssignmentStrategy.MAX, MultiStageQueryContext.getAssignmentStrategy(QueryContext.empty()));
}
@Test
@ -81,67 +82,67 @@ public class MultiStageQueryContextTest
Map<String, Object> propertyMap = ImmutableMap.of(CTX_TASK_ASSIGNMENT_STRATEGY, "AUTO");
Assert.assertEquals(
WorkerAssignmentStrategy.AUTO,
MultiStageQueryContext.getAssignmentStrategy(new QueryContext(propertyMap))
MultiStageQueryContext.getAssignmentStrategy(QueryContext.of(propertyMap))
);
}
@Test
public void getMaxNumTasks_noParameterSetReturnsDefaultValue()
{
Assert.assertEquals(DEFAULT_MAX_NUM_TASKS, MultiStageQueryContext.getMaxNumTasks(new QueryContext()));
Assert.assertEquals(DEFAULT_MAX_NUM_TASKS, MultiStageQueryContext.getMaxNumTasks(QueryContext.empty()));
}
@Test
public void getMaxNumTasks_parameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_MAX_NUM_TASKS, 101);
Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(new QueryContext(propertyMap)));
Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(QueryContext.of(propertyMap)));
}
@Test
public void getMaxNumTasks_legacyParameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_MAX_NUM_TASKS, 101);
Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(new QueryContext(propertyMap)));
Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(QueryContext.of(propertyMap)));
}
@Test
public void getDestination_noParameterSetReturnsDefaultValue()
{
Assert.assertNull(MultiStageQueryContext.getDestination(new QueryContext()));
Assert.assertNull(MultiStageQueryContext.getDestination(QueryContext.empty()));
}
@Test
public void getDestination_parameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_DESTINATION, "dataSource");
Assert.assertEquals("dataSource", MultiStageQueryContext.getDestination(new QueryContext(propertyMap)));
Assert.assertEquals("dataSource", MultiStageQueryContext.getDestination(QueryContext.of(propertyMap)));
}
@Test
public void getRowsPerSegment_noParameterSetReturnsDefaultValue()
{
Assert.assertEquals(1000, MultiStageQueryContext.getRowsPerSegment(new QueryContext(), 1000));
Assert.assertEquals(1000, MultiStageQueryContext.getRowsPerSegment(QueryContext.empty(), 1000));
}
@Test
public void getRowsPerSegment_parameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_ROWS_PER_SEGMENT, 10);
Assert.assertEquals(10, MultiStageQueryContext.getRowsPerSegment(new QueryContext(propertyMap), 1000));
Assert.assertEquals(10, MultiStageQueryContext.getRowsPerSegment(QueryContext.of(propertyMap), 1000));
}
@Test
public void getRowsInMemory_noParameterSetReturnsDefaultValue()
{
Assert.assertEquals(1000, MultiStageQueryContext.getRowsInMemory(new QueryContext(), 1000));
Assert.assertEquals(1000, MultiStageQueryContext.getRowsInMemory(QueryContext.empty(), 1000));
}
@Test
public void getRowsInMemory_parameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_ROWS_IN_MEMORY, 10);
Assert.assertEquals(10, MultiStageQueryContext.getRowsInMemory(new QueryContext(propertyMap), 1000));
Assert.assertEquals(10, MultiStageQueryContext.getRowsInMemory(QueryContext.of(propertyMap), 1000));
}
@Test
@ -161,27 +162,27 @@ public class MultiStageQueryContextTest
@Test
public void getSortOrderNoParameterSetReturnsDefaultValue()
{
Assert.assertNull(MultiStageQueryContext.getSortOrder(new QueryContext()));
Assert.assertNull(MultiStageQueryContext.getSortOrder(QueryContext.empty()));
}
@Test
public void getSortOrderParameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_SORT_ORDER, "a, b,\"c,d\"");
Assert.assertEquals("a, b,\"c,d\"", MultiStageQueryContext.getSortOrder(new QueryContext(propertyMap)));
Assert.assertEquals("a, b,\"c,d\"", MultiStageQueryContext.getSortOrder(QueryContext.of(propertyMap)));
}
@Test
public void getMSQModeNoParameterSetReturnsDefaultValue()
{
Assert.assertEquals("strict", MultiStageQueryContext.getMSQMode(new QueryContext()));
Assert.assertEquals("strict", MultiStageQueryContext.getMSQMode(QueryContext.empty()));
}
@Test
public void getMSQModeParameterSetReturnsCorrectValue()
{
Map<String, Object> propertyMap = ImmutableMap.of(CTX_MSQ_MODE, "nonStrict");
Assert.assertEquals("nonStrict", MultiStageQueryContext.getMSQMode(new QueryContext(propertyMap)));
Assert.assertEquals("nonStrict", MultiStageQueryContext.getMSQMode(QueryContext.of(propertyMap)));
}
private static List<String> decodeSortOrder(@Nullable final String input)

View File

@ -34,6 +34,7 @@ import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryProcessingPool;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryRunnerFactory;
@ -127,7 +128,8 @@ public class ServerManagerForQueryErrorTest extends ServerManager
Optional<byte[]> cacheKeyPrefix
)
{
if (query.getContextBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) {
final QueryContext queryContext = query.context();
if (queryContext.getBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) {
final MutableBoolean isIgnoreSegment = new MutableBoolean(false);
queryToIgnoredSegments.compute(
query.getMostSpecificId(),
@ -147,7 +149,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
LOG.info("Pretending I don't have segment [%s]", descriptor);
return new ReportTimelineMissingSegmentQueryRunner<>(descriptor);
}
} else if (query.getContextBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) {
} else if (queryContext.getBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>()
{
@Override
@ -162,7 +164,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new QueryTimeoutException("query timeout test");
}
};
} else if (query.getContextBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) {
} else if (queryContext.getBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>()
{
@Override
@ -177,7 +179,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw QueryCapacityExceededException.withErrorMessageAndResolvedHost("query capacity exceeded test");
}
};
} else if (query.getContextBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) {
} else if (queryContext.getBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>()
{
@Override
@ -192,7 +194,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new QueryUnsupportedException("query unsupported test");
}
};
} else if (query.getContextBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) {
} else if (queryContext.getBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>()
{
@Override
@ -207,7 +209,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new ResourceLimitExceededException("resource limit exceeded test");
}
};
} else if (query.getContextBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) {
} else if (queryContext.getBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>()
{
@Override

View File

@ -34,6 +34,7 @@ import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryProcessingPool;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryRunnerFactory;
@ -125,7 +126,8 @@ public class ServerManagerForQueryErrorTest extends ServerManager
Optional<byte[]> cacheKeyPrefix
)
{
if (query.getContextBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) {
final QueryContext queryContext = query.context();
if (queryContext.getBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) {
final MutableBoolean isIgnoreSegment = new MutableBoolean(false);
queryToIgnoredSegments.compute(
query.getMostSpecificId(),
@ -145,7 +147,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
LOG.info("Pretending I don't have segment[%s]", descriptor);
return new ReportTimelineMissingSegmentQueryRunner<>(descriptor);
}
} else if (query.getContextBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) {
} else if (queryContext.getBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>()
{
@Override
@ -160,7 +162,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new QueryTimeoutException("query timeout test");
}
};
} else if (query.getContextBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) {
} else if (queryContext.getBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>()
{
@Override
@ -175,7 +177,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw QueryCapacityExceededException.withErrorMessageAndResolvedHost("query capacity exceeded test");
}
};
} else if (query.getContextBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) {
} else if (queryContext.getBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>()
{
@Override
@ -190,7 +192,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new QueryUnsupportedException("query unsupported test");
}
};
} else if (query.getContextBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) {
} else if (queryContext.getBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>()
{
@Override
@ -205,7 +207,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new ResourceLimitExceededException("resource limit exceeded test");
}
};
} else if (query.getContextBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) {
} else if (queryContext.getBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>()
{
@Override

View File

@ -547,7 +547,7 @@ public abstract class AbstractAuthConfigurationTest
public void test_sqlQueryWithContext_datasourceOnlyUser_fail() throws Exception
{
final String query = "select count(*) from auth_test";
StatusResponseHolder responseHolder = makeSQLQueryRequest(
makeSQLQueryRequest(
getHttpClient(User.DATASOURCE_ONLY_USER),
query,
ImmutableMap.of("auth_test_ctx", "should-be-denied"),
@ -559,7 +559,7 @@ public abstract class AbstractAuthConfigurationTest
public void test_sqlQueryWithContext_datasourceAndContextParamsUser_succeed() throws Exception
{
final String query = "select count(*) from auth_test";
StatusResponseHolder responseHolder = makeSQLQueryRequest(
makeSQLQueryRequest(
getHttpClient(User.DATASOURCE_AND_CONTEXT_PARAMS_USER),
query,
ImmutableMap.of("auth_test_ctx", "should-be-allowed"),
@ -844,11 +844,6 @@ public abstract class AbstractAuthConfigurationTest
protected void verifyInvalidAuthNameFails(String endpoint)
{
HttpClient adminClient = new CredentialedHttpClient(
new BasicCredentials("admin", "priest"),
httpClient
);
HttpUtil.makeRequestWithExpectedStatus(
getHttpClient(User.ADMIN),
HttpMethod.POST,

View File

@ -32,6 +32,11 @@ public class BadQueryContextException extends BadQueryException
this(ERROR_CODE, e.getMessage(), ERROR_CLASS);
}
public BadQueryContextException(String msg)
{
this(ERROR_CODE, msg, ERROR_CLASS);
}
@JsonCreator
private BadQueryContextException(
@JsonProperty("error") String errorCode,

View File

@ -27,7 +27,6 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Ordering;
import org.apache.druid.guice.annotations.ExtensionPoint;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.granularity.PeriodGranularity;
@ -38,7 +37,6 @@ import org.joda.time.Duration;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -89,7 +87,7 @@ public abstract class BaseQuery<T> implements Query<T>
Preconditions.checkNotNull(granularity, "Must specify a granularity");
this.dataSource = dataSource;
this.context = new QueryContext(context);
this.context = QueryContext.of(context);
this.querySegmentSpec = querySegmentSpec;
this.descending = descending;
this.granularity = granularity;
@ -172,27 +170,15 @@ public abstract class BaseQuery<T> implements Query<T>
@JsonInclude(Include.NON_DEFAULT)
public Map<String, Object> getContext()
{
return context.getMergedParams();
return context.asMap();
}
@Override
public QueryContext getQueryContext()
public QueryContext context()
{
return context;
}
@Override
public boolean getContextBoolean(String key, boolean defaultValue)
{
return context.getAsBoolean(key, defaultValue);
}
@Override
public HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{
return context.getAsHumanReadableBytes(key, defaultValue);
}
/**
* @deprecated use {@link #computeOverriddenContext(Map, Map) computeOverriddenContext(getContext(), overrides))}
* instead. This method may be removed in the next minor or major version of Druid.
@ -228,7 +214,7 @@ public abstract class BaseQuery<T> implements Query<T>
@Override
public String getId()
{
return context.getAsString(QUERY_ID);
return context().getString(QUERY_ID);
}
@Override
@ -241,7 +227,7 @@ public abstract class BaseQuery<T> implements Query<T>
@Override
public String getSubQueryId()
{
return context.getAsString(SUB_QUERY_ID);
return context().getString(SUB_QUERY_ID);
}
@Override

View File

@ -35,7 +35,7 @@ import java.util.List;
*
* Note that despite the type parameter "T", this runner may not actually return sequences with type T. They
* may really be of type {@code Result<BySegmentResultValue<T>>}, if "bySegment" is set. Downstream consumers
* of the returned sequence must be aware of this, and can use {@link QueryContexts#isBySegment(Query)} to
* of the returned sequence must be aware of this, and can use {@link QueryContext#isBySegment()} to
* know what to expect.
*/
public class BySegmentQueryRunner<T> implements QueryRunner<T>
@ -55,7 +55,7 @@ public class BySegmentQueryRunner<T> implements QueryRunner<T>
@SuppressWarnings("unchecked")
public Sequence<T> run(final QueryPlus<T> queryPlus, ResponseContext responseContext)
{
if (QueryContexts.isBySegment(queryPlus.getQuery())) {
if (queryPlus.getQuery().context().isBySegment()) {
final Sequence<T> baseSequence = base.run(queryPlus, responseContext);
final List<T> results = baseSequence.toList();
return Sequences.simple(

View File

@ -39,7 +39,7 @@ public abstract class BySegmentSkippingQueryRunner<T> implements QueryRunner<T>
@Override
public Sequence<T> run(QueryPlus<T> queryPlus, ResponseContext responseContext)
{
if (QueryContexts.isBySegment(queryPlus.getQuery())) {
if (queryPlus.getQuery().context().isBySegment()) {
return baseRunner.run(queryPlus, responseContext);
}

View File

@ -78,7 +78,7 @@ public class ChainedExecutionQueryRunner<T> implements QueryRunner<T>
public Sequence<T> run(final QueryPlus<T> queryPlus, final ResponseContext responseContext)
{
Query<T> query = queryPlus.getQuery();
final int priority = QueryContexts.getPriority(query);
final int priority = query.context().getPriority();
final Ordering ordering = query.getResultOrdering();
final QueryPlus<T> threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState();
return new BaseSequence<T, Iterator<T>>(
@ -137,9 +137,10 @@ public class ChainedExecutionQueryRunner<T> implements QueryRunner<T>
queryWatcher.registerQueryFuture(query, future);
try {
final QueryContext context = query.context();
return new MergeIterable<>(
QueryContexts.hasTimeout(query) ?
future.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS) :
context.hasTimeout() ?
future.get(context.getTimeout(), TimeUnit.MILLISECONDS) :
future.get(),
ordering.nullsFirst()
).iterator();

View File

@ -56,8 +56,9 @@ public class FinalizeResultsQueryRunner<T> implements QueryRunner<T>
public Sequence<T> run(final QueryPlus<T> queryPlus, ResponseContext responseContext)
{
final Query<T> query = queryPlus.getQuery();
final boolean isBySegment = QueryContexts.isBySegment(query);
final boolean shouldFinalize = QueryContexts.isFinalize(query, true);
final QueryContext queryContext = query.context();
final boolean isBySegment = queryContext.isBySegment();
final boolean shouldFinalize = queryContext.isFinalize(true);
final Query<T> queryToRun;
final Function<T, ?> finalizerFn;

View File

@ -84,8 +84,9 @@ public class GroupByMergedQueryRunner<T> implements QueryRunner<T>
querySpecificConfig
);
final Pair<Queue, Accumulator<Queue, T>> bySegmentAccumulatorPair = GroupByQueryHelper.createBySegmentAccumulatorPair();
final boolean bySegment = QueryContexts.isBySegment(query);
final int priority = QueryContexts.getPriority(query);
final QueryContext queryContext = query.context();
final boolean bySegment = queryContext.isBySegment();
final int priority = queryContext.getPriority();
final QueryPlus<T> threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState();
final List<ListenableFuture<Void>> futures =
Lists.newArrayList(
@ -173,8 +174,9 @@ public class GroupByMergedQueryRunner<T> implements QueryRunner<T>
ListenableFuture<List<Void>> future = Futures.allAsList(futures);
try {
queryWatcher.registerQueryFuture(query, future);
if (QueryContexts.hasTimeout(query)) {
future.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS);
final QueryContext context = query.context();
if (context.hasTimeout()) {
future.get(context.getTimeout(), TimeUnit.MILLISECONDS);
} else {
future.get();
}

View File

@ -20,6 +20,7 @@
package org.apache.druid.query;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.druid.guice.annotations.PublicApi;
@ -36,6 +37,7 @@ import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.ColumnHolder;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@ -293,4 +295,24 @@ public class Queries
return requiredColumns;
}
public static <T> Query<T> withMaxScatterGatherBytes(Query<T> query, long maxScatterGatherBytesLimit)
{
QueryContext context = query.context();
if (!context.containsKey(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY)) {
return query.withOverriddenContext(ImmutableMap.of(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, maxScatterGatherBytesLimit));
}
context.verifyMaxScatterGatherBytes(maxScatterGatherBytesLimit);
return query;
}
public static <T> Query<T> withTimeout(Query<T> query, long timeout)
{
return query.withOverriddenContext(ImmutableMap.of(QueryContexts.TIMEOUT_KEY, timeout));
}
public static <T> Query<T> withDefaultTimeout(Query<T> query, long defaultTimeout)
{
return query.withOverriddenContext(ImmutableMap.of(QueryContexts.DEFAULT_TIMEOUT_KEY, defaultTimeout));
}
}

View File

@ -45,6 +45,7 @@ import org.joda.time.Duration;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -96,64 +97,53 @@ public interface Query<T>
DateTimeZone getTimezone();
/**
* Use {@link #getQueryContext()} instead.
* Returns the context as an (immutable) map.
*/
@Deprecated
Map<String, Object> getContext();
/**
* Returns QueryContext for this query. This type distinguishes between user provided, system default, and system
* generated query context keys so that authorization may be employed directly against the user supplied context
* values.
*
* This method is marked @Nullable, but is only so for backwards compatibility with Druid versions older than 0.23.
* Callers should check if the result of this method is null, and if so, they are dealing with a legacy query
* implementation, and should fall back to using {@link #getContext()} and {@link #withOverriddenContext(Map)} to
* manipulate the query context.
*
* Note for query context serialization and deserialization.
* Currently, once a query is serialized, its queryContext can be different from the original queryContext
* after the query is deserialized back. If the queryContext has any {@link QueryContext#defaultParams} or
* {@link QueryContext#systemParams} in it, those will be found in {@link QueryContext#userParams}
* after it is deserialized. This is because {@link BaseQuery#getContext()} uses
* {@link QueryContext#getMergedParams()} for serialization, and queries accept a map for deserialization.
* Returns the query context as a {@link QueryContext}, which provides
* convenience methods for accessing typed context values. The returned
* instance is a view on top of the context provided by {@link #getContext()}.
* <p>
* The default implementation is for backward compatibility. Derived classes should
* store and return the {@link QueryContext} directly.
*/
@Nullable
default QueryContext getQueryContext()
default QueryContext context()
{
return null;
return QueryContext.of(getContext());
}
/**
* Get context value and cast to ContextType in an unsafe way.
*
* For safe conversion, it's recommended to use following methods instead
* For safe conversion, it's recommended to use following methods instead:
* <p>
* {@link QueryContext#getBoolean(String)} <br/>
* {@link QueryContext#getString(String)} <br/>
* {@link QueryContext#getInt(String)} <br/>
* {@link QueryContext#getLong(String)} <br/>
* {@link QueryContext#getFloat(String)} <br/>
* {@link QueryContext#getEnum(String, Class, Enum)} <br/>
* {@link QueryContext#getHumanReadableBytes(String, HumanReadableBytes)}
*
* {@link QueryContext#getAsBoolean(String)}
* {@link QueryContext#getAsString(String)}
* {@link QueryContext#getAsInt(String)}
* {@link QueryContext#getAsLong(String)}
* {@link QueryContext#getAsFloat(String, float)}
* {@link QueryContext#getAsEnum(String, Class, Enum)}
* {@link QueryContext#getAsHumanReadableBytes(String, HumanReadableBytes)}
* @deprecated use {@code queryContext().get<Type>()} instead
*/
@Deprecated
@SuppressWarnings("unchecked")
@Nullable
default <ContextType> ContextType getContextValue(String key)
{
if (getQueryContext() == null) {
return null;
} else {
return (ContextType) getQueryContext().get(key);
}
return (ContextType) context().get(key);
}
/**
* @deprecated use {@code queryContext().getBoolean()} instead.
*/
@Deprecated
default boolean getContextBoolean(String key, boolean defaultValue)
{
if (getQueryContext() == null) {
return defaultValue;
} else {
return getQueryContext().getAsBoolean(key, defaultValue);
}
return context().getBoolean(key, defaultValue);
}
/**
@ -164,14 +154,12 @@ public interface Query<T>
* @param key The context key value being looked up
* @param defaultValue The default to return if the key value doesn't exist or the context is null.
* @return {@link HumanReadableBytes}
* @deprecated use {@code queryContext().getContextHumanReadableBytes()} instead.
*/
default HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue)
@Deprecated
default HumanReadableBytes getContextHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{
if (getQueryContext() == null) {
return defaultValue;
} else {
return getQueryContext().getAsHumanReadableBytes(key, defaultValue);
}
return context().getHumanReadableBytes(key, defaultValue);
}
boolean isDescending();
@ -230,7 +218,7 @@ public interface Query<T>
@Nullable
default String getSqlQueryId()
{
return getQueryContext().getAsString(BaseQuery.SQL_QUERY_ID);
return context().getString(BaseQuery.SQL_QUERY_ID);
}
/**

View File

@ -20,6 +20,10 @@
package org.apache.druid.query;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.QueryContexts.Vectorize;
import org.apache.druid.segment.QueryableIndexStorageAdapter;
import javax.annotation.Nullable;
@ -29,227 +33,547 @@ import java.util.Objects;
import java.util.TreeMap;
/**
* Holder for query context parameters. There are 3 ways to set context params today.
*
* - Default parameters. These are set mostly via {@link DefaultQueryConfig#context}.
* Auto-generated queryId or sqlQueryId are also set as default parameters. These default parameters can
* be overridden by user or system parameters.
* - User parameters. These are the params set by the user. User params override default parameters but
* are overridden by system parameters.
* - System parameters. These are the params set by the Druid query engine for internal use only.
*
* You can use {@code getX} methods or {@link #getMergedParams()} to compute the context params
* merging 3 types of params above.
*
* Currently, this class is mainly used for query context parameter authorization,
* such as HTTP query endpoints or JDBC endpoint. Its usage can be expanded in the future if we
* want to track user parameters and separate them from others during query processing.
* Immutable holder for query context parameters with typed access methods.
* Code builds up a map of context values from serialization or during
* planning. Once that map is handed to the {@code QueryContext}, that map
* is effectively immutable.
* <p>
* The implementation uses a {@link TreeMap} so that the serialized form of a query
* lists context values in a deterministic order. Jackson will call
* {@code getContext()} on the query, which will call {@link #asMap()} here,
* which returns the sorted {@code TreeMap}.
* <p>
* The {@code TreeMap} is a mutable class. We'd prefer an immutable class, but
* we can choose either ordering or immutability. Since the semantics of the context
* is that it is immutable once it is placed in a query. Code should NEVER get the
* context map from a query and modify it, even if the actual implementation
* allows it.
*/
public class QueryContext
{
private final Map<String, Object> defaultParams;
private final Map<String, Object> userParams;
private final Map<String, Object> systemParams;
private static final QueryContext EMPTY = new QueryContext(null);
/**
* Cache of params merged.
*/
@Nullable
private Map<String, Object> mergedParams;
private final Map<String, Object> context;
public QueryContext()
public QueryContext(Map<String, Object> context)
{
this(null);
// There is no semantic difference between an empty and a null context.
// Ensure that a context always exists to avoid the need to check for
// a null context. Jackson serialization will omit empty contexts.
this.context = context == null
? Collections.emptyMap()
: Collections.unmodifiableMap(new TreeMap<>(context));
}
public QueryContext(@Nullable Map<String, Object> userParams)
public static QueryContext empty()
{
this(
new TreeMap<>(),
userParams == null ? new TreeMap<>() : new TreeMap<>(userParams),
new TreeMap<>()
);
return EMPTY;
}
private QueryContext(
final Map<String, Object> defaultParams,
final Map<String, Object> userParams,
final Map<String, Object> systemParams
)
public static QueryContext of(Map<String, Object> context)
{
this.defaultParams = defaultParams;
this.userParams = userParams;
this.systemParams = systemParams;
this.mergedParams = null;
}
private void invalidateMergedParams()
{
this.mergedParams = null;
return new QueryContext(context);
}
public boolean isEmpty()
{
return defaultParams.isEmpty() && userParams.isEmpty() && systemParams.isEmpty();
return context.isEmpty();
}
public void addDefaultParam(String key, Object val)
public Map<String, Object> asMap()
{
invalidateMergedParams();
defaultParams.put(key, val);
}
public void addDefaultParams(Map<String, Object> defaultParams)
{
invalidateMergedParams();
this.defaultParams.putAll(defaultParams);
}
public void addSystemParam(String key, Object val)
{
invalidateMergedParams();
this.systemParams.put(key, val);
}
public Object removeUserParam(String key)
{
invalidateMergedParams();
return userParams.remove(key);
return context;
}
/**
* Returns only the context parameters the user sets.
* The returned map does not include the parameters that have been removed via {@link #removeUserParam}.
*
* Callers should use {@code getX} methods or {@link #getMergedParams()} instead to use the whole context params.
* Check if the given key is set. If the client will then fetch the value,
* consider using one of the {@code get<Type>(String key)} methods instead:
* they each return {@code null} if the value is not set.
*/
public Map<String, Object> getUserParams()
{
return userParams;
}
public boolean isDebug()
{
return getAsBoolean(QueryContexts.ENABLE_DEBUG, QueryContexts.DEFAULT_ENABLE_DEBUG);
}
public boolean isEnableJoinLeftScanDirect()
{
return getAsBoolean(
QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT,
QueryContexts.DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT
);
}
@SuppressWarnings("unused")
public boolean containsKey(String key)
{
return get(key) != null;
return context.containsKey(key);
}
/**
* Return a value as a generic {@code Object}, returning {@code null} if the
* context value is not set.
*/
@Nullable
public Object get(String key)
{
Object val = systemParams.get(key);
if (val != null) {
return val;
}
val = userParams.get(key);
return val == null ? defaultParams.get(key) : val;
return context.get(key);
}
@SuppressWarnings("unused")
public Object getOrDefault(String key, Object defaultValue)
/**
* Return a value as a generic {@code Object}, returning the default value if the
* context value is not set.
*/
public Object get(String key, Object defaultValue)
{
final Object val = get(key);
return val == null ? defaultValue : val;
}
/**
* Return a value as an {@code String}, returning {@link null} if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
@Nullable
public String getAsString(String key)
public String getString(String key)
{
Object val = get(key);
return val == null ? null : val.toString();
return getString(key, null);
}
public String getAsString(String key, String defaultValue)
public String getString(String key, String defaultValue)
{
Object val = get(key);
return val == null ? defaultValue : val.toString();
return QueryContexts.parseString(context, key, defaultValue);
}
@Nullable
public Boolean getAsBoolean(String key)
/**
* Return a value as an {@code Boolean}, returning {@link null} if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public Boolean getBoolean(final String key)
{
return QueryContexts.getAsBoolean(key, get(key));
}
public boolean getAsBoolean(
final String key,
final boolean defaultValue
)
/**
* Return a value as an {@code boolean}, returning the default value if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public boolean getBoolean(final String key, final boolean defaultValue)
{
return QueryContexts.getAsBoolean(key, get(key), defaultValue);
return QueryContexts.parseBoolean(context, key, defaultValue);
}
public Integer getAsInt(final String key)
/**
* Return a value as an {@code Integer}, returning {@link null} if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public Integer getInt(final String key)
{
return QueryContexts.getAsInt(key, get(key));
}
public int getAsInt(
final String key,
final int defaultValue
)
/**
* Return a value as an {@code int}, returning the default value if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public int getInt(final String key, final int defaultValue)
{
return QueryContexts.getAsInt(key, get(key), defaultValue);
return QueryContexts.parseInt(context, key, defaultValue);
}
public Long getAsLong(final String key)
/**
* Return a value as an {@code Long}, returning {@link null} if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public Long getLong(final String key)
{
return QueryContexts.getAsLong(key, get(key));
}
public long getAsLong(final String key, final long defaultValue)
/**
* Return a value as an {@code long}, returning the default value if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public long getLong(final String key, final long defaultValue)
{
return QueryContexts.getAsLong(key, get(key), defaultValue);
return QueryContexts.parseLong(context, key, defaultValue);
}
public HumanReadableBytes getAsHumanReadableBytes(final String key, final HumanReadableBytes defaultValue)
/**
* Return a value as an {@code Float}, returning {@link null} if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
@SuppressWarnings("unused")
public Float getFloat(final String key)
{
return QueryContexts.getAsHumanReadableBytes(key, get(key), defaultValue);
return QueryContexts.getAsFloat(key, get(key));
}
public float getAsFloat(final String key, final float defaultValue)
/**
* Return a value as an {@code float}, returning the default value if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public float getFloat(final String key, final float defaultValue)
{
return QueryContexts.getAsFloat(key, get(key), defaultValue);
}
public <E extends Enum<E>> E getAsEnum(String key, Class<E> clazz, E defaultValue)
public HumanReadableBytes getHumanReadableBytes(final String key, final HumanReadableBytes defaultValue)
{
return QueryContexts.getAsHumanReadableBytes(key, get(key), defaultValue);
}
public <E extends Enum<E>> E getEnum(String key, Class<E> clazz, E defaultValue)
{
return QueryContexts.getAsEnum(key, get(key), clazz, defaultValue);
}
public Map<String, Object> getMergedParams()
public Granularity getGranularity(String key)
{
if (mergedParams == null) {
final Map<String, Object> merged = new TreeMap<>(defaultParams);
merged.putAll(userParams);
merged.putAll(systemParams);
mergedParams = Collections.unmodifiableMap(merged);
final Object value = get(key);
if (value == null) {
return null;
}
if (value instanceof Granularity) {
return (Granularity) value;
} else {
throw QueryContexts.badTypeException(key, "a Granularity", value);
}
return mergedParams;
}
public QueryContext copy()
public boolean isDebug()
{
return new QueryContext(
new TreeMap<>(defaultParams),
new TreeMap<>(userParams),
new TreeMap<>(systemParams)
return getBoolean(QueryContexts.ENABLE_DEBUG, QueryContexts.DEFAULT_ENABLE_DEBUG);
}
public boolean isBySegment()
{
return isBySegment(QueryContexts.DEFAULT_BY_SEGMENT);
}
public boolean isBySegment(boolean defaultValue)
{
return getBoolean(QueryContexts.BY_SEGMENT_KEY, defaultValue);
}
public boolean isPopulateCache()
{
return isPopulateCache(QueryContexts.DEFAULT_POPULATE_CACHE);
}
public boolean isPopulateCache(boolean defaultValue)
{
return getBoolean(QueryContexts.POPULATE_CACHE_KEY, defaultValue);
}
public boolean isUseCache()
{
return isUseCache(QueryContexts.DEFAULT_USE_CACHE);
}
public boolean isUseCache(boolean defaultValue)
{
return getBoolean(QueryContexts.USE_CACHE_KEY, defaultValue);
}
public boolean isPopulateResultLevelCache()
{
return isPopulateResultLevelCache(QueryContexts.DEFAULT_POPULATE_RESULTLEVEL_CACHE);
}
public boolean isPopulateResultLevelCache(boolean defaultValue)
{
return getBoolean(QueryContexts.POPULATE_RESULT_LEVEL_CACHE_KEY, defaultValue);
}
public boolean isUseResultLevelCache()
{
return isUseResultLevelCache(QueryContexts.DEFAULT_USE_RESULTLEVEL_CACHE);
}
public boolean isUseResultLevelCache(boolean defaultValue)
{
return getBoolean(QueryContexts.USE_RESULT_LEVEL_CACHE_KEY, defaultValue);
}
public boolean isFinalize(boolean defaultValue)
{
return getBoolean(QueryContexts.FINALIZE_KEY, defaultValue);
}
public boolean isSerializeDateTimeAsLong(boolean defaultValue)
{
return getBoolean(QueryContexts.SERIALIZE_DATE_TIME_AS_LONG_KEY, defaultValue);
}
public boolean isSerializeDateTimeAsLongInner(boolean defaultValue)
{
return getBoolean(QueryContexts.SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY, defaultValue);
}
public Vectorize getVectorize()
{
return getVectorize(QueryContexts.DEFAULT_VECTORIZE);
}
public Vectorize getVectorize(Vectorize defaultValue)
{
return getEnum(QueryContexts.VECTORIZE_KEY, Vectorize.class, defaultValue);
}
public Vectorize getVectorizeVirtualColumns()
{
return getVectorizeVirtualColumns(QueryContexts.DEFAULT_VECTORIZE_VIRTUAL_COLUMN);
}
public Vectorize getVectorizeVirtualColumns(Vectorize defaultValue)
{
return getEnum(
QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY,
Vectorize.class,
defaultValue
);
}
public int getVectorSize()
{
return getVectorSize(QueryableIndexStorageAdapter.DEFAULT_VECTOR_SIZE);
}
public int getVectorSize(int defaultSize)
{
return getInt(QueryContexts.VECTOR_SIZE_KEY, defaultSize);
}
public int getMaxSubqueryRows(int defaultSize)
{
return getInt(QueryContexts.MAX_SUBQUERY_ROWS_KEY, defaultSize);
}
public int getUncoveredIntervalsLimit()
{
return getUncoveredIntervalsLimit(QueryContexts.DEFAULT_UNCOVERED_INTERVALS_LIMIT);
}
public int getUncoveredIntervalsLimit(int defaultValue)
{
return getInt(QueryContexts.UNCOVERED_INTERVALS_LIMIT_KEY, defaultValue);
}
public int getPriority()
{
return getPriority(QueryContexts.DEFAULT_PRIORITY);
}
public int getPriority(int defaultValue)
{
return getInt(QueryContexts.PRIORITY_KEY, defaultValue);
}
public String getLane()
{
return getString(QueryContexts.LANE_KEY);
}
public boolean getEnableParallelMerges()
{
return getBoolean(
QueryContexts.BROKER_PARALLEL_MERGE_KEY,
QueryContexts.DEFAULT_ENABLE_PARALLEL_MERGE
);
}
public int getParallelMergeInitialYieldRows(int defaultValue)
{
return getInt(QueryContexts.BROKER_PARALLEL_MERGE_INITIAL_YIELD_ROWS_KEY, defaultValue);
}
public int getParallelMergeSmallBatchRows(int defaultValue)
{
return getInt(QueryContexts.BROKER_PARALLEL_MERGE_SMALL_BATCH_ROWS_KEY, defaultValue);
}
public int getParallelMergeParallelism(int defaultValue)
{
return getInt(QueryContexts.BROKER_PARALLELISM, defaultValue);
}
public long getJoinFilterRewriteMaxSize()
{
return getLong(
QueryContexts.JOIN_FILTER_REWRITE_MAX_SIZE_KEY,
QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE
);
}
public boolean getEnableJoinFilterPushDown()
{
return getBoolean(
QueryContexts.JOIN_FILTER_PUSH_DOWN_KEY,
QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN
);
}
public boolean getEnableJoinFilterRewrite()
{
return getBoolean(
QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY,
QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE
);
}
public boolean isSecondaryPartitionPruningEnabled()
{
return getBoolean(
QueryContexts.SECONDARY_PARTITION_PRUNING_KEY,
QueryContexts.DEFAULT_SECONDARY_PARTITION_PRUNING
);
}
public long getMaxQueuedBytes(long defaultValue)
{
return getLong(QueryContexts.MAX_QUEUED_BYTES_KEY, defaultValue);
}
public long getMaxScatterGatherBytes()
{
return getLong(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, Long.MAX_VALUE);
}
public boolean hasTimeout()
{
return getTimeout() != QueryContexts.NO_TIMEOUT;
}
public long getTimeout()
{
return getTimeout(getDefaultTimeout());
}
public long getTimeout(long defaultTimeout)
{
final long timeout = getLong(QueryContexts.TIMEOUT_KEY, defaultTimeout);
if (timeout >= 0) {
return timeout;
}
throw new BadQueryContextException(
StringUtils.format(
"Timeout [%s] must be a non negative value, but was %d",
QueryContexts.TIMEOUT_KEY,
timeout
)
);
}
public long getDefaultTimeout()
{
final long defaultTimeout = getLong(QueryContexts.DEFAULT_TIMEOUT_KEY, QueryContexts.DEFAULT_TIMEOUT_MILLIS);
if (defaultTimeout >= 0) {
return defaultTimeout;
}
throw new BadQueryContextException(
StringUtils.format(
"Timeout [%s] must be a non negative value, but was %d",
QueryContexts.DEFAULT_TIMEOUT_KEY,
defaultTimeout
)
);
}
public void verifyMaxQueryTimeout(long maxQueryTimeout)
{
long timeout = getTimeout();
if (timeout > maxQueryTimeout) {
throw new BadQueryContextException(
StringUtils.format(
"Configured %s = %d is more than enforced limit of %d.",
QueryContexts.TIMEOUT_KEY,
timeout,
maxQueryTimeout
)
);
}
}
public void verifyMaxScatterGatherBytes(long maxScatterGatherBytesLimit)
{
long curr = getLong(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, 0);
if (curr > maxScatterGatherBytesLimit) {
throw new BadQueryContextException(
StringUtils.format(
"Configured %s = %d is more than enforced limit of %d.",
QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY,
curr,
maxScatterGatherBytesLimit
)
);
}
}
public int getNumRetriesOnMissingSegments(int defaultValue)
{
return getInt(QueryContexts.NUM_RETRIES_ON_MISSING_SEGMENTS_KEY, defaultValue);
}
public boolean allowReturnPartialResults(boolean defaultValue)
{
return getBoolean(QueryContexts.RETURN_PARTIAL_RESULTS_KEY, defaultValue);
}
public boolean getEnableJoinFilterRewriteValueColumnFilters()
{
return getBoolean(
QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY,
QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS
);
}
public boolean getEnableRewriteJoinToFilter()
{
return getBoolean(
QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY,
QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER
);
}
public boolean getEnableJoinLeftScanDirect()
{
return getBoolean(
QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT,
QueryContexts.DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT
);
}
public int getInSubQueryThreshold()
{
return getInSubQueryThreshold(QueryContexts.DEFAULT_IN_SUB_QUERY_THRESHOLD);
}
public int getInSubQueryThreshold(int defaultValue)
{
return getInt(
QueryContexts.IN_SUB_QUERY_THRESHOLD_KEY,
defaultValue
);
}
public boolean isTimeBoundaryPlanningEnabled()
{
return getBoolean(
QueryContexts.TIME_BOUNDARY_PLANNING_KEY,
QueryContexts.DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING
);
}
public String getBrokerServiceName()
{
return getString(QueryContexts.BROKER_SERVICE_NAME);
}
@Override
public boolean equals(Object o)
{
@ -259,23 +583,21 @@ public class QueryContext
if (o == null || getClass() != o.getClass()) {
return false;
}
QueryContext context = (QueryContext) o;
return getMergedParams().equals(context.getMergedParams());
QueryContext other = (QueryContext) o;
return context.equals(other.context);
}
@Override
public int hashCode()
{
return Objects.hash(getMergedParams());
return Objects.hash(context);
}
@Override
public String toString()
{
return "QueryContext{" +
"defaultParams=" + defaultParams +
", userParams=" + userParams +
", systemParams=" + systemParams +
"context=" + context +
'}';
}
}

View File

@ -21,19 +21,19 @@ package org.apache.druid.query;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.guice.annotations.PublicApi;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.QueryableIndexStorageAdapter;
import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import java.util.Map.Entry;
import java.util.concurrent.TimeUnit;
@PublicApi
@ -80,7 +80,13 @@ public class QueryContexts
public static final String SERIALIZE_DATE_TIME_AS_LONG_KEY = "serializeDateTimeAsLong";
public static final String SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY = "serializeDateTimeAsLongInner";
public static final String UNCOVERED_INTERVALS_LIMIT_KEY = "uncoveredIntervalsLimit";
public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold";
// SQL query context keys
public static final String CTX_SQL_QUERY_ID = BaseQuery.SQL_QUERY_ID;
public static final String CTX_SQL_STRINGIFY_ARRAYS = "sqlStringifyArrays";
// Defaults
public static final boolean DEFAULT_BY_SEGMENT = false;
public static final boolean DEFAULT_POPULATE_CACHE = true;
public static final boolean DEFAULT_USE_CACHE = true;
@ -150,332 +156,42 @@ public class QueryContexts
}
}
public static <T> boolean isBySegment(Query<T> query)
private QueryContexts()
{
return isBySegment(query, DEFAULT_BY_SEGMENT);
}
public static <T> boolean isBySegment(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(BY_SEGMENT_KEY, defaultValue);
}
public static <T> boolean isPopulateCache(Query<T> query)
{
return isPopulateCache(query, DEFAULT_POPULATE_CACHE);
}
public static <T> boolean isPopulateCache(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(POPULATE_CACHE_KEY, defaultValue);
}
public static <T> boolean isUseCache(Query<T> query)
{
return isUseCache(query, DEFAULT_USE_CACHE);
}
public static <T> boolean isUseCache(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(USE_CACHE_KEY, defaultValue);
}
public static <T> boolean isPopulateResultLevelCache(Query<T> query)
{
return isPopulateResultLevelCache(query, DEFAULT_POPULATE_RESULTLEVEL_CACHE);
}
public static <T> boolean isPopulateResultLevelCache(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(POPULATE_RESULT_LEVEL_CACHE_KEY, defaultValue);
}
public static <T> boolean isUseResultLevelCache(Query<T> query)
{
return isUseResultLevelCache(query, DEFAULT_USE_RESULTLEVEL_CACHE);
}
public static <T> boolean isUseResultLevelCache(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(USE_RESULT_LEVEL_CACHE_KEY, defaultValue);
}
public static <T> boolean isFinalize(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(FINALIZE_KEY, defaultValue);
}
public static <T> boolean isSerializeDateTimeAsLong(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(SERIALIZE_DATE_TIME_AS_LONG_KEY, defaultValue);
}
public static <T> boolean isSerializeDateTimeAsLongInner(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY, defaultValue);
}
public static <T> Vectorize getVectorize(Query<T> query)
{
return getVectorize(query, QueryContexts.DEFAULT_VECTORIZE);
}
public static <T> Vectorize getVectorize(Query<T> query, Vectorize defaultValue)
{
return query.getQueryContext().getAsEnum(VECTORIZE_KEY, Vectorize.class, defaultValue);
}
public static <T> Vectorize getVectorizeVirtualColumns(Query<T> query)
{
return getVectorizeVirtualColumns(query, QueryContexts.DEFAULT_VECTORIZE_VIRTUAL_COLUMN);
}
public static <T> Vectorize getVectorizeVirtualColumns(Query<T> query, Vectorize defaultValue)
{
return query.getQueryContext().getAsEnum(VECTORIZE_VIRTUAL_COLUMNS_KEY, Vectorize.class, defaultValue);
}
public static <T> int getVectorSize(Query<T> query)
{
return getVectorSize(query, QueryableIndexStorageAdapter.DEFAULT_VECTOR_SIZE);
}
public static <T> int getVectorSize(Query<T> query, int defaultSize)
{
return query.getQueryContext().getAsInt(VECTOR_SIZE_KEY, defaultSize);
}
public static <T> int getMaxSubqueryRows(Query<T> query, int defaultSize)
{
return query.getQueryContext().getAsInt(MAX_SUBQUERY_ROWS_KEY, defaultSize);
}
public static <T> int getUncoveredIntervalsLimit(Query<T> query)
{
return getUncoveredIntervalsLimit(query, DEFAULT_UNCOVERED_INTERVALS_LIMIT);
}
public static <T> int getUncoveredIntervalsLimit(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(UNCOVERED_INTERVALS_LIMIT_KEY, defaultValue);
}
public static <T> int getPriority(Query<T> query)
{
return getPriority(query, DEFAULT_PRIORITY);
}
public static <T> int getPriority(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(PRIORITY_KEY, defaultValue);
}
public static <T> String getLane(Query<T> query)
{
return query.getQueryContext().getAsString(LANE_KEY);
}
public static <T> boolean getEnableParallelMerges(Query<T> query)
{
return query.getContextBoolean(BROKER_PARALLEL_MERGE_KEY, DEFAULT_ENABLE_PARALLEL_MERGE);
}
public static <T> int getParallelMergeInitialYieldRows(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(BROKER_PARALLEL_MERGE_INITIAL_YIELD_ROWS_KEY, defaultValue);
}
public static <T> int getParallelMergeSmallBatchRows(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(BROKER_PARALLEL_MERGE_SMALL_BATCH_ROWS_KEY, defaultValue);
}
public static <T> int getParallelMergeParallelism(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(BROKER_PARALLELISM, defaultValue);
}
public static <T> boolean getEnableJoinFilterRewriteValueColumnFilters(Query<T> query)
{
return query.getContextBoolean(
JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY,
DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS
);
}
public static <T> boolean getEnableRewriteJoinToFilter(Query<T> query)
{
return query.getContextBoolean(
REWRITE_JOIN_TO_FILTER_ENABLE_KEY,
DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER
);
}
public static <T> long getJoinFilterRewriteMaxSize(Query<T> query)
{
return query.getQueryContext().getAsLong(JOIN_FILTER_REWRITE_MAX_SIZE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE);
}
public static <T> boolean getEnableJoinFilterPushDown(Query<T> query)
{
return query.getContextBoolean(JOIN_FILTER_PUSH_DOWN_KEY, DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN);
}
public static <T> boolean getEnableJoinFilterRewrite(Query<T> query)
{
return query.getContextBoolean(JOIN_FILTER_REWRITE_ENABLE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE);
}
public static boolean getEnableJoinLeftScanDirect(Map<String, Object> context)
{
return parseBoolean(context, SQL_JOIN_LEFT_SCAN_DIRECT, DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT);
}
public static <T> boolean isSecondaryPartitionPruningEnabled(Query<T> query)
{
return query.getContextBoolean(SECONDARY_PARTITION_PRUNING_KEY, DEFAULT_SECONDARY_PARTITION_PRUNING);
}
public static <T> boolean isDebug(Query<T> query)
{
return query.getContextBoolean(ENABLE_DEBUG, DEFAULT_ENABLE_DEBUG);
}
public static boolean isDebug(Map<String, Object> queryContext)
{
return parseBoolean(queryContext, ENABLE_DEBUG, DEFAULT_ENABLE_DEBUG);
}
public static int getInSubQueryThreshold(Map<String, Object> context)
{
return getInSubQueryThreshold(context, DEFAULT_IN_SUB_QUERY_THRESHOLD);
}
public static int getInSubQueryThreshold(Map<String, Object> context, int defaultValue)
{
return parseInt(context, IN_SUB_QUERY_THRESHOLD_KEY, defaultValue);
}
public static boolean isTimeBoundaryPlanningEnabled(Map<String, Object> queryContext)
{
return parseBoolean(queryContext, TIME_BOUNDARY_PLANNING_KEY, DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING);
}
public static <T> Query<T> withMaxScatterGatherBytes(Query<T> query, long maxScatterGatherBytesLimit)
{
Long curr = query.getQueryContext().getAsLong(MAX_SCATTER_GATHER_BYTES_KEY);
if (curr == null) {
return query.withOverriddenContext(ImmutableMap.of(MAX_SCATTER_GATHER_BYTES_KEY, maxScatterGatherBytesLimit));
} else {
if (curr > maxScatterGatherBytesLimit) {
throw new IAE(
"configured [%s = %s] is more than enforced limit of [%s].",
MAX_SCATTER_GATHER_BYTES_KEY,
curr,
maxScatterGatherBytesLimit
);
} else {
return query;
}
}
}
public static <T> Query<T> verifyMaxQueryTimeout(Query<T> query, long maxQueryTimeout)
{
long timeout = getTimeout(query);
if (timeout > maxQueryTimeout) {
throw new IAE(
"configured [%s = %s] is more than enforced limit of maxQueryTimeout [%s].",
TIMEOUT_KEY,
timeout,
maxQueryTimeout
);
} else {
return query;
}
}
public static <T> long getMaxQueuedBytes(Query<T> query, long defaultValue)
{
return query.getQueryContext().getAsLong(MAX_QUEUED_BYTES_KEY, defaultValue);
}
public static <T> long getMaxScatterGatherBytes(Query<T> query)
{
return query.getQueryContext().getAsLong(MAX_SCATTER_GATHER_BYTES_KEY, Long.MAX_VALUE);
}
public static <T> boolean hasTimeout(Query<T> query)
{
return getTimeout(query) != NO_TIMEOUT;
}
public static <T> long getTimeout(Query<T> query)
{
return getTimeout(query, getDefaultTimeout(query));
}
public static <T> long getTimeout(Query<T> query, long defaultTimeout)
{
try {
final long timeout = query.getQueryContext().getAsLong(TIMEOUT_KEY, defaultTimeout);
Preconditions.checkState(timeout >= 0, "Timeout must be a non negative value, but was [%s]", timeout);
return timeout;
}
catch (IAE e) {
throw new BadQueryContextException(e);
}
}
public static <T> Query<T> withTimeout(Query<T> query, long timeout)
{
return query.withOverriddenContext(ImmutableMap.of(TIMEOUT_KEY, timeout));
}
public static <T> Query<T> withDefaultTimeout(Query<T> query, long defaultTimeout)
{
return query.withOverriddenContext(ImmutableMap.of(QueryContexts.DEFAULT_TIMEOUT_KEY, defaultTimeout));
}
static <T> long getDefaultTimeout(Query<T> query)
{
final long defaultTimeout = query.getQueryContext().getAsLong(DEFAULT_TIMEOUT_KEY, DEFAULT_TIMEOUT_MILLIS);
Preconditions.checkState(defaultTimeout >= 0, "Timeout must be a non negative value, but was [%s]", defaultTimeout);
return defaultTimeout;
}
public static <T> int getNumRetriesOnMissingSegments(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(NUM_RETRIES_ON_MISSING_SEGMENTS_KEY, defaultValue);
}
public static <T> boolean allowReturnPartialResults(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(RETURN_PARTIAL_RESULTS_KEY, defaultValue);
}
public static String getBrokerServiceName(Map<String, Object> queryContext)
{
return queryContext == null ? null : (String) queryContext.get(BROKER_SERVICE_NAME);
}
@SuppressWarnings("unused")
static <T> long parseLong(Map<String, Object> context, String key, long defaultValue)
public static long parseLong(Map<String, Object> context, String key, long defaultValue)
{
return getAsLong(key, context.get(key), defaultValue);
}
static int parseInt(Map<String, Object> context, String key, int defaultValue)
public static int parseInt(Map<String, Object> context, String key, int defaultValue)
{
return getAsInt(key, context.get(key), defaultValue);
}
static boolean parseBoolean(Map<String, Object> context, String key, boolean defaultValue)
@Nullable
public static String parseString(Map<String, Object> context, String key)
{
return parseString(context, key, null);
}
public static boolean parseBoolean(Map<String, Object> context, String key, boolean defaultValue)
{
return getAsBoolean(key, context.get(key), defaultValue);
}
public static String parseString(Map<String, Object> context, String key, String defaultValue)
{
return getAsString(key, context.get(key), defaultValue);
}
@SuppressWarnings("unused") // To keep IntelliJ inspections happy
public static float parseFloat(Map<String, Object> context, String key, float defaultValue)
{
return getAsFloat(key, context.get(key), defaultValue);
}
public static String getAsString(
final String key,
final Object value,
@ -486,14 +202,13 @@ public class QueryContexts
return defaultValue;
} else if (value instanceof String) {
return (String) value;
} else {
throw new IAE("Expected key [%s] to be a String, but got [%s]", key, value.getClass().getName());
}
throw badTypeException(key, "a String", value);
}
@Nullable
public static Boolean getAsBoolean(
final String parameter,
final String key,
final Object value
)
{
@ -503,13 +218,12 @@ public class QueryContexts
return Boolean.parseBoolean((String) value);
} else if (value instanceof Boolean) {
return (Boolean) value;
} else {
throw new IAE("Expected parameter [%s] to be a Boolean, but got [%s]", parameter, value.getClass().getName());
}
throw badTypeException(key, "a Boolean", value);
}
/**
* Get the value of a parameter as a {@code boolean}. The parameter is expected
* Get the value of a context value as a {@code boolean}. The value is expected
* to be {@code null}, a string or a {@code Boolean} object.
*/
public static boolean getAsBoolean(
@ -534,24 +248,33 @@ public class QueryContexts
return Numbers.parseInt(value);
}
catch (NumberFormatException ignored) {
throw new IAE("Expected key [%s] in integer format, but got [%s]", key, value);
// Attempt to handle trivial decimal values: 12.00, etc.
// This mimics how Jackson will convert "12.00" to a Integer on request.
try {
return new BigDecimal((String) value).intValueExact();
}
catch (Exception nfe) {
// That didn't work either. Give up.
throw badValueException(key, "in integer format", value);
}
}
}
throw new IAE("Expected key [%s] to be an Integer, but got [%s]", key, value.getClass().getName());
throw badTypeException(key, "an Integer", value);
}
/**
* Get the value of a parameter as an {@code int}. The parameter is expected
* Get the value of a context value as an {@code int}. The value is expected
* to be {@code null}, a string or a {@code Number} object.
*/
public static int getAsInt(
final String ke,
final String key,
final Object value,
final int defaultValue
)
{
Integer val = getAsInt(ke, value);
Integer val = getAsInt(key, value);
return val == null ? defaultValue : val;
}
@ -567,14 +290,23 @@ public class QueryContexts
return Numbers.parseLong(value);
}
catch (NumberFormatException ignored) {
throw new IAE("Expected key [%s] in long format, but got [%s]", key, value);
// Attempt to handle trivial decimal values: 12.00, etc.
// This mimics how Jackson will convert "12.00" to a Long on request.
try {
return new BigDecimal((String) value).longValueExact();
}
catch (Exception nfe) {
// That didn't work either. Give up.
throw badValueException(key, "in long format", value);
}
}
}
throw new IAE("Expected key [%s] to be a Long, but got [%s]", key, value.getClass().getName());
throw badTypeException(key, "a Long", value);
}
/**
* Get the value of a parameter as an {@code long}. The parameter is expected
* Get the value of a context value as an {@code long}. The value is expected
* to be {@code null}, a string or a {@code Number} object.
*/
public static long getAsLong(
@ -587,8 +319,39 @@ public class QueryContexts
return val == null ? defaultValue : val;
}
/**
* Get the value of a context value as an {@code Float}. The value is expected
* to be {@code null}, a string or a {@code Number} object.
*/
public static Float getAsFloat(final String key, final Object value)
{
if (value == null) {
return null;
} else if (value instanceof Number) {
return ((Number) value).floatValue();
} else if (value instanceof String) {
try {
return Float.parseFloat((String) value);
}
catch (NumberFormatException ignored) {
throw badValueException(key, "in float format", value);
}
}
throw badTypeException(key, "a Float", value);
}
public static float getAsFloat(
final String key,
final Object value,
final float defaultValue
)
{
Float val = getAsFloat(key, value);
return val == null ? defaultValue : val;
}
public static HumanReadableBytes getAsHumanReadableBytes(
final String parameter,
final String key,
final Object value,
final HumanReadableBytes defaultValue
)
@ -602,73 +365,126 @@ public class QueryContexts
return HumanReadableBytes.valueOf(HumanReadableBytes.parse((String) value));
}
catch (IAE e) {
throw new IAE("Expected key [%s] in human readable format, but got [%s]", parameter, value);
throw badValueException(key, "a human readable number", value);
}
}
throw new IAE("Expected key [%s] to be a human readable number, but got [%s]", parameter, value.getClass().getName());
throw badTypeException(key, "a human readable number", value);
}
public static float getAsFloat(String key, Object value, float defaultValue)
/**
* Insert, update or remove a single key to produce an overridden context.
* Leaves the original context unchanged.
*
* @param context context to override
* @param key key to insert, update or remove
* @param value if {@code null}, remove the key. Otherwise, insert or replace
* the key.
* @return a new context map
*/
public static Map<String, Object> override(
final Map<String, Object> context,
final String key,
final Object value
)
{
if (null == value) {
return defaultValue;
} else if (value instanceof Number) {
return ((Number) value).floatValue();
} else if (value instanceof String) {
try {
return Float.parseFloat((String) value);
}
catch (NumberFormatException ignored) {
throw new IAE("Expected key [%s] in float format, but got [%s]", key, value);
}
Map<String, Object> overridden = new HashMap<>(context);
if (value == null) {
overridden.remove(key);
} else {
overridden.put(key, value);
}
throw new IAE("Expected key [%s] to be a Float, but got [%s]", key, value.getClass().getName());
return overridden;
}
/**
* Insert or replace multiple keys to produce an overridden context.
* Leaves the original context unchanged.
*
* @param context context to override
* @param overrides map of values to insert or replace
* @return a new context map
*/
public static Map<String, Object> override(
final Map<String, Object> context,
final Map<String, Object> overrides
)
{
Map<String, Object> overridden = new TreeMap<>();
Map<String, Object> overridden = new HashMap<>();
if (context != null) {
overridden.putAll(context);
}
overridden.putAll(overrides);
if (overrides != null) {
overridden.putAll(overrides);
}
return overridden;
}
private QueryContexts()
public static <E extends Enum<E>> E getAsEnum(String key, Object value, Class<E> clazz, E defaultValue)
{
}
public static <E extends Enum<E>> E getAsEnum(String key, Object val, Class<E> clazz, E defaultValue)
{
if (val == null) {
if (value == null) {
return defaultValue;
}
try {
if (val instanceof String) {
return Enum.valueOf(clazz, StringUtils.toUpperCase((String) val));
} else if (val instanceof Boolean) {
return Enum.valueOf(clazz, StringUtils.toUpperCase(String.valueOf(val)));
if (value instanceof String) {
return Enum.valueOf(clazz, StringUtils.toUpperCase((String) value));
} else if (value instanceof Boolean) {
return Enum.valueOf(clazz, StringUtils.toUpperCase(String.valueOf(value)));
}
}
catch (IllegalArgumentException e) {
throw new IAE("Expected key [%s] must be value of enum [%s], but got [%s].",
key,
clazz.getName(),
val.toString());
throw badValueException(
key,
StringUtils.format("a value of enum [%s]", clazz.getSimpleName()),
value
);
}
throw new ISE(
"Expected key [%s] must be type of [%s], actual type is [%s].",
throw badTypeException(
key,
clazz.getName(),
val.getClass()
StringUtils.format("of type [%s]", clazz.getSimpleName()),
value
);
}
public static BadQueryContextException badValueException(
final String key,
final String expected,
final Object actual
)
{
return new BadQueryContextException(
StringUtils.format(
"Expected key [%s] to be in %s, but got [%s]",
key,
expected,
actual
)
);
}
public static BadQueryContextException badTypeException(
final String key,
final String expected,
final Object actual
)
{
return new BadQueryContextException(
StringUtils.format(
"Expected key [%s] to be %s, but got [%s]",
key,
expected,
actual.getClass().getName()
)
);
}
public static void addDefaults(Map<String, Object> context, Map<String, Object> defaults)
{
for (Entry<String, Object> entry : defaults.entrySet()) {
context.putIfAbsent(entry.getKey(), entry.getValue());
}
}
}

View File

@ -41,7 +41,7 @@ public class SubqueryQueryRunner<T> implements QueryRunner<T>
{
DataSource dataSource = queryPlus.getQuery().getDataSource();
boolean forcePushDownNestedQuery = queryPlus.getQuery()
.getContextBoolean(
.context().getBoolean(
GroupByQueryConfig.CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY,
false
);

View File

@ -450,7 +450,7 @@ public class GroupByQuery extends BaseQuery<ResultRow>
@JsonIgnore
public boolean getContextSortByDimsFirst()
{
return getContextBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false);
return context().getBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false);
}
@JsonIgnore
@ -465,7 +465,7 @@ public class GroupByQuery extends BaseQuery<ResultRow>
@JsonIgnore
public boolean getApplyLimitPushDownFromContext()
{
return getContextBoolean(GroupByQueryConfig.CTX_KEY_APPLY_LIMIT_PUSH_DOWN, true);
return context().getBoolean(GroupByQueryConfig.CTX_KEY_APPLY_LIMIT_PUSH_DOWN, true);
}
@Override
@ -487,7 +487,7 @@ public class GroupByQuery extends BaseQuery<ResultRow>
private boolean validateAndGetForceLimitPushDown()
{
final boolean forcePushDown = getContextBoolean(GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, false);
final boolean forcePushDown = context().getBoolean(GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, false);
if (forcePushDown) {
if (!(limitSpec instanceof DefaultLimitSpec)) {
throw new IAE("When forcing limit push down, a limit spec must be provided.");
@ -748,7 +748,7 @@ public class GroupByQuery extends BaseQuery<ResultRow>
@Nullable
private DateTime computeUniversalTimestamp()
{
final String timestampStringFromContext = getQueryContext().getAsString(CTX_KEY_FUDGE_TIMESTAMP, "");
final String timestampStringFromContext = context().getString(CTX_KEY_FUDGE_TIMESTAMP, "");
final Granularity granularity = getGranularity();
if (!timestampStringFromContext.isEmpty()) {

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.groupby.strategy.GroupByStrategySelector;
import org.apache.druid.utils.JvmUtils;
@ -335,25 +336,26 @@ public class GroupByQueryConfig
public GroupByQueryConfig withOverrides(final GroupByQuery query)
{
final GroupByQueryConfig newConfig = new GroupByQueryConfig();
newConfig.defaultStrategy = query.getQueryContext().getAsString(CTX_KEY_STRATEGY, getDefaultStrategy());
newConfig.singleThreaded = query.getQueryContext().getAsBoolean(CTX_KEY_IS_SINGLE_THREADED, isSingleThreaded());
final QueryContext queryContext = query.context();
newConfig.defaultStrategy = queryContext.getString(CTX_KEY_STRATEGY, getDefaultStrategy());
newConfig.singleThreaded = queryContext.getBoolean(CTX_KEY_IS_SINGLE_THREADED, isSingleThreaded());
newConfig.maxIntermediateRows = Math.min(
query.getQueryContext().getAsInt(CTX_KEY_MAX_INTERMEDIATE_ROWS, getMaxIntermediateRows()),
queryContext.getInt(CTX_KEY_MAX_INTERMEDIATE_ROWS, getMaxIntermediateRows()),
getMaxIntermediateRows()
);
newConfig.maxResults = Math.min(
query.getQueryContext().getAsInt(CTX_KEY_MAX_RESULTS, getMaxResults()),
queryContext.getInt(CTX_KEY_MAX_RESULTS, getMaxResults()),
getMaxResults()
);
newConfig.bufferGrouperMaxSize = Math.min(
query.getQueryContext().getAsInt(CTX_KEY_BUFFER_GROUPER_MAX_SIZE, getBufferGrouperMaxSize()),
queryContext.getInt(CTX_KEY_BUFFER_GROUPER_MAX_SIZE, getBufferGrouperMaxSize()),
getBufferGrouperMaxSize()
);
newConfig.bufferGrouperMaxLoadFactor = query.getQueryContext().getAsFloat(
newConfig.bufferGrouperMaxLoadFactor = queryContext.getFloat(
CTX_KEY_BUFFER_GROUPER_MAX_LOAD_FACTOR,
getBufferGrouperMaxLoadFactor()
);
newConfig.bufferGrouperInitialBuckets = query.getQueryContext().getAsInt(
newConfig.bufferGrouperInitialBuckets = queryContext.getInt(
CTX_KEY_BUFFER_GROUPER_INITIAL_BUCKETS,
getBufferGrouperInitialBuckets()
);
@ -362,33 +364,33 @@ public class GroupByQueryConfig
// choose a default value lower than the max allowed when the context key is missing in the client query.
newConfig.maxOnDiskStorage = HumanReadableBytes.valueOf(
Math.min(
query.getContextAsHumanReadableBytes(CTX_KEY_MAX_ON_DISK_STORAGE, getDefaultOnDiskStorage()).getBytes(),
queryContext.getHumanReadableBytes(CTX_KEY_MAX_ON_DISK_STORAGE, getDefaultOnDiskStorage()).getBytes(),
getMaxOnDiskStorage().getBytes()
)
);
newConfig.maxSelectorDictionarySize = maxSelectorDictionarySize; // No overrides
newConfig.maxMergingDictionarySize = maxMergingDictionarySize; // No overrides
newConfig.forcePushDownLimit = query.getContextBoolean(CTX_KEY_FORCE_LIMIT_PUSH_DOWN, isForcePushDownLimit());
newConfig.applyLimitPushDownToSegment = query.getContextBoolean(
newConfig.forcePushDownLimit = queryContext.getBoolean(CTX_KEY_FORCE_LIMIT_PUSH_DOWN, isForcePushDownLimit());
newConfig.applyLimitPushDownToSegment = queryContext.getBoolean(
CTX_KEY_APPLY_LIMIT_PUSH_DOWN_TO_SEGMENT,
isApplyLimitPushDownToSegment()
);
newConfig.forceHashAggregation = query.getContextBoolean(CTX_KEY_FORCE_HASH_AGGREGATION, isForceHashAggregation());
newConfig.forcePushDownNestedQuery = query.getContextBoolean(
newConfig.forceHashAggregation = queryContext.getBoolean(CTX_KEY_FORCE_HASH_AGGREGATION, isForceHashAggregation());
newConfig.forcePushDownNestedQuery = queryContext.getBoolean(
CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY,
isForcePushDownNestedQuery()
);
newConfig.intermediateCombineDegree = query.getQueryContext().getAsInt(
newConfig.intermediateCombineDegree = queryContext.getInt(
CTX_KEY_INTERMEDIATE_COMBINE_DEGREE,
getIntermediateCombineDegree()
);
newConfig.numParallelCombineThreads = query.getQueryContext().getAsInt(
newConfig.numParallelCombineThreads = queryContext.getInt(
CTX_KEY_NUM_PARALLEL_COMBINE_THREADS,
getNumParallelCombineThreads()
);
newConfig.mergeThreadLocal = query.getContextBoolean(CTX_KEY_MERGE_THREAD_LOCAL, isMergeThreadLocal());
newConfig.vectorize = query.getContextBoolean(QueryContexts.VECTORIZE_KEY, isVectorize());
newConfig.enableMultiValueUnnesting = query.getContextBoolean(
newConfig.mergeThreadLocal = queryContext.getBoolean(CTX_KEY_MERGE_THREAD_LOCAL, isMergeThreadLocal());
newConfig.vectorize = queryContext.getBoolean(QueryContexts.VECTORIZE_KEY, isVectorize());
newConfig.enableMultiValueUnnesting = queryContext.getBoolean(
CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING,
isMultiValueUnnestingEnabled()
);

View File

@ -96,7 +96,7 @@ public class GroupByQueryEngine
"Null storage adapter found. Probably trying to issue a query against a segment being memory unmapped."
);
}
if (!query.getContextBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true)) {
if (!query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true)) {
throw new UOE(
"GroupBy v1 does not support %s as false. Set %s to true or use groupBy v2",
GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING,

View File

@ -100,7 +100,7 @@ public class GroupByQueryHelper
);
final IncrementalIndex index;
final boolean sortResults = query.getContextBoolean(CTX_KEY_SORT_RESULTS, true);
final boolean sortResults = query.context().getBoolean(CTX_KEY_SORT_RESULTS, true);
// All groupBy dimensions are strings, for now.
final List<DimensionSchema> dimensionSchemas = new ArrayList<>();
@ -118,7 +118,7 @@ public class GroupByQueryHelper
final AppendableIndexBuilder indexBuilder;
if (query.getContextBoolean("useOffheap", false)) {
if (query.context().getBoolean("useOffheap", false)) {
throw new UnsupportedOperationException(
"The 'useOffheap' option is no longer available for groupBy v1. Please move to the newer groupBy engine, "
+ "which always operates off-heap, by removing any custom 'druid.query.groupBy.defaultStrategy' runtime "

View File

@ -45,7 +45,6 @@ import org.apache.druid.java.util.common.jackson.JacksonUtils;
import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
@ -118,7 +117,7 @@ public class GroupByQueryQueryToolChest extends QueryToolChest<ResultRow, GroupB
public QueryRunner<ResultRow> mergeResults(final QueryRunner<ResultRow> runner)
{
return (queryPlus, responseContext) -> {
if (QueryContexts.isBySegment(queryPlus.getQuery())) {
if (queryPlus.getQuery().context().isBySegment()) {
return runner.run(queryPlus, responseContext);
}
@ -304,7 +303,7 @@ public class GroupByQueryQueryToolChest extends QueryToolChest<ResultRow, GroupB
private Sequence<ResultRow> finalizeSubqueryResults(Sequence<ResultRow> subqueryResult, GroupByQuery subquery)
{
final Sequence<ResultRow> finalizingResults;
if (QueryContexts.isFinalize(subquery, false)) {
if (subquery.context().isFinalize(false)) {
finalizingResults = new MappedSequence<>(
subqueryResult,
makePreComputeManipulatorFn(
@ -321,7 +320,7 @@ public class GroupByQueryQueryToolChest extends QueryToolChest<ResultRow, GroupB
public static boolean isNestedQueryPushDown(GroupByQuery q, GroupByStrategy strategy)
{
return q.getDataSource() instanceof QueryDataSource
&& q.getContextBoolean(GroupByQueryConfig.CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY, false)
&& q.context().getBoolean(GroupByQueryConfig.CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY, false)
&& q.getSubtotalsSpec() == null
&& strategy.supportsNestedQueryPushDown();
}
@ -418,7 +417,7 @@ public class GroupByQueryQueryToolChest extends QueryToolChest<ResultRow, GroupB
@Override
public ObjectMapper decorateObjectMapper(final ObjectMapper objectMapper, final GroupByQuery query)
{
final boolean resultAsArray = query.getContextBoolean(GroupByQueryConfig.CTX_KEY_ARRAY_RESULT_ROWS, false);
final boolean resultAsArray = query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ARRAY_RESULT_ROWS, false);
if (resultAsArray && !queryConfig.isIntermediateResultAsMapCompat()) {
// We can assume ResultRow are serialized and deserialized as arrays. No need for special decoration,

View File

@ -45,7 +45,7 @@ import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.AbstractPrioritizedQueryRunnerCallable;
import org.apache.druid.query.ChainedExecutionQueryRunner;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryProcessingPool;
@ -134,7 +134,7 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
// merge buffer, otherwise the query will allocate too many merge buffers. This is potentially sub-optimal as it
// will involve materializing the results for each sink before starting to feed them into the outer merge buffer.
// I'm not sure of a better way to do this without tweaking how realtime servers do queries.
final boolean forceChainedExecution = query.getContextBoolean(
final boolean forceChainedExecution = query.context().getBoolean(
CTX_KEY_MERGE_RUNNERS_USING_CHAINED_EXECUTION,
false
);
@ -144,7 +144,8 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
)
.withoutThreadUnsafeState();
if (QueryContexts.isBySegment(query) || forceChainedExecution) {
final QueryContext queryContext = query.context();
if (queryContext.isBySegment() || forceChainedExecution) {
ChainedExecutionQueryRunner<ResultRow> runner = new ChainedExecutionQueryRunner<>(queryProcessingPool, queryWatcher, queryables);
return runner.run(queryPlusForRunners, responseContext);
}
@ -156,12 +157,12 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
StringUtils.format("druid-groupBy-%s_%s", UUID.randomUUID(), query.getId())
);
final int priority = QueryContexts.getPriority(query);
final int priority = queryContext.getPriority();
// Figure out timeoutAt time now, so we can apply the timeout to both the mergeBufferPool.take and the actual
// query processing together.
final long queryTimeout = QueryContexts.getTimeout(query);
final boolean hasTimeout = QueryContexts.hasTimeout(query);
final long queryTimeout = queryContext.getTimeout();
final boolean hasTimeout = queryContext.hasTimeout();
final long timeoutAt = System.currentTimeMillis() + queryTimeout;
return new BaseSequence<>(

View File

@ -34,7 +34,6 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.ColumnSelectorPlus;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.aggregation.AggregatorAdapters;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.dimension.ColumnSelectorStrategyFactory;
@ -77,6 +76,7 @@ import org.joda.time.DateTime;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import java.io.Closeable;
import java.nio.ByteBuffer;
import java.util.Iterator;
@ -141,7 +141,7 @@ public class GroupByQueryEngineV2
try {
final String fudgeTimestampString = NullHandling.emptyToNullIfNeeded(
query.getQueryContext().getAsString(GroupByStrategyV2.CTX_KEY_FUDGE_TIMESTAMP)
query.context().getString(GroupByStrategyV2.CTX_KEY_FUDGE_TIMESTAMP)
);
final DateTime fudgeTimestamp = fudgeTimestampString == null
@ -151,7 +151,7 @@ public class GroupByQueryEngineV2
final Filter filter = Filters.convertToCNFFromQueryContext(query, Filters.toFilter(query.getFilter()));
final Interval interval = Iterables.getOnlyElement(query.getIntervals());
final boolean doVectorize = QueryContexts.getVectorize(query).shouldVectorize(
final boolean doVectorize = query.context().getVectorize().shouldVectorize(
VectorGroupByEngine.canVectorize(query, storageAdapter, filter)
);
@ -496,7 +496,7 @@ public class GroupByQueryEngineV2
// Time is the same for every row in the cursor
this.timestamp = fudgeTimestamp != null ? fudgeTimestamp : cursor.getTime();
this.allSingleValueDims = allSingleValueDims;
this.allowMultiValueGrouping = query.getContextBoolean(
this.allowMultiValueGrouping = query.context().getBoolean(
GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING,
true
);

View File

@ -28,7 +28,6 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.parsers.CloseableIterator;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.aggregation.AggregatorAdapters;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.filter.Filter;
@ -56,6 +55,7 @@ import org.joda.time.DateTime;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collections;
@ -150,7 +150,7 @@ public class VectorGroupByEngine
interval,
query.getVirtualColumns(),
false,
QueryContexts.getVectorSize(query),
query.context().getVectorSize(),
groupByQueryMetrics
);

View File

@ -37,6 +37,7 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.TopNSequence;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.dimension.DimensionSpec;
@ -232,9 +233,11 @@ public class DefaultLimitSpec implements LimitSpec
}
if (!sortingNeeded) {
String timestampField = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD);
final QueryContext queryContext = query.context();
String timestampField = queryContext.getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD);
if (timestampField != null && !timestampField.isEmpty()) {
int timestampResultFieldIndex = query.getQueryContext().getAsInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX);
// Will NPE if the key is not set
int timestampResultFieldIndex = queryContext.getInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX);
sortingNeeded = query.getContextSortByDimsFirst()
? timestampResultFieldIndex != query.getDimensions().size() - 1
: timestampResultFieldIndex != 0;

View File

@ -91,7 +91,7 @@ public class GroupByStrategyV1 implements GroupByStrategy
@Override
public boolean doMergeResults(final GroupByQuery query)
{
return query.getContextBoolean(GroupByQueryQueryToolChest.GROUP_BY_MERGE_KEY, true);
return query.context().getBoolean(GroupByQueryQueryToolChest.GROUP_BY_MERGE_KEY, true);
}
@Override

View File

@ -44,6 +44,7 @@ import org.apache.druid.query.DataSource;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryPlus;
@ -132,8 +133,9 @@ public class GroupByStrategyV2 implements GroupByStrategy
return new GroupByQueryResource();
} else {
final List<ReferenceCountingResourceHolder<ByteBuffer>> mergeBufferHolders;
if (QueryContexts.hasTimeout(query)) {
mergeBufferHolders = mergeBufferPool.takeBatch(requiredMergeBufferNum, QueryContexts.getTimeout(query));
final QueryContext context = query.context();
if (context.hasTimeout()) {
mergeBufferHolders = mergeBufferPool.takeBatch(requiredMergeBufferNum, context.getTimeout());
} else {
mergeBufferHolders = mergeBufferPool.takeBatch(requiredMergeBufferNum);
}
@ -221,9 +223,10 @@ public class GroupByStrategyV2 implements GroupByStrategy
Granularity granularity = query.getGranularity();
List<DimensionSpec> dimensionSpecs = query.getDimensions();
// the CTX_TIMESTAMP_RESULT_FIELD is set in DruidQuery.java
final String timestampResultField = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD);
final QueryContext queryContext = query.context();
final String timestampResultField = queryContext.getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD);
final boolean hasTimestampResultField = (timestampResultField != null && !timestampResultField.isEmpty())
&& query.getContextBoolean(CTX_KEY_OUTERMOST, true)
&& queryContext.getBoolean(CTX_KEY_OUTERMOST, true)
&& !query.isApplyLimitPushDown();
int timestampResultFieldIndex = 0;
if (hasTimestampResultField) {
@ -249,7 +252,7 @@ public class GroupByStrategyV2 implements GroupByStrategy
// the granularity and dimensions are slightly different.
// now, part of the query plan logic is handled in GroupByStrategyV2, not only in DruidQuery.toGroupByQuery()
final Granularity timestampResultFieldGranularity
= query.getContextValue(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY);
= queryContext.getGranularity(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY);
dimensionSpecs =
query.getDimensions()
.stream()
@ -258,7 +261,7 @@ public class GroupByStrategyV2 implements GroupByStrategy
granularity = timestampResultFieldGranularity;
// when timestampResultField is the last dimension, should set sortByDimsFirst=true,
// otherwise the downstream is sorted by row's timestamp first which makes the final ordering not as expected
timestampResultFieldIndex = query.getQueryContext().getAsInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX);
timestampResultFieldIndex = queryContext.getInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX);
if (!query.getContextSortByDimsFirst() && timestampResultFieldIndex == query.getDimensions().size() - 1) {
context.put(GroupByQuery.CTX_KEY_SORT_BY_DIMS_FIRST, true);
}
@ -312,8 +315,8 @@ public class GroupByStrategyV2 implements GroupByStrategy
// Apply postaggregators if this is the outermost mergeResults (CTX_KEY_OUTERMOST) and we are not executing a
// pushed-down subquery (CTX_KEY_EXECUTING_NESTED_QUERY).
if (!query.getContextBoolean(CTX_KEY_OUTERMOST, true)
|| query.getContextBoolean(GroupByQueryConfig.CTX_KEY_EXECUTING_NESTED_QUERY, false)) {
if (!queryContext.getBoolean(CTX_KEY_OUTERMOST, true)
|| queryContext.getBoolean(GroupByQueryConfig.CTX_KEY_EXECUTING_NESTED_QUERY, false)) {
return mergedResults;
} else if (query.getPostAggregatorSpecs().isEmpty()) {
if (!hasTimestampResultField) {
@ -405,7 +408,7 @@ public class GroupByStrategyV2 implements GroupByStrategy
public Sequence<ResultRow> applyPostProcessing(Sequence<ResultRow> results, GroupByQuery query)
{
// Don't apply limit here for inner results, that will be pushed down to the BufferHashGrouper
if (query.getContextBoolean(CTX_KEY_OUTERMOST, true)) {
if (query.context().getBoolean(CTX_KEY_OUTERMOST, true)) {
return query.postProcess(results);
} else {
return results;

View File

@ -31,7 +31,7 @@ import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.AbstractPrioritizedQueryRunnerCallable;
import org.apache.druid.query.ConcatQueryRunner;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryProcessingPool;
@ -205,7 +205,7 @@ public class SegmentMetadataQueryRunnerFactory implements QueryRunnerFactory<Seg
)
{
final Query<SegmentAnalysis> query = queryPlus.getQuery();
final int priority = QueryContexts.getPriority(query);
final int priority = query.context().getPriority();
final QueryPlus<SegmentAnalysis> threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState();
ListenableFuture<Sequence<SegmentAnalysis>> future = queryProcessingPool.submitRunnerTask(
new AbstractPrioritizedQueryRunnerCallable<Sequence<SegmentAnalysis>, SegmentAnalysis>(priority, input)
@ -219,8 +219,9 @@ public class SegmentMetadataQueryRunnerFactory implements QueryRunnerFactory<Seg
);
try {
queryWatcher.registerQueryFuture(query, future);
if (QueryContexts.hasTimeout(query)) {
return future.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS);
final QueryContext context = query.context();
if (context.hasTimeout()) {
return future.get(context.getTimeout(), TimeUnit.MILLISECONDS);
} else {
return future.get();
}

View File

@ -264,7 +264,7 @@ public class ScanQuery extends BaseQuery<ScanResultValue>
private Integer validateAndGetMaxRowsQueuedForOrdering()
{
final Integer maxRowsQueuedForOrdering =
getQueryContext().getAsInt(ScanQueryConfig.CTX_KEY_MAX_ROWS_QUEUED_FOR_ORDERING);
context().getInt(ScanQueryConfig.CTX_KEY_MAX_ROWS_QUEUED_FOR_ORDERING);
Preconditions.checkArgument(
maxRowsQueuedForOrdering == null || maxRowsQueuedForOrdering > 0,
"maxRowsQueuedForOrdering must be greater than 0"
@ -275,7 +275,7 @@ public class ScanQuery extends BaseQuery<ScanResultValue>
private Integer validateAndGetMaxSegmentPartitionsOrderedInMemory()
{
final Integer maxSegmentPartitionsOrderedInMemory =
getQueryContext().getAsInt(ScanQueryConfig.CTX_KEY_MAX_SEGMENT_PARTITIONS_FOR_ORDERING);
context().getInt(ScanQueryConfig.CTX_KEY_MAX_SEGMENT_PARTITIONS_FOR_ORDERING);
Preconditions.checkArgument(
maxSegmentPartitionsOrderedInMemory == null || maxSegmentPartitionsOrderedInMemory > 0,
"maxRowsQueuedForOrdering must be greater than 0"

View File

@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.BaseSequence;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryTimeoutException;
import org.apache.druid.query.context.ResponseContext;
@ -78,7 +77,7 @@ public class ScanQueryEngine
if (numScannedRows != null && numScannedRows >= query.getScanRowsLimit() && query.getTimeOrder().equals(ScanQuery.Order.NONE)) {
return Sequences.empty();
}
final boolean hasTimeout = QueryContexts.hasTimeout(query);
final boolean hasTimeout = query.context().hasTimeout();
final Long timeoutAt = responseContext.getTimeoutTime();
final StorageAdapter adapter = segment.asStorageAdapter();

View File

@ -99,7 +99,7 @@ public class ScanQueryLimitRowIterator implements CloseableIterator<ScanResultVa
// We want to perform multi-event ScanResultValue limiting if we are not time-ordering or are at the
// inner-level if we are time-ordering
if (query.getTimeOrder() == ScanQuery.Order.NONE ||
!query.getContextBoolean(ScanQuery.CTX_KEY_OUTERMOST, true)) {
!query.context().getBoolean(ScanQuery.CTX_KEY_OUTERMOST, true)) {
ScanResultValue batch = yielder.get();
List events = (List) batch.getEvents();
if (events.size() <= limit - count) {

View File

@ -33,7 +33,6 @@ import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryProcessingPool;
import org.apache.druid.query.QueryRunner;
@ -94,7 +93,7 @@ public class ScanQueryRunnerFactory implements QueryRunnerFactory<ScanResultValu
// Note: this variable is effective only when queryContext has a timeout.
// See the comment of ResponseContext.Key.TIMEOUT_AT.
final long timeoutAt = System.currentTimeMillis() + QueryContexts.getTimeout(queryPlus.getQuery());
final long timeoutAt = System.currentTimeMillis() + queryPlus.getQuery().context().getTimeout();
responseContext.putTimeoutTime(timeoutAt);
if (query.getTimeOrder().equals(ScanQuery.Order.NONE)) {

View File

@ -55,7 +55,7 @@ public class SearchQueryConfig
{
final SearchQueryConfig newConfig = new SearchQueryConfig();
newConfig.maxSearchLimit = query.getLimit();
newConfig.searchStrategy = query.getQueryContext().getAsString(CTX_KEY_STRATEGY, searchStrategy);
newConfig.searchStrategy = query.context().getString(CTX_KEY_STRATEGY, searchStrategy);
return newConfig;
}
}

View File

@ -34,7 +34,6 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryToolChest;
@ -329,7 +328,7 @@ public class SearchQueryQueryToolChest extends QueryToolChest<Result<SearchResul
return runner.run(queryPlus, responseContext);
}
final boolean isBySegment = QueryContexts.isBySegment(query);
final boolean isBySegment = query.context().isBySegment();
return Sequences.map(
runner.run(queryPlus.withQuery(query.withLimit(config.getMaxSearchLimit())), responseContext),

View File

@ -24,7 +24,6 @@ import com.google.common.collect.Ordering;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QuerySegmentWalker;
import org.apache.druid.query.filter.DimFilter;
@ -34,6 +33,7 @@ import org.joda.time.Duration;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
@ -110,12 +110,6 @@ public class SelectQuery implements Query<Object>
throw new RuntimeException(REMOVED_ERROR_MESSAGE);
}
@Override
public QueryContext getQueryContext()
{
throw new RuntimeException(REMOVED_ERROR_MESSAGE);
}
@Override
public boolean isDescending()
{

View File

@ -68,7 +68,7 @@ public class SpecificSegmentQueryRunner<T> implements QueryRunner<T>
)
);
final boolean setName = input.getQuery().getContextBoolean(CTX_SET_THREAD_NAME, true);
final boolean setName = input.getQuery().context().getBoolean(CTX_SET_THREAD_NAME, true);
final Query<T> query = queryPlus.getQuery();

View File

@ -35,6 +35,7 @@ import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.DefaultGenericQueryMetricsFactory;
import org.apache.druid.query.GenericQueryMetricsFactory;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
@ -232,9 +233,10 @@ public class TimeBoundaryQueryQueryToolChest
{
if (query.isMinTime() || query.isMaxTime()) {
RowSignature.Builder builder = RowSignature.builder();
final QueryContext queryContext = query.context();
String outputName = query.isMinTime() ?
query.getQueryContext().getAsString(TimeBoundaryQuery.MIN_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MIN_TIME) :
query.getQueryContext().getAsString(TimeBoundaryQuery.MAX_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MAX_TIME);
queryContext.getString(TimeBoundaryQuery.MIN_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MIN_TIME) :
queryContext.getString(TimeBoundaryQuery.MAX_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MAX_TIME);
return builder.add(outputName, ColumnType.LONG).build();
}
return super.resultArraySignature(query);

View File

@ -154,17 +154,17 @@ public class TimeseriesQuery extends BaseQuery<Result<TimeseriesResultValue>>
public boolean isGrandTotal()
{
return getContextBoolean(CTX_GRAND_TOTAL, false);
return context().getBoolean(CTX_GRAND_TOTAL, false);
}
public String getTimestampResultField()
{
return getQueryContext().getAsString(CTX_TIMESTAMP_RESULT_FIELD);
return context().getString(CTX_TIMESTAMP_RESULT_FIELD);
}
public boolean isSkipEmptyBuckets()
{
return getContextBoolean(SKIP_EMPTY_BUCKETS, false);
return context().getBoolean(SKIP_EMPTY_BUCKETS, false);
}
@Nullable

View File

@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerHelper;
import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.Aggregator;
@ -101,7 +100,7 @@ public class TimeseriesQueryEngine
final ColumnInspector inspector = query.getVirtualColumns().wrapInspector(adapter);
final boolean doVectorize = QueryContexts.getVectorize(query).shouldVectorize(
final boolean doVectorize = query.context().getVectorize().shouldVectorize(
adapter.canVectorize(filter, query.getVirtualColumns(), descending)
&& VirtualColumns.shouldVectorize(query, query.getVirtualColumns(), adapter)
&& query.getAggregatorSpecs().stream().allMatch(aggregatorFactory -> aggregatorFactory.canVectorize(inspector))
@ -141,7 +140,7 @@ public class TimeseriesQueryEngine
queryInterval,
query.getVirtualColumns(),
descending,
QueryContexts.getVectorSize(query),
query.context().getVectorSize(),
timeseriesQueryMetrics
);

View File

@ -37,7 +37,6 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryToolChest;
@ -147,7 +146,7 @@ public class TimeseriesQueryQueryToolChest extends QueryToolChest<Result<Timeser
!query.isSkipEmptyBuckets() &&
// Returns empty sequence if bySegment is set because bySegment results are mostly used for
// caching in historicals or debugging where the exact results are preferred.
!QueryContexts.isBySegment(query)) {
!query.context().isBySegment()) {
// Usally it is NOT Okay to materialize results via toList(), but Granularity is ALL thus
// we have only one record.
final List<Result<TimeseriesResultValue>> val = baseResults.toList();

View File

@ -138,7 +138,7 @@ public class TopNQueryEngine
// if sorted by dimension we should aggregate all metrics in a single pass, use the regular pooled algorithm for
// this
topNAlgorithm = new PooledTopNAlgorithm(adapter, query, bufferPool);
} else if (selector.isAggregateTopNMetricFirst() || query.getContextBoolean("doAggregateTopNMetricFirst", false)) {
} else if (selector.isAggregateTopNMetricFirst() || query.context().getBoolean("doAggregateTopNMetricFirst", false)) {
// for high cardinality dimensions with larger result sets we aggregate with only the ordering aggregation to
// compute the first 'n' values, and then for the rest of the metrics but for only the 'n' values
topNAlgorithm = new AggregateTopNMetricFirstAlgorithm(adapter, query, bufferPool);

View File

@ -574,12 +574,12 @@ public class TopNQueryQueryToolChest extends QueryToolChest<Result<TopNResultVal
}
final TopNQuery query = (TopNQuery) input;
final int minTopNThreshold = query.getQueryContext().getAsInt("minTopNThreshold", config.getMinTopNThreshold());
final int minTopNThreshold = query.context().getInt(QueryContexts.MIN_TOP_N_THRESHOLD, config.getMinTopNThreshold());
if (query.getThreshold() > minTopNThreshold) {
return runner.run(queryPlus, responseContext);
}
final boolean isBySegment = QueryContexts.isBySegment(query);
final boolean isBySegment = query.context().isBySegment();
return Sequences.map(
runner.run(queryPlus.withQuery(query.withThreshold(minTopNThreshold)), responseContext),

View File

@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.Cacheable;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnCapabilities;
@ -48,6 +47,7 @@ import org.apache.druid.segment.virtual.VirtualizedColumnInspector;
import org.apache.druid.segment.virtual.VirtualizedColumnSelectorFactory;
import javax.annotation.Nullable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -120,7 +120,7 @@ public class VirtualColumns implements Cacheable
public static boolean shouldVectorize(Query<?> query, VirtualColumns virtualColumns, ColumnInspector inspector)
{
if (virtualColumns.getVirtualColumns().length > 0) {
return QueryContexts.getVectorizeVirtualColumns(query).shouldVectorize(virtualColumns.canVectorize(inspector));
return query.context().getVectorizeVirtualColumns().shouldVectorize(virtualColumns.canVectorize(inspector));
} else {
return true;
}

View File

@ -215,7 +215,7 @@ public class Filters
if (filter == null) {
return null;
}
boolean useCNF = query.getContextBoolean(QueryContexts.USE_FILTER_CNF_KEY, QueryContexts.DEFAULT_USE_FILTER_CNF);
boolean useCNF = query.context().getBoolean(QueryContexts.USE_FILTER_CNF_KEY, QueryContexts.DEFAULT_USE_FILTER_CNF);
try {
return useCNF ? Filters.toCnf(filter) : filter;
}

View File

@ -20,7 +20,7 @@
package org.apache.druid.segment.join.filter.rewrite;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryContext;
import java.util.Objects;
@ -76,12 +76,13 @@ public class JoinFilterRewriteConfig
public static JoinFilterRewriteConfig forQuery(final Query<?> query)
{
QueryContext context = query.context();
return new JoinFilterRewriteConfig(
QueryContexts.getEnableJoinFilterPushDown(query),
QueryContexts.getEnableJoinFilterRewrite(query),
QueryContexts.getEnableJoinFilterRewriteValueColumnFilters(query),
QueryContexts.getEnableRewriteJoinToFilter(query),
QueryContexts.getJoinFilterRewriteMaxSize(query)
context.getEnableJoinFilterPushDown(),
context.getEnableJoinFilterRewrite(),
context.getEnableJoinFilterRewriteValueColumnFilters(),
context.getEnableRewriteJoinToFilter(),
context.getJoinFilterRewriteMaxSize()
);
}

View File

@ -19,31 +19,45 @@
package org.apache.druid.query;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.exc.MismatchedInputException;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Ordering;
import nl.jqno.equalsverifier.EqualsVerifier;
import nl.jqno.equalsverifier.Warning;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.spec.QuerySegmentSpec;
import org.apache.druid.segment.DimensionHandlerUtils;
import org.joda.time.DateTimeZone;
import org.joda.time.Duration;
import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Test;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
public class QueryContextTest
{
private static final ObjectMapper JSON_MAPPER = new ObjectMapper();
@Test
public void testEquals()
{
@ -51,63 +65,83 @@ public class QueryContextTest
.suppress(Warning.NONFINAL_FIELDS, Warning.ALL_FIELDS_SHOULD_BE_USED)
.usingGetClass()
.forClass(QueryContext.class)
.withNonnullFields("defaultParams", "userParams", "systemParams")
.withNonnullFields("context")
.verify();
}
/**
* Verify that a context with an null map is the same as a context with
* an empty map.
*/
@Test
public void testEmptyParam()
public void testEmptyContext()
{
final QueryContext context = new QueryContext();
Assert.assertEquals(ImmutableMap.of(), context.getMergedParams());
{
final QueryContext context = new QueryContext(null);
assertEquals(ImmutableMap.of(), context.asMap());
}
{
final QueryContext context = new QueryContext(new HashMap<>());
assertEquals(ImmutableMap.of(), context.asMap());
}
{
final QueryContext context = QueryContext.of(null);
assertEquals(ImmutableMap.of(), context.asMap());
}
{
final QueryContext context = QueryContext.of(new HashMap<>());
assertEquals(ImmutableMap.of(), context.asMap());
}
{
final QueryContext context = QueryContext.empty();
assertEquals(ImmutableMap.of(), context.asMap());
}
}
@Test
public void testIsEmpty()
{
Assert.assertTrue(new QueryContext().isEmpty());
Assert.assertFalse(new QueryContext(ImmutableMap.of("k", "v")).isEmpty());
QueryContext context = new QueryContext();
context.addDefaultParam("k", "v");
Assert.assertFalse(context.isEmpty());
context = new QueryContext();
context.addSystemParam("k", "v");
Assert.assertFalse(context.isEmpty());
assertTrue(QueryContext.empty().isEmpty());
assertFalse(QueryContext.of(ImmutableMap.of("k", "v")).isEmpty());
}
@Test
public void testGetString()
{
final QueryContext context = new QueryContext(
final QueryContext context = QueryContext.of(
ImmutableMap.of("key", "val",
"key2", 2)
);
Assert.assertEquals("val", context.get("key"));
Assert.assertEquals("val", context.getAsString("key"));
Assert.assertEquals("2", context.getAsString("key2"));
Assert.assertNull(context.getAsString("non-exist"));
assertEquals("val", context.get("key"));
assertEquals("val", context.getString("key"));
assertNull(context.getString("non-exist"));
assertEquals("foo", context.getString("non-exist", "foo"));
assertThrows(BadQueryContextException.class, () -> context.getString("key2"));
}
@Test
public void testGetBoolean()
{
final QueryContext context = new QueryContext(
final QueryContext context = QueryContext.of(
ImmutableMap.of(
"key1", "true",
"key2", true
)
);
Assert.assertTrue(context.getAsBoolean("key1", false));
Assert.assertTrue(context.getAsBoolean("key2", false));
Assert.assertFalse(context.getAsBoolean("non-exist", false));
assertTrue(context.getBoolean("key1", false));
assertTrue(context.getBoolean("key2", false));
assertTrue(context.getBoolean("key1"));
assertFalse(context.getBoolean("non-exist", false));
assertNull(context.getBoolean("non-exist"));
}
@Test
public void testGetInt()
{
final QueryContext context = new QueryContext(
final QueryContext context = QueryContext.of(
ImmutableMap.of(
"key1", "100",
"key2", 100,
@ -115,17 +149,17 @@ public class QueryContextTest
)
);
Assert.assertEquals(100, context.getAsInt("key1", 0));
Assert.assertEquals(100, context.getAsInt("key2", 0));
Assert.assertEquals(0, context.getAsInt("non-exist", 0));
assertEquals(100, context.getInt("key1", 0));
assertEquals(100, context.getInt("key2", 0));
assertEquals(0, context.getInt("non-exist", 0));
Assert.assertThrows(IAE.class, () -> context.getAsInt("key3", 5));
assertThrows(BadQueryContextException.class, () -> context.getInt("key3", 5));
}
@Test
public void testGetLong()
{
final QueryContext context = new QueryContext(
final QueryContext context = QueryContext.of(
ImmutableMap.of(
"key1", "100",
"key2", 100,
@ -133,17 +167,127 @@ public class QueryContextTest
)
);
Assert.assertEquals(100L, context.getAsLong("key1", 0));
Assert.assertEquals(100L, context.getAsLong("key2", 0));
Assert.assertEquals(0L, context.getAsLong("non-exist", 0));
assertEquals(100L, context.getLong("key1", 0));
assertEquals(100L, context.getLong("key2", 0));
assertEquals(0L, context.getLong("non-exist", 0));
Assert.assertThrows(IAE.class, () -> context.getAsLong("key3", 5));
assertThrows(BadQueryContextException.class, () -> context.getLong("key3", 5));
}
/**
* Tests the several ways that Druid code parses context strings into Long
* values. The desired behavior is that "x" is parsed exactly the same as Jackson
* would parse x (where x is a valid number.) The context methods must emulate
* Jackson. The dimension utility method is included because some code used that
* for long parsing, and we must maintain backward compatibility.
* <p>
* The exceptions in the {@code assertThrows} are not critical: the key thing is
* that we're documenting what works and what doesn't. If an exception changes,
* just update the tests. If something no longer throws an exception, we'll want
* to verify that we support the new use case consistently in all three paths.
*/
@Test
public void testGetLongCompatibility() throws JsonProcessingException
{
{
String value = null;
// Only the context methods allow {"foo": null} to be parsed as a null Long.
assertNull(getContextLong(value));
// Nulls not legal on this path.
assertThrows(NullPointerException.class, () -> getDimensionLong(value));
// Nulls not legal on this path.
assertThrows(IllegalArgumentException.class, () -> getJsonLong(value));
}
{
String value = "";
// Blank string not legal on this path.
assertThrows(BadQueryContextException.class, () -> getContextLong(value));
assertNull(getDimensionLong(value));
// Blank string not allowed where a value is expected.
assertThrows(MismatchedInputException.class, () -> getJsonLong(value));
}
{
String value = "0";
assertEquals(0L, (long) getContextLong(value));
assertEquals(0L, (long) getDimensionLong(value));
assertEquals(0L, (long) getJsonLong(value));
}
{
String value = "+1";
assertEquals(1L, (long) getContextLong(value));
assertEquals(1L, (long) getDimensionLong(value));
assertThrows(JsonParseException.class, () -> getJsonLong(value));
}
{
String value = "-1";
assertEquals(-1L, (long) getContextLong(value));
assertEquals(-1L, (long) getDimensionLong(value));
assertEquals(-1L, (long) getJsonLong(value));
}
{
// Hexadecimal numbers are not supported in JSON. Druid also does not support
// them in strings.
String value = "0xabcd";
assertThrows(BadQueryContextException.class, () -> getContextLong(value));
// The dimension utils have a funny way of handling hex: they return null
assertNull(getDimensionLong(value));
assertThrows(JsonParseException.class, () -> getJsonLong(value));
}
{
// Leading zeros supported by Druid parsing, but not by JSON.
String value = "05";
assertEquals(5L, (long) getContextLong(value));
assertEquals(5L, (long) getDimensionLong(value));
assertThrows(JsonParseException.class, () -> getJsonLong(value));
}
{
// The dimension utils allow a float where a long is expected.
// Jackson can do this conversion. This test verifies that the context
// functions can handle the same conversion.
String value = "10.00";
assertEquals(10L, (long) getContextLong(value));
assertEquals(10L, (long) getDimensionLong(value));
assertEquals(10L, (long) getJsonLong(value));
}
{
// None of the conversion methods allow a (thousands) separator. The comma
// would be ambiguous in JSON. Java allows the underscore, but JSON does
// not support this syntax, and neither does Druid's string-to-long conversion.
String value = "1_234";
assertThrows(BadQueryContextException.class, () -> getContextLong(value));
assertNull(getDimensionLong(value));
assertThrows(JsonParseException.class, () -> getJsonLong(value));
}
}
private static Long getContextLong(String value)
{
return QueryContexts.getAsLong("dummy", value);
}
private static Long getJsonLong(String value) throws JsonProcessingException
{
return JSON_MAPPER.readValue(value, Long.class);
}
private static Long getDimensionLong(String value)
{
return DimensionHandlerUtils.getExactLongFromDecimalString(value);
}
@Test
public void testGetFloat()
{
final QueryContext context = new QueryContext(
final QueryContext context = QueryContext.of(
ImmutableMap.of(
"f1", "500",
"f2", 500,
@ -152,11 +296,11 @@ public class QueryContextTest
)
);
Assert.assertEquals(0, Float.compare(500, context.getAsFloat("f1", 100)));
Assert.assertEquals(0, Float.compare(500, context.getAsFloat("f2", 100)));
Assert.assertEquals(0, Float.compare(500.1f, context.getAsFloat("f3", 100)));
assertEquals(0, Float.compare(500, context.getFloat("f1", 100)));
assertEquals(0, Float.compare(500, context.getFloat("f2", 100)));
assertEquals(0, Float.compare(500.1f, context.getFloat("f3", 100)));
Assert.assertThrows(IAE.class, () -> context.getAsLong("f4", 5));
assertThrows(BadQueryContextException.class, () -> context.getFloat("f4", 5));
}
@Test
@ -172,167 +316,30 @@ public class QueryContextTest
.put("m6", "abc")
.build()
);
Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m1", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m2", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500 * 1024 * 1024L, context.getAsHumanReadableBytes("m3", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500 * 1024 * 1024L, context.getAsHumanReadableBytes("m4", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m5", HumanReadableBytes.ZERO).getBytes());
assertEquals(500_000_000, context.getHumanReadableBytes("m1", HumanReadableBytes.ZERO).getBytes());
assertEquals(500_000_000, context.getHumanReadableBytes("m2", HumanReadableBytes.ZERO).getBytes());
assertEquals(500 * 1024 * 1024L, context.getHumanReadableBytes("m3", HumanReadableBytes.ZERO).getBytes());
assertEquals(500 * 1024 * 1024L, context.getHumanReadableBytes("m4", HumanReadableBytes.ZERO).getBytes());
assertEquals(500_000_000, context.getHumanReadableBytes("m5", HumanReadableBytes.ZERO).getBytes());
Assert.assertThrows(IAE.class, () -> context.getAsHumanReadableBytes("m6", HumanReadableBytes.ZERO));
assertThrows(BadQueryContextException.class, () -> context.getHumanReadableBytes("m6", HumanReadableBytes.ZERO));
}
@Test
public void testAddSystemParamOverrideUserParam()
public void testDefaultEnableQueryDebugging()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addSystemParam("sys1", "sysVal1");
context.addSystemParam("conflict", "sysVal2");
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
),
context.getUserParams()
);
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"sys1", "sysVal1",
"conflict", "sysVal2"
),
context.getMergedParams()
);
}
@Test
public void testUserParamOverrideDefaultParam()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1"
)
);
context.addDefaultParam("conflict", "defaultVal2");
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
),
context.getUserParams()
);
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"default1", "defaultVal1",
"conflict", "userVal2"
),
context.getMergedParams()
);
}
@Test
public void testRemoveUserParam()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1",
"conflict", "defaultVal2"
)
);
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"default1", "defaultVal1",
"conflict", "userVal2"
),
context.getMergedParams()
);
Assert.assertEquals("userVal2", context.removeUserParam("conflict"));
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"default1", "defaultVal1",
"conflict", "defaultVal2"
),
context.getMergedParams()
);
}
@Test
public void testGetMergedParams()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1",
"conflict", "defaultVal2"
)
);
Assert.assertSame(context.getMergedParams(), context.getMergedParams());
}
@Test
public void testCopy()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1",
"conflict", "defaultVal2"
)
);
context.addSystemParam("sys1", "val1");
final Map<String, Object> merged = ImmutableMap.copyOf(context.getMergedParams());
final QueryContext context2 = context.copy();
context2.removeUserParam("conflict");
context2.addSystemParam("sys2", "val2");
context2.addDefaultParam("default3", "defaultVal3");
Assert.assertEquals(merged, context.getMergedParams());
assertFalse(QueryContext.empty().isDebug());
assertTrue(QueryContext.of(ImmutableMap.of(QueryContexts.ENABLE_DEBUG, true)).isDebug());
}
// This test is a bit silly. It is retained because another test uses the
// LegacyContextQuery test.
@Test
public void testLegacyReturnsLegacy()
{
Query<?> legacy = new LegacyContextQuery(ImmutableMap.of("foo", "bar"));
Assert.assertNull(legacy.getQueryContext());
Map<String, Object> context = ImmutableMap.of("foo", "bar");
Query<?> legacy = new LegacyContextQuery(context);
assertEquals(context, legacy.getContext());
}
@Test
@ -345,10 +352,10 @@ public class QueryContextTest
.aggregators(Collections.singletonList(new CountAggregatorFactory("theCount")))
.context(ImmutableMap.of("foo", "bar"))
.build();
Assert.assertNotNull(timeseries.getQueryContext());
assertNotNull(timeseries.getContext());
}
public static class LegacyContextQuery implements Query
public static class LegacyContextQuery implements Query<Integer>
{
private final Map<String, Object> context;
@ -382,9 +389,9 @@ public class QueryContextTest
}
@Override
public QueryRunner getRunner(QuerySegmentWalker walker)
public QueryRunner<Integer> getRunner(QuerySegmentWalker walker)
{
return new NoopQueryRunner();
return new NoopQueryRunner<>();
}
@Override
@ -417,31 +424,6 @@ public class QueryContextTest
return context;
}
@Override
public boolean getContextBoolean(String key, boolean defaultValue)
{
if (context == null || !context.containsKey(key)) {
return defaultValue;
}
return (boolean) context.get(key);
}
@Override
public HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{
if (null == context || !context.containsKey(key)) {
return defaultValue;
}
Object value = context.get(key);
if (value instanceof Number) {
return HumanReadableBytes.valueOf(((Number) value).longValue());
} else if (value instanceof String) {
return new HumanReadableBytes((String) value);
} else {
throw new IAE("Expected parameter [%s] to be in human readable format", key);
}
}
@Override
public boolean isDescending()
{
@ -449,19 +431,19 @@ public class QueryContextTest
}
@Override
public Ordering getResultOrdering()
public Ordering<Integer> getResultOrdering()
{
return Ordering.natural();
}
@Override
public Query withQuerySegmentSpec(QuerySegmentSpec spec)
public Query<Integer> withQuerySegmentSpec(QuerySegmentSpec spec)
{
return new LegacyContextQuery(context);
}
@Override
public Query withId(String id)
public Query<Integer> withId(String id)
{
context.put(BaseQuery.QUERY_ID, id);
return this;
@ -475,7 +457,7 @@ public class QueryContextTest
}
@Override
public Query withSubQueryId(String subQueryId)
public Query<Integer> withSubQueryId(String subQueryId)
{
context.put(BaseQuery.SUB_QUERY_ID, subQueryId);
return this;
@ -489,21 +471,15 @@ public class QueryContextTest
}
@Override
public Query withDataSource(DataSource dataSource)
public Query<Integer> withDataSource(DataSource dataSource)
{
return this;
}
@Override
public Query withOverriddenContext(Map contextOverride)
public Query<Integer> withOverriddenContext(Map<String, Object> contextOverride)
{
return new LegacyContextQuery(contextOverride);
}
@Override
public Object getContextValue(String key)
{
return context.get(key);
}
}
}

View File

@ -22,7 +22,6 @@ package org.apache.druid.query;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.junit.Assert;
@ -47,7 +46,7 @@ public class QueryContextsTest
false,
new HashMap<>()
);
Assert.assertEquals(300_000, QueryContexts.getDefaultTimeout(query));
Assert.assertEquals(300_000, query.context().getDefaultTimeout());
}
@Test
@ -59,10 +58,10 @@ public class QueryContextsTest
false,
new HashMap<>()
);
Assert.assertEquals(300_000, QueryContexts.getTimeout(query));
Assert.assertEquals(300_000, query.context().getTimeout());
query = QueryContexts.withDefaultTimeout(query, 60_000);
Assert.assertEquals(60_000, QueryContexts.getTimeout(query));
query = Queries.withDefaultTimeout(query, 60_000);
Assert.assertEquals(60_000, query.context().getTimeout());
}
@Test
@ -74,17 +73,17 @@ public class QueryContextsTest
false,
ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1000)
);
Assert.assertEquals(1000, QueryContexts.getTimeout(query));
Assert.assertEquals(1000, query.context().getTimeout());
query = QueryContexts.withDefaultTimeout(query, 1_000_000);
Assert.assertEquals(1000, QueryContexts.getTimeout(query));
query = Queries.withDefaultTimeout(query, 1_000_000);
Assert.assertEquals(1000, query.context().getTimeout());
}
@Test
public void testQueryMaxTimeout()
{
exception.expect(IAE.class);
exception.expectMessage("configured [timeout = 1000] is more than enforced limit of maxQueryTimeout [100].");
exception.expect(BadQueryContextException.class);
exception.expectMessage("Configured timeout = 1000 is more than enforced limit of 100.");
Query<?> query = new TestQuery(
new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
@ -92,14 +91,14 @@ public class QueryContextsTest
ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1000)
);
QueryContexts.verifyMaxQueryTimeout(query, 100);
query.context().verifyMaxQueryTimeout(100);
}
@Test
public void testMaxScatterGatherBytes()
{
exception.expect(IAE.class);
exception.expectMessage("configured [maxScatterGatherBytes = 1000] is more than enforced limit of [100].");
exception.expect(BadQueryContextException.class);
exception.expectMessage("Configured maxScatterGatherBytes = 1000 is more than enforced limit of 100.");
Query<?> query = new TestQuery(
new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
@ -107,7 +106,7 @@ public class QueryContextsTest
ImmutableMap.of(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, 1000)
);
QueryContexts.withMaxScatterGatherBytes(query, 100);
Queries.withMaxScatterGatherBytes(query, 100);
}
@Test
@ -119,7 +118,7 @@ public class QueryContextsTest
false,
ImmutableMap.of(QueryContexts.SECONDARY_PARTITION_PRUNING_KEY, false)
);
Assert.assertFalse(QueryContexts.isSecondaryPartitionPruningEnabled(query));
Assert.assertFalse(query.context().isSecondaryPartitionPruningEnabled());
}
@Test
@ -131,7 +130,7 @@ public class QueryContextsTest
false,
ImmutableMap.of()
);
Assert.assertTrue(QueryContexts.isSecondaryPartitionPruningEnabled(query));
Assert.assertTrue(query.context().isSecondaryPartitionPruningEnabled());
}
@Test
@ -139,7 +138,7 @@ public class QueryContextsTest
{
Assert.assertEquals(
QueryContexts.DEFAULT_IN_SUB_QUERY_THRESHOLD,
QueryContexts.getInSubQueryThreshold(ImmutableMap.of())
QueryContext.empty().getInSubQueryThreshold()
);
}
@ -148,32 +147,32 @@ public class QueryContextsTest
{
Assert.assertEquals(
QueryContexts.DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING,
QueryContexts.isTimeBoundaryPlanningEnabled(ImmutableMap.of())
QueryContext.empty().isTimeBoundaryPlanningEnabled()
);
}
@Test
public void testGetEnableJoinLeftScanDirect()
{
Assert.assertFalse(QueryContexts.getEnableJoinLeftScanDirect(ImmutableMap.of()));
Assert.assertTrue(QueryContexts.getEnableJoinLeftScanDirect(ImmutableMap.of(
Assert.assertFalse(QueryContext.empty().getEnableJoinLeftScanDirect());
Assert.assertTrue(QueryContext.of(ImmutableMap.of(
QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT,
true
)));
Assert.assertFalse(QueryContexts.getEnableJoinLeftScanDirect(ImmutableMap.of(
)).getEnableJoinLeftScanDirect());
Assert.assertFalse(QueryContext.of(ImmutableMap.of(
QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT,
false
)));
)).getEnableJoinLeftScanDirect());
}
@Test
public void testGetBrokerServiceName()
{
Map<String, Object> queryContext = new HashMap<>();
Assert.assertNull(QueryContexts.getBrokerServiceName(queryContext));
Assert.assertNull(QueryContext.of(queryContext).getBrokerServiceName());
queryContext.put(QueryContexts.BROKER_SERVICE_NAME, "hotBroker");
Assert.assertEquals("hotBroker", QueryContexts.getBrokerServiceName(queryContext));
Assert.assertEquals("hotBroker", QueryContext.of(queryContext).getBrokerServiceName());
}
@Test
@ -182,8 +181,8 @@ public class QueryContextsTest
Map<String, Object> queryContext = new HashMap<>();
queryContext.put(QueryContexts.BROKER_SERVICE_NAME, 100);
exception.expect(ClassCastException.class);
QueryContexts.getBrokerServiceName(queryContext);
exception.expect(BadQueryContextException.class);
QueryContext.of(queryContext).getBrokerServiceName();
}
@Test
@ -193,38 +192,12 @@ public class QueryContextsTest
queryContext.put(QueryContexts.TIMEOUT_KEY, "2000'");
exception.expect(BadQueryContextException.class);
QueryContexts.getTimeout(new TestQuery(
new TestQuery(
new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
false,
queryContext
));
}
@Test
public void testDefaultEnableQueryDebugging()
{
Query<?> query = new TestQuery(
new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
false,
ImmutableMap.of()
);
Assert.assertFalse(QueryContexts.isDebug(query));
Assert.assertFalse(QueryContexts.isDebug(query.getContext()));
}
@Test
public void testEnableQueryDebuggingSetToTrue()
{
Query<?> query = new TestQuery(
new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
false,
ImmutableMap.of(QueryContexts.ENABLE_DEBUG, true)
);
Assert.assertTrue(QueryContexts.isDebug(query));
Assert.assertTrue(QueryContexts.isDebug(query.getContext()));
).context().getTimeout();
}
@Test
@ -237,7 +210,7 @@ public class QueryContextsTest
QueryContexts.getAsString("foo", 10, null);
Assert.fail();
}
catch (IAE e) {
catch (BadQueryContextException e) {
// Expected
}
@ -249,7 +222,7 @@ public class QueryContextsTest
QueryContexts.getAsBoolean("foo", 10, false);
Assert.fail();
}
catch (IAE e) {
catch (BadQueryContextException e) {
// Expected
}
@ -262,7 +235,7 @@ public class QueryContextsTest
QueryContexts.getAsInt("foo", true, 20);
Assert.fail();
}
catch (IAE e) {
catch (BadQueryContextException e) {
// Expected
}
@ -275,7 +248,7 @@ public class QueryContextsTest
QueryContexts.getAsLong("foo", true, 20);
Assert.fail();
}
catch (IAE e) {
catch (BadQueryContextException e) {
// Expected
}
}
@ -314,12 +287,12 @@ public class QueryContextsTest
Assert.assertEquals(
QueryContexts.Vectorize.FORCE,
query.getQueryContext().getAsEnum("e1", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE)
query.context().getEnum("e1", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE)
);
Assert.assertThrows(
IAE.class,
() -> query.getQueryContext().getAsEnum("e2", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE)
BadQueryContextException.class,
() -> query.context().getEnum("e2", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE)
);
}
}

View File

@ -31,6 +31,7 @@ import org.apache.druid.query.DefaultGenericQueryMetricsFactory;
import org.apache.druid.query.Druids;
import org.apache.druid.query.GenericQueryMetricsFactory;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
@ -102,13 +103,14 @@ public class DataSourceMetadataQueryTest
), Query.class
);
Assert.assertEquals((Integer) 1, serdeQuery.getQueryContext().getAsInt(QueryContexts.PRIORITY_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.FINALIZE_KEY));
Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.USE_CACHE_KEY, false));
Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.POPULATE_CACHE_KEY, false));
Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.FINALIZE_KEY, false));
final QueryContext queryContext = serdeQuery.context();
Assert.assertEquals(1, (int) queryContext.getInt(QueryContexts.PRIORITY_KEY));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals("true", queryContext.getString(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.FINALIZE_KEY));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.USE_CACHE_KEY, false));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.POPULATE_CACHE_KEY, false));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.FINALIZE_KEY, false));
}
@Test

View File

@ -20,7 +20,6 @@
package org.apache.druid.query.groupby.epinephelinae.vector;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerTestHelper;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
@ -68,7 +67,7 @@ public class VectorGroupByEngineIteratorTest extends InitializedNullHandlingTest
interval,
query.getVirtualColumns(),
false,
QueryContexts.getVectorSize(query),
query.context().getVectorSize(),
null
);
final List<GroupByVectorColumnSelector> dimensions = query.getDimensions().stream().map(

View File

@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableMap;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.query.Druids;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.junit.Assert;
import org.junit.Test;
@ -78,10 +79,11 @@ public class TimeBoundaryQueryTest
), TimeBoundaryQuery.class
);
Assert.assertEquals(new Integer(1), serdeQuery.getQueryContext().getAsInt(QueryContexts.PRIORITY_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.FINALIZE_KEY));
final QueryContext queryContext = query.context();
Assert.assertEquals(1, (int) queryContext.getInt(QueryContexts.PRIORITY_KEY));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.FINALIZE_KEY));
}
@Test
@ -116,9 +118,10 @@ public class TimeBoundaryQueryTest
);
Assert.assertEquals("1", serdeQuery.getQueryContext().getAsString(QueryContexts.PRIORITY_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.FINALIZE_KEY));
final QueryContext queryContext = query.context();
Assert.assertEquals("1", queryContext.get(QueryContexts.PRIORITY_KEY));
Assert.assertEquals("true", queryContext.get(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals("true", queryContext.get(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals("true", queryContext.get(QueryContexts.FINALIZE_KEY));
}
}

View File

@ -24,12 +24,12 @@ import org.apache.druid.client.cache.CacheConfig;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryToolChest;
import org.apache.druid.query.SegmentDescriptor;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
public class CacheUtil
@ -109,7 +109,7 @@ public class CacheUtil
)
{
return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType)
&& QueryContexts.isUseCache(query)
&& query.context().isUseCache()
&& cacheConfig.isUseCache();
}
@ -129,7 +129,7 @@ public class CacheUtil
)
{
return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType)
&& QueryContexts.isPopulateCache(query)
&& query.context().isPopulateCache()
&& cacheConfig.isPopulateCache();
}
@ -149,7 +149,7 @@ public class CacheUtil
)
{
return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType)
&& QueryContexts.isUseResultLevelCache(query)
&& query.context().isUseResultLevelCache()
&& cacheConfig.isUseResultLevelCache();
}
@ -169,7 +169,7 @@ public class CacheUtil
)
{
return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType)
&& QueryContexts.isPopulateResultLevelCache(query)
&& query.context().isPopulateResultLevelCache()
&& cacheConfig.isPopulateResultLevelCache();
}

View File

@ -60,6 +60,7 @@ import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.Queries;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryPlus;
@ -282,10 +283,11 @@ public class CachingClusteredClient implements QuerySegmentWalker
this.useCache = CacheUtil.isUseSegmentCache(query, strategy, cacheConfig, CacheUtil.ServerType.BROKER);
this.populateCache = CacheUtil.isPopulateSegmentCache(query, strategy, cacheConfig, CacheUtil.ServerType.BROKER);
this.isBySegment = QueryContexts.isBySegment(query);
final QueryContext queryContext = query.context();
this.isBySegment = queryContext.isBySegment();
// Note that enabling this leads to putting uncovered intervals information in the response headers
// and might blow up in some cases https://github.com/apache/druid/issues/2108
this.uncoveredIntervalsLimit = QueryContexts.getUncoveredIntervalsLimit(query);
this.uncoveredIntervalsLimit = queryContext.getUncoveredIntervalsLimit();
// For nested queries, we need to look at the intervals of the inner most query.
this.intervals = dataSourceAnalysis.getBaseQuerySegmentSpec()
.map(QuerySegmentSpec::getIntervals)
@ -304,9 +306,10 @@ public class CachingClusteredClient implements QuerySegmentWalker
{
final ImmutableMap.Builder<String, Object> contextBuilder = new ImmutableMap.Builder<>();
final int priority = QueryContexts.getPriority(query);
final QueryContext queryContext = query.context();
final int priority = queryContext.getPriority();
contextBuilder.put(QueryContexts.PRIORITY_KEY, priority);
final String lane = QueryContexts.getLane(query);
final String lane = queryContext.getLane();
if (lane != null) {
contextBuilder.put(QueryContexts.LANE_KEY, lane);
}
@ -384,18 +387,19 @@ public class CachingClusteredClient implements QuerySegmentWalker
private Sequence<T> merge(List<Sequence<T>> sequencesByInterval)
{
BinaryOperator<T> mergeFn = toolChest.createMergeFn(query);
if (processingConfig.useParallelMergePool() && QueryContexts.getEnableParallelMerges(query) && mergeFn != null) {
final QueryContext queryContext = query.context();
if (processingConfig.useParallelMergePool() && queryContext.getEnableParallelMerges() && mergeFn != null) {
return new ParallelMergeCombiningSequence<>(
pool,
sequencesByInterval,
query.getResultOrdering(),
mergeFn,
QueryContexts.hasTimeout(query),
QueryContexts.getTimeout(query),
QueryContexts.getPriority(query),
QueryContexts.getParallelMergeParallelism(query, processingConfig.getMergePoolDefaultMaxQueryParallelism()),
QueryContexts.getParallelMergeInitialYieldRows(query, processingConfig.getMergePoolTaskInitialYieldRows()),
QueryContexts.getParallelMergeSmallBatchRows(query, processingConfig.getMergePoolSmallBatchRows()),
queryContext.hasTimeout(),
queryContext.getTimeout(),
queryContext.getPriority(),
queryContext.getParallelMergeParallelism(processingConfig.getMergePoolDefaultMaxQueryParallelism()),
queryContext.getParallelMergeInitialYieldRows(processingConfig.getMergePoolTaskInitialYieldRows()),
queryContext.getParallelMergeSmallBatchRows(processingConfig.getMergePoolSmallBatchRows()),
processingConfig.getMergePoolTargetTaskRunTimeMillis(),
reportMetrics -> {
QueryMetrics<?> queryMetrics = queryPlus.getQueryMetrics();
@ -437,7 +441,7 @@ public class CachingClusteredClient implements QuerySegmentWalker
// Filter unneeded chunks based on partition dimension
for (TimelineObjectHolder<String, ServerSelector> holder : serversLookup) {
final Set<PartitionChunk<ServerSelector>> filteredChunks;
if (QueryContexts.isSecondaryPartitionPruningEnabled(query)) {
if (query.context().isSecondaryPartitionPruningEnabled()) {
filteredChunks = DimFilterUtils.filterShards(
query.getFilter(),
holder.getObject(),
@ -652,12 +656,12 @@ public class CachingClusteredClient implements QuerySegmentWalker
final QueryRunner serverRunner = serverView.getQueryRunner(server);
if (serverRunner == null) {
log.error("Server[%s] doesn't have a query runner", server.getName());
log.error("Server [%s] doesn't have a query runner", server.getName());
return;
}
// Divide user-provided maxQueuedBytes by the number of servers, and limit each server to that much.
final long maxQueuedBytes = QueryContexts.getMaxQueuedBytes(query, httpClientConfig.getMaxQueuedBytes());
final long maxQueuedBytes = query.context().getMaxQueuedBytes(httpClientConfig.getMaxQueuedBytes());
final long maxQueuedBytesPerServer = maxQueuedBytes / segmentsByServer.size();
final Sequence<T> serverResults;
@ -776,7 +780,7 @@ public class CachingClusteredClient implements QuerySegmentWalker
this.dataSourceAnalysis = dataSourceAnalysis;
this.joinableFactoryWrapper = joinableFactoryWrapper;
this.isSegmentLevelCachingEnable = ((populateCache || useCache)
&& !QueryContexts.isBySegment(query)); // explicit bySegment queries are never cached
&& !query.context().isBySegment()); // explicit bySegment queries are never cached
}

View File

@ -41,8 +41,9 @@ import org.apache.druid.java.util.http.client.response.ClientResponse;
import org.apache.druid.java.util.http.client.response.HttpResponseHandler;
import org.apache.druid.java.util.http.client.response.StatusResponseHandler;
import org.apache.druid.java.util.http.client.response.StatusResponseHolder;
import org.apache.druid.query.Queries;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
@ -152,7 +153,7 @@ public class DirectDruidClient<T> implements QueryRunner<T>
{
final Query<T> query = queryPlus.getQuery();
QueryToolChest<T, Query<T>> toolChest = warehouse.getToolChest(query);
boolean isBySegment = QueryContexts.isBySegment(query);
boolean isBySegment = query.context().isBySegment();
final JavaType queryResultType = isBySegment ? toolChest.getBySegmentResultType() : toolChest.getBaseResultType();
final ListenableFuture<InputStream> future;
@ -160,13 +161,15 @@ public class DirectDruidClient<T> implements QueryRunner<T>
final String cancelUrl = url + query.getId();
try {
log.debug("Querying queryId[%s] url[%s]", query.getId(), url);
log.debug("Querying queryId [%s] url [%s]", query.getId(), url);
final long requestStartTimeNs = System.nanoTime();
final long timeoutAt = query.getQueryContext().getAsLong(QUERY_FAIL_TIME);
final long maxScatterGatherBytes = QueryContexts.getMaxScatterGatherBytes(query);
final QueryContext queryContext = query.context();
// Will NPE if the value is not set.
final long timeoutAt = queryContext.getLong(QUERY_FAIL_TIME);
final long maxScatterGatherBytes = queryContext.getMaxScatterGatherBytes();
final AtomicLong totalBytesGathered = context.getTotalBytes();
final long maxQueuedBytes = QueryContexts.getMaxQueuedBytes(query, 0);
final long maxQueuedBytes = queryContext.getMaxQueuedBytes(0);
final boolean usingBackpressure = maxQueuedBytes > 0;
final HttpResponseHandler<InputStream, InputStream> responseHandler = new HttpResponseHandler<InputStream, InputStream>()
@ -454,7 +457,7 @@ public class DirectDruidClient<T> implements QueryRunner<T>
new Request(
HttpMethod.POST,
new URL(url)
).setContent(objectMapper.writeValueAsBytes(QueryContexts.withTimeout(query, timeLeft)))
).setContent(objectMapper.writeValueAsBytes(Queries.withTimeout(query, timeLeft)))
.setHeader(
HttpHeaders.Names.CONTENT_TYPE,
isSmile ? SmileMediaTypes.APPLICATION_JACKSON_SMILE : MediaType.APPLICATION_JSON

View File

@ -75,7 +75,7 @@ public class JsonParserIterator<T> implements Iterator<T>, Closeable
this.future = future;
this.url = url;
if (query != null) {
this.timeoutAt = query.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME, -1L);
this.timeoutAt = query.context().getLong(DirectDruidClient.QUERY_FAIL_TIME, -1L);
this.queryId = query.getId();
} else {
this.timeoutAt = -1;

View File

@ -215,15 +215,15 @@ public class RetryQueryRunner<T> implements QueryRunner<T>
if (sequence != null) {
return true;
} else {
final QueryContext queryContext = queryPlus.getQuery().context();
final List<SegmentDescriptor> missingSegments = getMissingSegments(queryPlus, context);
final int maxNumRetries = QueryContexts.getNumRetriesOnMissingSegments(
queryPlus.getQuery(),
final int maxNumRetries = queryContext.getNumRetriesOnMissingSegments(
config.getNumTries()
);
if (missingSegments.isEmpty()) {
return false;
} else if (retryCount >= maxNumRetries) {
if (!QueryContexts.allowReturnPartialResults(queryPlus.getQuery(), config.isReturnPartialResults())) {
if (!queryContext.allowReturnPartialResults(config.isReturnPartialResults())) {
throw new SegmentMissingException("No results found for segments[%s]", missingSegments);
} else {
return false;

View File

@ -161,7 +161,7 @@ public class SinkQuerySegmentWalker implements QuerySegmentWalker
}
final QueryToolChest<T, Query<T>> toolChest = factory.getToolchest();
final boolean skipIncrementalSegment = query.getContextBoolean(CONTEXT_SKIP_INCREMENTAL_SEGMENT, false);
final boolean skipIncrementalSegment = query.context().getBoolean(CONTEXT_SKIP_INCREMENTAL_SEGMENT, false);
final AtomicLong cpuTimeAccumulator = new AtomicLong(0L);
// Make sure this query type can handle the subquery, if present.

View File

@ -39,7 +39,6 @@ import org.apache.druid.query.GlobalTableDataSource;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.PostProcessingOperator;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
@ -163,7 +162,7 @@ public class ClientQuerySegmentWalker implements QuerySegmentWalker
final DataSource freeTradeDataSource = globalizeIfPossible(newQuery.getDataSource());
// do an inlining dry run to see if any inlining is necessary, without actually running the queries.
final int maxSubqueryRows = QueryContexts.getMaxSubqueryRows(query, serverConfig.getMaxSubqueryRows());
final int maxSubqueryRows = query.context().getMaxSubqueryRows(serverConfig.getMaxSubqueryRows());
final DataSource inlineDryRun = inlineIfNecessary(
freeTradeDataSource,
@ -431,7 +430,7 @@ public class ClientQuerySegmentWalker implements QuerySegmentWalker
.emitCPUTimeMetric(emitter)
.postProcess(
objectMapper.convertValue(
query.getQueryContext().getAsString("postProcessing"),
query.context().getString("postProcessing"),
new TypeReference<PostProcessingOperator<T>>() {}
)
)

View File

@ -21,6 +21,7 @@ package org.apache.druid.server;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.Iterables;
import org.apache.druid.client.DirectDruidClient;
import org.apache.druid.java.util.common.DateTimes;
@ -61,7 +62,8 @@ import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
@ -102,6 +104,8 @@ public class QueryLifecycle
@MonotonicNonNull
private Query<?> baseQuery;
@MonotonicNonNull
private Set<String> userContextKeys;
public QueryLifecycle(
final QueryToolChestWarehouse warehouse,
@ -195,17 +199,15 @@ public class QueryLifecycle
{
transition(State.NEW, State.INITIALIZED);
if (baseQuery.getQueryContext() == null) {
QueryContext context = new QueryContext(baseQuery.getContext());
context.addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString());
context.addDefaultParams(defaultQueryConfig.getContext());
this.baseQuery = baseQuery.withOverriddenContext(context.getMergedParams());
} else {
baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString());
baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext());
this.baseQuery = baseQuery;
userContextKeys = new HashSet<>(baseQuery.getContext().keySet());
String queryId = baseQuery.getId();
if (Strings.isNullOrEmpty(queryId)) {
queryId = UUID.randomUUID().toString();
}
Map<String, Object> mergedUserAndConfigContext = QueryContexts.override(defaultQueryConfig.getContext(), baseQuery.getContext());
mergedUserAndConfigContext.put(BaseQuery.QUERY_ID, queryId);
this.baseQuery = baseQuery.withOverriddenContext(mergedUserAndConfigContext);
this.toolChest = warehouse.getToolChest(this.baseQuery);
}
@ -220,23 +222,15 @@ public class QueryLifecycle
public Access authorize(HttpServletRequest req)
{
transition(State.INITIALIZED, State.AUTHORIZING);
final Set<String> contextKeys;
if (baseQuery.getQueryContext() == null) {
contextKeys = baseQuery.getContext().keySet();
} else {
contextKeys = baseQuery.getQueryContext().getUserParams().keySet();
}
final Iterable<ResourceAction> resourcesToAuthorize = Iterables.concat(
Iterables.transform(
baseQuery.getDataSource().getTableNames(),
AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR
),
authConfig.authorizeQueryContextParams()
? Iterables.transform(
contextKeys,
Iterables.transform(
authConfig.contextKeysToAuthorize(userContextKeys),
contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE)
)
: Collections.emptyList()
);
return doAuthorize(
AuthorizationUtils.authenticationResultFromRequest(req),
@ -353,7 +347,7 @@ public class QueryLifecycle
if (e != null) {
statsMap.put("exception", e.toString());
if (QueryContexts.isDebug(baseQuery)) {
if (baseQuery.context().isDebug()) {
log.warn(e, "Exception while processing queryId [%s]", baseQuery.getId());
} else {
log.noStackTrace().warn(e, "Exception while processing queryId [%s]", baseQuery.getId());
@ -403,9 +397,10 @@ public class QueryLifecycle
private boolean isSerializeDateTimeAsLong()
{
final boolean shouldFinalize = QueryContexts.isFinalize(baseQuery, true);
return QueryContexts.isSerializeDateTimeAsLong(baseQuery, false)
|| (!shouldFinalize && QueryContexts.isSerializeDateTimeAsLongInner(baseQuery, false));
final QueryContext queryContext = baseQuery.context();
final boolean shouldFinalize = queryContext.isFinalize(true);
return queryContext.isSerializeDateTimeAsLong(false)
|| (!shouldFinalize && queryContext.isSerializeDateTimeAsLongInner(false));
}
public ObjectWriter newOutputWriter(ResourceIOReaderWriter ioReaderWriter)

View File

@ -46,7 +46,7 @@ import org.apache.druid.query.BadJsonQueryException;
import org.apache.druid.query.BadQueryException;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryException;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryTimeoutException;
@ -383,20 +383,19 @@ public class QueryResource implements QueryCountStatsProvider
catch (JsonParseException e) {
throw new BadJsonQueryException(e);
}
String prevEtag = getPreviousEtag(req);
if (prevEtag != null) {
if (baseQuery.getQueryContext() == null) {
QueryContext context = new QueryContext(baseQuery.getContext());
context.addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);
return baseQuery.withOverriddenContext(context.getMergedParams());
} else {
baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);
}
if (prevEtag == null) {
return baseQuery;
}
return baseQuery;
return baseQuery.withOverriddenContext(
QueryContexts.override(
baseQuery.getContext(),
HEADER_IF_NONE_MATCH,
prevEtag
)
);
}
private static String getPreviousEtag(final HttpServletRequest req)

View File

@ -38,7 +38,6 @@ import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.java.util.emitter.service.ServiceMetricEvent;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryWatcher;
@ -254,7 +253,7 @@ public class QueryScheduler implements QueryWatcher
@VisibleForTesting
List<Bulkhead> acquireLanes(Query<?> query)
{
final String lane = QueryContexts.getLane(query);
final String lane = query.context().getLane();
final Optional<BulkheadConfig> laneConfig = lane == null ? Optional.empty() : laneRegistry.getConfiguration(lane);
final Optional<BulkheadConfig> totalConfig = laneRegistry.getConfiguration(TOTAL);
List<Bulkhead> hallPasses = new ArrayList<>(2);

View File

@ -22,8 +22,9 @@ package org.apache.druid.server;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.client.DirectDruidClient;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.query.Queries;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.context.ResponseContext;
@ -56,21 +57,23 @@ public class SetAndVerifyContextQueryRunner<T> implements QueryRunner<T>
public Query<T> withTimeoutAndMaxScatterGatherBytes(Query<T> query, ServerConfig serverConfig)
{
Query<T> newQuery = QueryContexts.verifyMaxQueryTimeout(
QueryContexts.withMaxScatterGatherBytes(
QueryContexts.withDefaultTimeout(
Query<T> newQuery =
Queries.withMaxScatterGatherBytes(
Queries.withDefaultTimeout(
query,
Math.min(serverConfig.getDefaultQueryTimeout(), serverConfig.getMaxQueryTimeout())
),
serverConfig.getMaxScatterGatherBytes()
),
);
newQuery.context().verifyMaxQueryTimeout(
serverConfig.getMaxQueryTimeout()
);
// DirectDruidClient.QUERY_FAIL_TIME is used by DirectDruidClient and JsonParserIterator to determine when to
// fail with a timeout exception
final long failTime;
if (QueryContexts.hasTimeout(newQuery)) {
failTime = this.startTimeMillis + QueryContexts.getTimeout(newQuery);
final QueryContext context = newQuery.context();
if (context.hasTimeout()) {
failTime = this.startTimeMillis + context.getTimeout();
} else {
failTime = this.startTimeMillis + serverConfig.getMaxQueryTimeout();
}

View File

@ -26,6 +26,7 @@ import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import org.apache.druid.client.SegmentServerSelector;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.server.QueryLaningStrategy;
@ -70,10 +71,11 @@ public class HiLoQueryLaningStrategy implements QueryLaningStrategy
// QueryContexts.getPriority gives a default, but it can parse the value to integer. Before calling QueryContexts.getPriority
// we make sure that priority has been set.
Integer priority = null;
if (theQuery.getContextValue(QueryContexts.PRIORITY_KEY) != null) {
priority = QueryContexts.getPriority(theQuery);
final QueryContext queryContext = theQuery.context();
if (null != queryContext.get(QueryContexts.PRIORITY_KEY)) {
priority = queryContext.getPriority();
}
final String lane = theQuery.getQueryContext().getAsString(QueryContexts.LANE_KEY);
final String lane = queryContext.getLane();
if (lane == null && priority != null && priority < 0) {
return Optional.of(LOW);
}

View File

@ -25,12 +25,12 @@ import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import org.apache.druid.client.SegmentServerSelector;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.server.QueryLaningStrategy;
import org.apache.druid.server.QueryScheduler;
import javax.annotation.Nullable;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
@ -84,6 +84,6 @@ public class ManualQueryLaningStrategy implements QueryLaningStrategy
@Override
public <T> Optional<String> computeLane(QueryPlus<T> query, Set<SegmentServerSelector> segments)
{
return Optional.ofNullable(QueryContexts.getLane(query.getQuery()));
return Optional.ofNullable(query.getQuery().context().getLane());
}
}

View File

@ -22,7 +22,6 @@ package org.apache.druid.server.scheduling;
import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import org.apache.druid.client.SegmentServerSelector;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.server.QueryLaningStrategy;
@ -47,6 +46,6 @@ public class NoQueryLaningStrategy implements QueryLaningStrategy
@Override
public <T> Optional<String> computeLane(QueryPlus<T> query, Set<SegmentServerSelector> segments)
{
return Optional.ofNullable(QueryContexts.getLane(query.getQuery()));
return Optional.ofNullable(query.getQuery().context().getLane());
}
}

View File

@ -25,7 +25,6 @@ import com.google.common.base.Preconditions;
import org.apache.druid.client.SegmentServerSelector;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.server.QueryPrioritizationStrategy;
import org.joda.time.DateTime;
@ -33,6 +32,7 @@ import org.joda.time.Duration;
import org.joda.time.Period;
import javax.annotation.Nullable;
import java.util.Optional;
import java.util.Set;
@ -87,7 +87,7 @@ public class ThresholdBasedQueryPrioritizationStrategy implements QueryPrioritiz
boolean violatesSegmentThreshold = segments.size() > segmentCountThreshold;
if (violatesPeriodThreshold || violatesDurationThreshold || violatesSegmentThreshold) {
final int adjustedPriority = QueryContexts.getPriority(theQuery) - adjustment;
final int adjustedPriority = theQuery.context().getPriority() - adjustment;
return Optional.of(adjustedPriority);
}
return Optional.empty();

View File

@ -27,6 +27,7 @@ public class Access
static final String DEFAULT_ERROR_MESSAGE = "Unauthorized";
public static final Access OK = new Access(true);
public static final Access DENIED = new Access(false);
private final boolean allowed;
private final String message;

View File

@ -21,10 +21,14 @@ package org.apache.druid.server.security;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.utils.CollectionUtils;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
public class AuthConfig
{
@ -46,25 +50,20 @@ public class AuthConfig
public static final String TRUSTED_DOMAIN_NAME = "trustedDomain";
/**
* Set of context keys which are always permissible because something in the Druid
* code itself sets the key before the security check.
*/
public static final Set<String> ALLOWED_CONTEXT_KEYS = ImmutableSet.of(
// Set in the Avatica server path
QueryContexts.CTX_SQL_STRINGIFY_ARRAYS,
// Set by the Router
QueryContexts.CTX_SQL_QUERY_ID
);
public AuthConfig()
{
this(null, null, null, false, false);
}
@JsonCreator
public AuthConfig(
@JsonProperty("authenticatorChain") List<String> authenticatorChain,
@JsonProperty("authorizers") List<String> authorizers,
@JsonProperty("unsecuredPaths") List<String> unsecuredPaths,
@JsonProperty("allowUnauthenticatedHttpOptions") boolean allowUnauthenticatedHttpOptions,
@JsonProperty("authorizeQueryContextParams") boolean authorizeQueryContextParams
)
{
this.authenticatorChain = authenticatorChain;
this.authorizers = authorizers;
this.unsecuredPaths = unsecuredPaths == null ? Collections.emptyList() : unsecuredPaths;
this.allowUnauthenticatedHttpOptions = allowUnauthenticatedHttpOptions;
this.authorizeQueryContextParams = authorizeQueryContextParams;
this(null, null, null, false, false, null, null);
}
@JsonProperty
@ -82,6 +81,44 @@ public class AuthConfig
@JsonProperty
private final boolean authorizeQueryContextParams;
/**
* The set of query context keys that are allowed, even when security is
* enabled. A null value is the same as an empty set.
*/
@JsonProperty
private final Set<String> unsecuredContextKeys;
/**
* The set of query context keys to secure, when context security is
* enabled. Null has a special meaning: it means to ignore this set.
* Else, only the keys in this set are subject to security. If set,
* the unsecured list is ignored.
*/
@JsonProperty
private final Set<String> securedContextKeys;
@JsonCreator
public AuthConfig(
@JsonProperty("authenticatorChain") List<String> authenticatorChain,
@JsonProperty("authorizers") List<String> authorizers,
@JsonProperty("unsecuredPaths") List<String> unsecuredPaths,
@JsonProperty("allowUnauthenticatedHttpOptions") boolean allowUnauthenticatedHttpOptions,
@JsonProperty("authorizeQueryContextParams") boolean authorizeQueryContextParams,
@JsonProperty("unsecuredContextKeys") Set<String> unsecuredContextKeys,
@JsonProperty("securedContextKeys") Set<String> securedContextKeys
)
{
this.authenticatorChain = authenticatorChain;
this.authorizers = authorizers;
this.unsecuredPaths = unsecuredPaths == null ? Collections.emptyList() : unsecuredPaths;
this.allowUnauthenticatedHttpOptions = allowUnauthenticatedHttpOptions;
this.authorizeQueryContextParams = authorizeQueryContextParams;
this.unsecuredContextKeys = unsecuredContextKeys == null
? Collections.emptySet()
: unsecuredContextKeys;
this.securedContextKeys = securedContextKeys;
}
public List<String> getAuthenticatorChain()
{
return authenticatorChain;
@ -107,6 +144,36 @@ public class AuthConfig
return authorizeQueryContextParams;
}
/**
* Filter the user-supplied context keys based on the context key security
* rules. If context key security is disabled, then allow all keys. Else,
* apply the three key lists defined here.
* <ul>
* <li>Allow Druid-defined keys.</li>
* <li>Allow anything not in the secured context key list.</li>
* <li>Allow anything in the config-defined unsecured key list.</li>
* </ul>
* In the typical case, a site defines either the secured key list
* (to handle a few keys that are <i>are not</i> allowed) or the unsecured key
* list (to enumerate a few that <i>are</i> allowed.) If both lists
* are given, think of the secured list as exceptions to the unsecured
* key list.
*
* @return the list of secured keys to check via authentication
*/
public Set<String> contextKeysToAuthorize(final Set<String> userKeys)
{
if (!authorizeQueryContextParams) {
return ImmutableSet.of();
}
Set<String> keysToCheck = CollectionUtils.subtract(userKeys, ALLOWED_CONTEXT_KEYS);
keysToCheck = CollectionUtils.subtract(keysToCheck, unsecuredContextKeys);
if (securedContextKeys != null) {
keysToCheck = CollectionUtils.intersect(keysToCheck, securedContextKeys);
}
return keysToCheck;
}
@Override
public boolean equals(Object o)
{
@ -121,7 +188,9 @@ public class AuthConfig
&& authorizeQueryContextParams == that.authorizeQueryContextParams
&& Objects.equals(authenticatorChain, that.authenticatorChain)
&& Objects.equals(authorizers, that.authorizers)
&& Objects.equals(unsecuredPaths, that.unsecuredPaths);
&& Objects.equals(unsecuredPaths, that.unsecuredPaths)
&& Objects.equals(unsecuredContextKeys, that.unsecuredContextKeys)
&& Objects.equals(securedContextKeys, that.securedContextKeys);
}
@Override
@ -132,7 +201,9 @@ public class AuthConfig
authorizers,
unsecuredPaths,
allowUnauthenticatedHttpOptions,
authorizeQueryContextParams
authorizeQueryContextParams,
unsecuredContextKeys,
securedContextKeys
);
}
@ -145,6 +216,8 @@ public class AuthConfig
", unsecuredPaths=" + unsecuredPaths +
", allowUnauthenticatedHttpOptions=" + allowUnauthenticatedHttpOptions +
", enableQueryContextAuthorization=" + authorizeQueryContextParams +
", unsecuredContextKeys=" + unsecuredContextKeys +
", securedContextKeys=" + securedContextKeys +
'}';
}
@ -163,6 +236,8 @@ public class AuthConfig
private List<String> unsecuredPaths;
private boolean allowUnauthenticatedHttpOptions;
private boolean authorizeQueryContextParams;
private Set<String> unsecuredContextKeys;
private Set<String> securedContextKeys;
public Builder setAuthenticatorChain(List<String> authenticatorChain)
{
@ -194,6 +269,18 @@ public class AuthConfig
return this;
}
public Builder setUnsecuredContextKeys(Set<String> unsecuredContextKeys)
{
this.unsecuredContextKeys = unsecuredContextKeys;
return this;
}
public Builder setSecuredContextKeys(Set<String> securedContextKeys)
{
this.securedContextKeys = securedContextKeys;
return this;
}
public AuthConfig build()
{
return new AuthConfig(
@ -201,7 +288,9 @@ public class AuthConfig
authorizers,
unsecuredPaths,
allowUnauthenticatedHttpOptions,
authorizeQueryContextParams
authorizeQueryContextParams,
unsecuredContextKeys,
securedContextKeys
);
}
}

View File

@ -19,12 +19,14 @@
package org.apache.druid.client;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Bytes;
import org.apache.druid.client.selector.QueryableDruidServer;
import org.apache.druid.client.selector.ServerSelector;
import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.planning.DataSourceAnalysis;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
@ -43,7 +45,6 @@ import org.junit.runner.RunWith;
import java.util.Optional;
import java.util.Set;
import static org.apache.druid.query.QueryContexts.DEFAULT_BY_SEGMENT;
import static org.easymock.EasyMock.expect;
import static org.easymock.EasyMock.replay;
import static org.easymock.EasyMock.reset;
@ -67,7 +68,7 @@ public class CachingClusteredClientCacheKeyManagerTest extends EasyMockSupport
public void setup()
{
expect(strategy.computeCacheKey(query)).andReturn(QUERY_CACHE_KEY).anyTimes();
expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(false).anyTimes();
expect(query.context()).andReturn(QueryContext.of(ImmutableMap.of(QueryContexts.BY_SEGMENT_KEY, false))).anyTimes();
}
@After
@ -203,7 +204,7 @@ public class CachingClusteredClientCacheKeyManagerTest extends EasyMockSupport
{
expect(dataSourceAnalysis.isJoin()).andReturn(false);
reset(query);
expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(true).anyTimes();
expect(query.context()).andReturn(QueryContext.of(ImmutableMap.of(QueryContexts.BY_SEGMENT_KEY, true))).anyTimes();
replayAll();
CachingClusteredClient.CacheKeyManager<Object> keyManager = makeKeyManager();
Set<SegmentServerSelector> selectors = ImmutableSet.of(
@ -272,7 +273,7 @@ public class CachingClusteredClientCacheKeyManagerTest extends EasyMockSupport
public void testSegmentQueryCacheKey_noCachingIfBySegment()
{
reset(query);
expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(true).anyTimes();
expect(query.context()).andReturn(QueryContext.of(ImmutableMap.of(QueryContexts.BY_SEGMENT_KEY, true))).anyTimes();
replayAll();
byte[] cacheKey = makeKeyManager().computeSegmentLevelQueryCacheKey();
Assert.assertNull(cacheKey);

View File

@ -72,6 +72,7 @@ import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.Druids;
import org.apache.druid.query.FinalizeResultsQueryRunner;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
@ -2297,12 +2298,13 @@ public class CachingClusteredClientTest
for (Capture queryCapture : queryCaptures) {
QueryPlus capturedQueryPlus = (QueryPlus) queryCapture.getValue();
Query capturedQuery = capturedQueryPlus.getQuery();
final QueryContext queryContext = capturedQuery.context();
if (expectBySegment) {
Assert.assertEquals(true, capturedQuery.getQueryContext().getAsBoolean(QueryContexts.BY_SEGMENT_KEY));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.BY_SEGMENT_KEY));
} else {
Assert.assertTrue(
capturedQuery.getContextValue(QueryContexts.BY_SEGMENT_KEY) == null ||
capturedQuery.getQueryContext().getAsBoolean(QueryContexts.BY_SEGMENT_KEY).equals(false)
queryContext.get(QueryContexts.BY_SEGMENT_KEY) == null ||
!queryContext.getBoolean(QueryContexts.BY_SEGMENT_KEY)
);
}
}

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.Futures;
import org.apache.druid.jackson.DefaultObjectMapper;
@ -309,13 +310,8 @@ public class JsonParserIteratorTest
Query<?> query = Mockito.mock(Query.class);
QueryContext context = Mockito.mock(QueryContext.class);
Mockito.when(query.getId()).thenReturn(queryId);
Mockito.when(query.getQueryContext()).thenReturn(context);
Mockito.when(
context.getAsLong(
ArgumentMatchers.eq(DirectDruidClient.QUERY_FAIL_TIME),
ArgumentMatchers.eq(-1L)
)
).thenReturn(timeoutAt);
Mockito.when(query.context()).thenReturn(
QueryContext.of(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, timeoutAt)));
return query;
}
}

View File

@ -119,8 +119,6 @@ public class UnifiedIndexerAppenderatorsManagerTest extends InitializedNullHandl
@Test
public void test_getBundle_knownDataSource()
{
final UnifiedIndexerAppenderatorsManager.DatasourceBundle bundle = manager.getBundle(
Druids.newScanQueryBuilder()
.dataSource(appenderator.getDataSource())

View File

@ -21,6 +21,7 @@ package org.apache.druid.server;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.guava.Sequences;
@ -55,6 +56,9 @@ import org.junit.rules.ExpectedException;
import javax.servlet.http.HttpServletRequest;
import java.util.HashMap;
import java.util.Map;
public class QueryLifecycleTest
{
private static final String DATASOURCE = "some_datasource";
@ -73,9 +77,6 @@ public class QueryLifecycleTest
RequestLogger requestLogger;
AuthorizerMapper authzMapper;
DefaultQueryConfig queryConfig;
AuthConfig authConfig;
QueryLifecycle lifecycle;
QueryToolChest toolChest;
QueryRunner runner;
@ -97,11 +98,18 @@ public class QueryLifecycleTest
authorizer = EasyMock.createMock(Authorizer.class);
authzMapper = new AuthorizerMapper(ImmutableMap.of(AUTHORIZER, authorizer));
queryConfig = EasyMock.createMock(DefaultQueryConfig.class);
authConfig = EasyMock.createMock(AuthConfig.class);
toolChest = EasyMock.createMock(QueryToolChest.class);
runner = EasyMock.createMock(QueryRunner.class);
metrics = EasyMock.createNiceMock(QueryMetrics.class);
authenticationResult = EasyMock.createMock(AuthenticationResult.class);
}
private QueryLifecycle createLifecycle(AuthConfig authConfig)
{
long nanos = System.nanoTime();
long millis = System.currentTimeMillis();
lifecycle = new QueryLifecycle(
return new QueryLifecycle(
toolChestWarehouse,
texasRanger,
metricsFactory,
@ -113,11 +121,6 @@ public class QueryLifecycleTest
millis,
nanos
);
toolChest = EasyMock.createMock(QueryToolChest.class);
runner = EasyMock.createMock(QueryRunner.class);
metrics = EasyMock.createNiceMock(QueryMetrics.class);
authenticationResult = EasyMock.createMock(AuthenticationResult.class);
}
@After
@ -151,9 +154,9 @@ public class QueryLifecycleTest
.once();
EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).once();
replayAll();
QueryLifecycle lifecycle = createLifecycle(new AuthConfig());
lifecycle.runSimple(query, authenticationResult, Access.OK);
}
@ -174,6 +177,7 @@ public class QueryLifecycleTest
replayAll();
QueryLifecycle lifecycle = createLifecycle(new AuthConfig());
lifecycle.runSimple(query, authenticationResult, new Access(false));
}
@ -181,7 +185,6 @@ public class QueryLifecycleTest
public void testAuthorizeQueryContext_authorized()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
@ -197,21 +200,27 @@ public class QueryLifecycleTest
replayAll();
final Map<String, Object> userContext = ImmutableMap.of("foo", "bar", "baz", "qux");
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE)
.intervals(ImmutableList.of(Intervals.ETERNITY))
.aggregators(new CountAggregatorFactory("chocula"))
.context(ImmutableMap.of("foo", "bar", "baz", "qux"))
.context(userContext)
.build();
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query);
Assert.assertEquals(
ImmutableMap.of("foo", "bar", "baz", "qux"),
lifecycle.getQuery().getQueryContext().getUserParams()
);
Assert.assertTrue(lifecycle.getQuery().getQueryContext().getMergedParams().containsKey("queryId"));
final Map<String, Object> revisedContext = new HashMap<>(lifecycle.getQuery().getContext());
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId"));
revisedContext.remove("queryId");
Assert.assertEquals(
userContext,
revisedContext
);
Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
}
@ -220,13 +229,12 @@ public class QueryLifecycleTest
public void testAuthorizeQueryContext_notAuthorized()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(new Access(false));
.andReturn(Access.DENIED);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
@ -241,6 +249,128 @@ public class QueryLifecycleTest
.context(ImmutableMap.of("foo", "bar"))
.build();
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query);
Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed());
}
@Test
public void testAuthorizeQueryContext_unsecuredKeys()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
.once();
replayAll();
final Map<String, Object> userContext = ImmutableMap.of("foo", "bar", "baz", "qux");
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE)
.intervals(ImmutableList.of(Intervals.ETERNITY))
.aggregators(new CountAggregatorFactory("chocula"))
.context(userContext)
.build();
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.setUnsecuredContextKeys(ImmutableSet.of("foo", "baz"))
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query);
final Map<String, Object> revisedContext = new HashMap<>(lifecycle.getQuery().getContext());
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId"));
revisedContext.remove("queryId");
Assert.assertEquals(
userContext,
revisedContext
);
Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
}
@Test
public void testAuthorizeQueryContext_securedKeys()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
.once();
replayAll();
final Map<String, Object> userContext = ImmutableMap.of("foo", "bar", "baz", "qux");
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE)
.intervals(ImmutableList.of(Intervals.ETERNITY))
.aggregators(new CountAggregatorFactory("chocula"))
.context(userContext)
.build();
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
// We have secured keys, just not what the user gave.
.setSecuredContextKeys(ImmutableSet.of("foo2", "baz2"))
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query);
final Map<String, Object> revisedContext = new HashMap<>(lifecycle.getQuery().getContext());
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId"));
revisedContext.remove("queryId");
Assert.assertEquals(
userContext,
revisedContext
);
Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
}
@Test
public void testAuthorizeQueryContext_securedKeysNotAuthorized()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.DENIED);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
.once();
replayAll();
final Map<String, Object> userContext = ImmutableMap.of("foo", "bar", "baz", "qux");
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE)
.intervals(ImmutableList.of(Intervals.ETERNITY))
.aggregators(new CountAggregatorFactory("chocula"))
.context(userContext)
.build();
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
// We have secured keys. User used one of them.
.setSecuredContextKeys(ImmutableSet.of("foo", "baz2"))
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query);
Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed());
}
@ -249,7 +379,6 @@ public class QueryLifecycleTest
public void testAuthorizeLegacyQueryContext_authorized()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("fake", ResourceType.DATASOURCE), Action.READ))
@ -257,9 +386,6 @@ public class QueryLifecycleTest
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("baz", ResourceType.QUERY_CONTEXT), Action.WRITE)).andReturn(Access.OK);
// to use legacy query context with context authorization, even system generated things like queryId need to be explicitly added
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("queryId", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.OK);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
@ -269,12 +395,17 @@ public class QueryLifecycleTest
final QueryContextTest.LegacyContextQuery query = new QueryContextTest.LegacyContextQuery(ImmutableMap.of("foo", "bar", "baz", "qux"));
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query);
Assert.assertNull(lifecycle.getQuery().getQueryContext());
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("foo"));
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("baz"));
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId"));
final Map<String, Object> revisedContext = lifecycle.getQuery().getContext();
Assert.assertNotNull(revisedContext);
Assert.assertTrue(revisedContext.containsKey("foo"));
Assert.assertTrue(revisedContext.containsKey("baz"));
Assert.assertTrue(revisedContext.containsKey("queryId"));
Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
}
@ -301,7 +432,6 @@ public class QueryLifecycleTest
emitter,
requestLogger,
queryConfig,
authConfig,
toolChest,
runner,
metrics,

View File

@ -48,7 +48,6 @@ import org.apache.druid.java.util.emitter.core.NoopEmitter;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.topn.TopNQuery;
@ -150,7 +149,7 @@ public class QuerySchedulerTest
try {
Query<?> scheduledReport = scheduler.prioritizeAndLaneQuery(QueryPlus.wrap(report), ImmutableSet.of());
Assert.assertNotNull(scheduledReport);
Assert.assertEquals(HiLoQueryLaningStrategy.LOW, QueryContexts.getLane(scheduledReport));
Assert.assertEquals(HiLoQueryLaningStrategy.LOW, scheduledReport.context().getLane());
Sequence<Integer> underlyingSequence = makeSequence(10);
underlyingSequence = Sequences.wrap(underlyingSequence, new SequenceWrapper()
@ -412,8 +411,8 @@ public class QuerySchedulerTest
EasyMock.createMock(SegmentServerSelector.class)
)
);
Assert.assertEquals(-5, QueryContexts.getPriority(query));
Assert.assertEquals(HiLoQueryLaningStrategy.LOW, QueryContexts.getLane(query));
Assert.assertEquals(-5, query.context().getPriority());
Assert.assertEquals(HiLoQueryLaningStrategy.LOW, query.context().getLane());
}
@Test

View File

@ -36,7 +36,6 @@ import org.junit.Test;
public class SetAndVerifyContextQueryRunnerTest
{
@Test
public void testTimeoutIsUsedIfTimeoutIsNonZero() throws InterruptedException
{
@ -58,7 +57,7 @@ public class SetAndVerifyContextQueryRunnerTest
// time + 1 at the time the method was called
// this means that after sleeping for 1 millis, the fail time should be less than the current time when checking
Assert.assertTrue(
System.currentTimeMillis() > transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME)
System.currentTimeMillis() > transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME)
);
}
@ -85,7 +84,7 @@ public class SetAndVerifyContextQueryRunnerTest
Query<ScanResultValue> transformed = queryRunner.withTimeoutAndMaxScatterGatherBytes(query, defaultConfig);
// timeout is not set, default timeout has been set to long.max, make sure timeout is still in the future
Assert.assertEquals((Long) Long.MAX_VALUE, transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME));
Assert.assertEquals(Long.MAX_VALUE, (long) transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME));
}
@Test
@ -107,7 +106,7 @@ public class SetAndVerifyContextQueryRunnerTest
// timeout is set to 0, so withTimeoutAndMaxScatterGatherBytes should set QUERY_FAIL_TIME to be the current
// time + max query timeout at the time the method was called
// since default is long max, expect long max since current time would overflow
Assert.assertEquals((Long) Long.MAX_VALUE, transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME));
Assert.assertEquals(Long.MAX_VALUE, (long) transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME));
}
@Test
@ -137,7 +136,7 @@ public class SetAndVerifyContextQueryRunnerTest
// time + max query timeout at the time the method was called
// this means that the fail time should be greater than the current time when checking
Assert.assertTrue(
System.currentTimeMillis() < (Long) transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME)
System.currentTimeMillis() < transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME)
);
}
}

View File

@ -19,9 +19,16 @@
package org.apache.druid.server.security;
import com.google.common.collect.ImmutableSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.query.QueryContexts;
import org.junit.Test;
import java.util.Set;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class AuthConfigTest
{
@Test
@ -29,4 +36,55 @@ public class AuthConfigTest
{
EqualsVerifier.configure().usingGetClass().forClass(AuthConfig.class).verify();
}
@Test
public void testContextSecurity()
{
// No security
{
AuthConfig config = new AuthConfig();
Set<String> keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID);
assertTrue(config.contextKeysToAuthorize(keys).isEmpty());
}
// Default security
{
AuthConfig config = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.build();
Set<String> keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID);
assertEquals(ImmutableSet.of("a", "b"), config.contextKeysToAuthorize(keys));
}
// Specify unsecured keys (white-list)
{
AuthConfig config = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.setUnsecuredContextKeys(ImmutableSet.of("a"))
.build();
Set<String> keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID);
assertEquals(ImmutableSet.of("b"), config.contextKeysToAuthorize(keys));
}
// Specify secured keys (black-list)
{
AuthConfig config = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.setSecuredContextKeys(ImmutableSet.of("a"))
.build();
Set<String> keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID);
assertEquals(ImmutableSet.of("a"), config.contextKeysToAuthorize(keys));
}
// Specify both
{
AuthConfig config = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.setUnsecuredContextKeys(ImmutableSet.of("a", "b"))
.setSecuredContextKeys(ImmutableSet.of("b", "c"))
.build();
Set<String> keys = ImmutableSet.of("a", "b", "c", "d", QueryContexts.CTX_SQL_QUERY_ID);
assertEquals(ImmutableSet.of("c"), config.contextKeysToAuthorize(keys));
}
}
}

Some files were not shown because too many files have changed in this diff Show More