Redesign QueryContext class (#13071)

We introduce two new configuration keys that refine the query context security model controlled by druid.auth.authorizeQueryContextParams. When that value is set to true then two other configuration options become available:

druid.auth.unsecuredContextKeys: The set of query context keys that do not require a security check. Use this for the "white-list" of key to allow. All other keys go through the existing context key security checks.
druid.auth.securedContextKeys: The set of query context keys that do require a security check. Use this when you want to allow all but a specific set of keys: only these keys go through the existing context key security checks.
Both are set using JSON list format:

druid.auth.securedContextKeys=["secretKey1", "secretKey2"]
You generally set one or the other values. If both are set, unsecuredContextKeys acts as exceptions to securedContextKeys.

In addition, Druid defines two query context keys which always bypass checks because Druid uses them internally:

sqlQueryId
sqlStringifyArrays
This commit is contained in:
Paul Rogers 2022-10-14 22:32:11 -07:00 committed by GitHub
parent 6332c571bd
commit f4dcc52dac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
147 changed files with 2312 additions and 1710 deletions

View File

@ -29,7 +29,6 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.aggregation.datasketches.hll.sql.HllSketchApproxCountDistinctSqlAggregator; import org.apache.druid.query.aggregation.datasketches.hll.sql.HllSketchApproxCountDistinctSqlAggregator;
@ -516,7 +515,7 @@ public class SqlBenchmark
QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize
); );
final String sql = QUERIES.get(Integer.parseInt(query)); final String sql = QUERIES.get(Integer.parseInt(query));
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) { try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) {
final PlannerResult plannerResult = planner.plan(); final PlannerResult plannerResult = planner.plan();
final Sequence<Object[]> resultSequence = plannerResult.run().getResults(); final Sequence<Object[]> resultSequence = plannerResult.run().getResults();
final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in);
@ -534,7 +533,7 @@ public class SqlBenchmark
QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize
); );
final String sql = QUERIES.get(Integer.parseInt(query)); final String sql = QUERIES.get(Integer.parseInt(query));
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) { try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) {
final PlannerResult plannerResult = planner.plan(); final PlannerResult plannerResult = planner.plan();
blackhole.consume(plannerResult); blackhole.consume(plannerResult);
} }

View File

@ -29,7 +29,6 @@ import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.ExpressionProcessing; import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.QueryableIndex;
@ -352,7 +351,7 @@ public class SqlExpressionBenchmark
QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize
); );
final String sql = QUERIES.get(Integer.parseInt(query)); final String sql = QUERIES.get(Integer.parseInt(query));
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) { try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) {
final PlannerResult plannerResult = planner.plan(); final PlannerResult plannerResult = planner.plan();
final Sequence<Object[]> resultSequence = plannerResult.run().getResults(); final Sequence<Object[]> resultSequence = plannerResult.run().getResults();
final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in);

View File

@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.ExpressionProcessing; import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.expression.TestExprMacroTable;
@ -318,7 +317,7 @@ public class SqlNestedDataBenchmark
QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize
); );
final String sql = QUERIES.get(Integer.parseInt(query)); final String sql = QUERIES.get(Integer.parseInt(query));
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) { try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) {
final PlannerResult plannerResult = planner.plan(); final PlannerResult plannerResult = planner.plan();
final Sequence<Object[]> resultSequence = plannerResult.run().getResults(); final Sequence<Object[]> resultSequence = plannerResult.run().getResults();
final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in);

View File

@ -26,7 +26,6 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory;
@ -66,6 +65,7 @@ import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup; import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole; import org.openjdk.jmh.infra.Blackhole;
import java.util.Collections;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
/** /**
@ -167,7 +167,7 @@ public class SqlVsNativeBenchmark
@OutputTimeUnit(TimeUnit.MILLISECONDS) @OutputTimeUnit(TimeUnit.MILLISECONDS)
public void queryPlanner(Blackhole blackhole) throws Exception public void queryPlanner(Blackhole blackhole) throws Exception
{ {
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sqlQuery, new QueryContext())) { try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sqlQuery, Collections.emptyMap())) {
final PlannerResult plannerResult = planner.plan(); final PlannerResult plannerResult = planner.plan();
final Sequence<Object[]> resultSequence = plannerResult.run().getResults(); final Sequence<Object[]> resultSequence = plannerResult.run().getResults();
final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in);

View File

@ -29,9 +29,11 @@ import java.util.AbstractCollection;
import java.util.Collection; import java.util.Collection;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.Spliterator; import java.util.Spliterator;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.function.Function; import java.util.function.Function;
@ -148,6 +150,40 @@ public final class CollectionUtils
return list == null || list.isEmpty(); return list == null || list.isEmpty();
} }
/**
* Subtract one set from another: {@code C = A - B}.
*/
public static <T> Set<T> subtract(Set<T> left, Set<T> right)
{
Set<T> result = new HashSet<>(left);
result.removeAll(right);
return result;
}
/**
* Intersection of two sets: {@code C = A B}.
*/
public static <T> Set<T> intersect(Set<T> left, Set<T> right)
{
Set<T> result = new HashSet<>();
for (T key : left) {
if (right.contains(key)) {
result.add(key);
}
}
return result;
}
/**
* Intersection of two sets: {@code C = A B}.
*/
public static <T> Set<T> union(Set<T> left, Set<T> right)
{
Set<T> result = new HashSet<>(left);
result.addAll(right);
return result;
}
private CollectionUtils() private CollectionUtils()
{ {
} }

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.utils;
import com.google.common.collect.ImmutableSet;
import org.junit.Test;
import java.util.Set;
import static org.junit.Assert.assertEquals;
public class CollectionUtilsTest
{
// When Java 9 is allowed, use Set.of().
Set<String> empty = ImmutableSet.of();
Set<String> abc = ImmutableSet.of("a", "b", "c");
Set<String> bcd = ImmutableSet.of("b", "c", "d");
Set<String> efg = ImmutableSet.of("e", "f", "g");
@Test
public void testSubtract()
{
assertEquals(empty, CollectionUtils.subtract(empty, empty));
assertEquals(abc, CollectionUtils.subtract(abc, empty));
assertEquals(empty, CollectionUtils.subtract(abc, abc));
assertEquals(abc, CollectionUtils.subtract(abc, efg));
assertEquals(ImmutableSet.of("a"), CollectionUtils.subtract(abc, bcd));
}
@Test
public void testIntersect()
{
assertEquals(empty, CollectionUtils.intersect(empty, empty));
assertEquals(abc, CollectionUtils.intersect(abc, abc));
assertEquals(empty, CollectionUtils.intersect(abc, efg));
assertEquals(ImmutableSet.of("b", "c"), CollectionUtils.intersect(abc, bcd));
}
@Test
public void testUnion()
{
assertEquals(empty, CollectionUtils.union(empty, empty));
assertEquals(abc, CollectionUtils.union(abc, abc));
assertEquals(ImmutableSet.of("a", "b", "c", "e", "f", "g"), CollectionUtils.union(abc, efg));
assertEquals(ImmutableSet.of("a", "b", "c", "d"), CollectionUtils.union(abc, bcd));
}
}

View File

@ -28,7 +28,6 @@ import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.BaseQuery; import org.apache.druid.query.BaseQuery;
import org.apache.druid.query.DataSource; import org.apache.druid.query.DataSource;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.QuerySegmentWalker;
import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.DimFilter;
@ -41,6 +40,7 @@ import org.joda.time.Duration;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -146,12 +146,6 @@ public class MaterializedViewQuery<T> implements Query<T>
return query.getContext(); return query.getContext();
} }
@Override
public QueryContext getQueryContext()
{
return query.getQueryContext();
}
@Override @Override
public boolean isDescending() public boolean isDescending()
{ {

View File

@ -121,7 +121,6 @@ public class MaterializedViewQueryTest
.postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT) .postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT)
.build(); .build();
MaterializedViewQuery query = new MaterializedViewQuery(topNQuery, optimizer); MaterializedViewQuery query = new MaterializedViewQuery(topNQuery, optimizer);
Assert.assertEquals(20_000_000, query.getContextAsHumanReadableBytes("maxOnDiskStorage", HumanReadableBytes.ZERO).getBytes()); Assert.assertEquals(20_000_000, query.context().getHumanReadableBytes("maxOnDiskStorage", HumanReadableBytes.ZERO).getBytes());
} }
} }

View File

@ -237,7 +237,7 @@ public class MovingAverageQuery extends BaseQuery<Row>
@JsonIgnore @JsonIgnore
public boolean getContextSortByDimsFirst() public boolean getContextSortByDimsFirst()
{ {
return getContextBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false); return context().getBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false);
} }
@Override @Override

View File

@ -30,7 +30,6 @@ import org.apache.druid.java.util.common.granularity.PeriodGranularity;
import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.DataSource; import org.apache.druid.query.DataSource;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
@ -52,6 +51,7 @@ import org.joda.time.Interval;
import org.joda.time.Period; import org.joda.time.Period;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@ -124,7 +124,7 @@ public class MovingAverageQueryRunner implements QueryRunner<Row>
ResponseContext gbqResponseContext = ResponseContext.createEmpty(); ResponseContext gbqResponseContext = ResponseContext.createEmpty();
gbqResponseContext.merge(responseContext); gbqResponseContext.merge(responseContext);
gbqResponseContext.putQueryFailDeadlineMs( gbqResponseContext.putQueryFailDeadlineMs(
System.currentTimeMillis() + QueryContexts.getTimeout(gbq) System.currentTimeMillis() + gbq.context().getTimeout()
); );
Sequence<ResultRow> results = gbq.getRunner(walker).run(QueryPlus.wrap(gbq), gbqResponseContext); Sequence<ResultRow> results = gbq.getRunner(walker).run(QueryPlus.wrap(gbq), gbqResponseContext);
@ -164,7 +164,7 @@ public class MovingAverageQueryRunner implements QueryRunner<Row>
ResponseContext tsqResponseContext = ResponseContext.createEmpty(); ResponseContext tsqResponseContext = ResponseContext.createEmpty();
tsqResponseContext.merge(responseContext); tsqResponseContext.merge(responseContext);
tsqResponseContext.putQueryFailDeadlineMs( tsqResponseContext.putQueryFailDeadlineMs(
System.currentTimeMillis() + QueryContexts.getTimeout(tsq) System.currentTimeMillis() + tsq.context().getTimeout()
); );
Sequence<Result<TimeseriesResultValue>> results = tsq.getRunner(walker).run(QueryPlus.wrap(tsq), tsqResponseContext); Sequence<Result<TimeseriesResultValue>> results = tsq.getRunner(walker).run(QueryPlus.wrap(tsq), tsqResponseContext);

View File

@ -49,6 +49,7 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
@ -171,7 +172,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
histogramName, histogramName,
input.getDirectColumn(), input.getDirectColumn(),
k, k,
getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext()) getMaxStreamLengthFromQueryContext(plannerContext.queryContext())
); );
} else { } else {
String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression( String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(
@ -182,7 +183,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
histogramName, histogramName,
virtualColumnName, virtualColumnName,
k, k,
getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext()) getMaxStreamLengthFromQueryContext(plannerContext.queryContext())
); );
} }
@ -201,7 +202,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
static long getMaxStreamLengthFromQueryContext(QueryContext queryContext) static long getMaxStreamLengthFromQueryContext(QueryContext queryContext)
{ {
return queryContext.getAsLong( return queryContext.getLong(
CTX_APPROX_QUANTILE_DS_MAX_STREAM_LENGTH, CTX_APPROX_QUANTILE_DS_MAX_STREAM_LENGTH,
DoublesSketchAggregatorFactory.DEFAULT_MAX_STREAM_LENGTH DoublesSketchAggregatorFactory.DEFAULT_MAX_STREAM_LENGTH
); );

View File

@ -46,6 +46,7 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
public class DoublesSketchObjectSqlAggregator implements SqlAggregator public class DoublesSketchObjectSqlAggregator implements SqlAggregator
@ -113,7 +114,7 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
histogramName, histogramName,
input.getDirectColumn(), input.getDirectColumn(),
k, k,
DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext()) DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.queryContext())
); );
} else { } else {
String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression( String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(
@ -124,7 +125,7 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
histogramName, histogramName,
virtualColumnName, virtualColumnName,
k, k,
DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext()) DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.queryContext())
); );
} }
@ -136,7 +137,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
private static class DoublesSketchSqlAggFunction extends SqlAggFunction private static class DoublesSketchSqlAggFunction extends SqlAggFunction
{ {
private static final String SIGNATURE1 = "'" + NAME + "(column)'\n";
private static final String SIGNATURE2 = "'" + NAME + "(column, k)'\n"; private static final String SIGNATURE2 = "'" + NAME + "(column, k)'\n";
DoublesSketchSqlAggFunction() DoublesSketchSqlAggFunction()

View File

@ -27,6 +27,7 @@ import com.google.common.collect.Iterables;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids; import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
@ -53,7 +54,6 @@ import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFacto
import org.apache.druid.sql.calcite.BaseCalciteQueryTest; import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.filtration.Filtration; import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.DruidOperatorTable; import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.util.CalciteTests; import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker; import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.DataSegment;
@ -324,7 +324,7 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ
new QuantilePostAggregator("a6", "a6:agg", 0.999f), new QuantilePostAggregator("a6", "a6:agg", 0.999f),
new QuantilePostAggregator("a7", "a5:agg", 0.999f) new QuantilePostAggregator("a7", "a5:agg", 0.999f)
) )
.context(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy")) .context(ImmutableMap.of(QueryContexts.CTX_SQL_QUERY_ID, "dummy"))
.build() .build()
), ),
ImmutableList.of( ImmutableList.of(

View File

@ -518,7 +518,7 @@ public class ControllerImpl implements Controller
closer.register(netClient::close); closer.register(netClient::close);
final boolean isDurableStorageEnabled = final boolean isDurableStorageEnabled =
MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().getContext()); MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context());
final QueryDefinition queryDef = makeQueryDefinition( final QueryDefinition queryDef = makeQueryDefinition(
id(), id(),
@ -1191,7 +1191,7 @@ public class ControllerImpl implements Controller
final InputChannelFactory inputChannelFactory; final InputChannelFactory inputChannelFactory;
if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().getContext())) { if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context())) {
inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation( inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation(
id(), id(),
() -> taskIds, () -> taskIds,
@ -1294,7 +1294,7 @@ public class ControllerImpl implements Controller
*/ */
private void cleanUpDurableStorageIfNeeded() private void cleanUpDurableStorageIfNeeded()
{ {
if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().getContext())) { if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context())) {
final String controllerDirName = DurableStorageOutputChannelFactory.getControllerDirectory(task.getId()); final String controllerDirName = DurableStorageOutputChannelFactory.getControllerDirectory(task.getId());
try { try {
// Delete all temporary files as a failsafe // Delete all temporary files as a failsafe
@ -1454,7 +1454,7 @@ public class ControllerImpl implements Controller
) )
{ {
if (isRollupQuery) { if (isRollupQuery) {
final String queryGranularity = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY, ""); final String queryGranularity = query.context().getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY, "");
if (timeIsGroupByDimension((GroupByQuery) query, columnMappings) && !queryGranularity.isEmpty()) { if (timeIsGroupByDimension((GroupByQuery) query, columnMappings) && !queryGranularity.isEmpty()) {
return new ArbitraryGranularitySpec( return new ArbitraryGranularitySpec(
@ -1483,7 +1483,7 @@ public class ControllerImpl implements Controller
{ {
if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) { if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) {
final String queryTimeColumn = columnMappings.getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME); final String queryTimeColumn = columnMappings.getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME);
return queryTimeColumn.equals(groupByQuery.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD)); return queryTimeColumn.equals(groupByQuery.context().getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD));
} else { } else {
return false; return false;
} }
@ -1505,8 +1505,8 @@ public class ControllerImpl implements Controller
private static boolean isRollupQuery(Query<?> query) private static boolean isRollupQuery(Query<?> query)
{ {
return query instanceof GroupByQuery return query instanceof GroupByQuery
&& !MultiStageQueryContext.isFinalizeAggregations(query.getQueryContext()) && !MultiStageQueryContext.isFinalizeAggregations(query.context())
&& !query.getContextBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true); && !query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true);
} }
private static boolean isInlineResults(final MSQSpec querySpec) private static boolean isInlineResults(final MSQSpec querySpec)

View File

@ -106,6 +106,7 @@ import org.apache.druid.msq.util.DecoratedExecutorService;
import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.PrioritizedCallable; import org.apache.druid.query.PrioritizedCallable;
import org.apache.druid.query.PrioritizedRunnable; import org.apache.druid.query.PrioritizedRunnable;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryProcessingPool;
import org.apache.druid.server.DruidNode; import org.apache.druid.server.DruidNode;
@ -177,7 +178,9 @@ public class WorkerImpl implements Worker
this.context = context; this.context = context;
this.selfDruidNode = context.selfNode(); this.selfDruidNode = context.selfNode();
this.processorBouncer = context.processorBouncer(); this.processorBouncer = context.processorBouncer();
this.durableStageStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(task.getContext()); this.durableStageStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(
QueryContext.of(task.getContext())
);
} }
@Override @Override

View File

@ -106,7 +106,7 @@ public class MSQControllerTask extends AbstractTask
this.sqlQueryContext = sqlQueryContext; this.sqlQueryContext = sqlQueryContext;
this.sqlTypeNames = sqlTypeNames; this.sqlTypeNames = sqlTypeNames;
if (MultiStageQueryContext.isDurableStorageEnabled(querySpec.getQuery().getContext())) { if (MultiStageQueryContext.isDurableStorageEnabled(querySpec.getQuery().context())) {
this.remoteFetchExecutorService = this.remoteFetchExecutorService =
Executors.newCachedThreadPool(Execs.makeThreadFactory(getId() + "-remote-fetcher-%d")); Executors.newCachedThreadPool(Execs.makeThreadFactory(getId() + "-remote-fetcher-%d"));
} else { } else {

View File

@ -191,7 +191,7 @@ public class QueryKitUtils
public static VirtualColumn makeSegmentGranularityVirtualColumn(final Query<?> query) public static VirtualColumn makeSegmentGranularityVirtualColumn(final Query<?> query)
{ {
final Granularity segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext(query.getContext()); final Granularity segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext(query.getContext());
final String timeColumnName = query.getQueryContext().getAsString(QueryKitUtils.CTX_TIME_COLUMN_NAME); final String timeColumnName = query.context().getString(QueryKitUtils.CTX_TIME_COLUMN_NAME);
if (timeColumnName == null || Granularities.ALL.equals(segmentGranularity)) { if (timeColumnName == null || Granularities.ALL.equals(segmentGranularity)) {
return null; return null;

View File

@ -37,7 +37,6 @@ import org.apache.druid.msq.querykit.ShuffleSpecFactories;
import org.apache.druid.msq.querykit.ShuffleSpecFactory; import org.apache.druid.msq.querykit.ShuffleSpecFactory;
import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory; import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.having.AlwaysHavingSpec; import org.apache.druid.query.groupby.having.AlwaysHavingSpec;
@ -205,7 +204,7 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
*/ */
static boolean isFinalize(final GroupByQuery query) static boolean isFinalize(final GroupByQuery query)
{ {
return QueryContexts.isFinalize(query, true); return query.context().isFinalize(true);
} }
/** /**

View File

@ -57,7 +57,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
{ {
RowSignature scanSignature; RowSignature scanSignature;
try { try {
final String s = scanQuery.getQueryContext().getAsString(DruidQuery.CTX_SCAN_SIGNATURE); final String s = scanQuery.context().getString(DruidQuery.CTX_SCAN_SIGNATURE);
scanSignature = jsonMapper.readValue(s, RowSignature.class); scanSignature = jsonMapper.readValue(s, RowSignature.class);
} }
catch (JsonProcessingException e) { catch (JsonProcessingException e) {
@ -74,7 +74,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
* 2. This is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec * 2. This is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec
*/ */
// No ordering, but there is a limit or an offset. These work by funneling everything through a single partition. // No ordering, but there is a limit or an offset. These work by funneling everything through a single partition.
// So there is no point in forcing any particular partitioning. Since everything is funnelled into a single // So there is no point in forcing any particular partitioning. Since everything is funneled into a single
// partition without a ClusterBy, we don't need to necessarily create it via the resultShuffleSpecFactory provided // partition without a ClusterBy, we don't need to necessarily create it via the resultShuffleSpecFactory provided
@Override @Override
public QueryDefinition makeQueryDefinition( public QueryDefinition makeQueryDefinition(

View File

@ -23,9 +23,10 @@ import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.indexing.error.MSQWarnings; import org.apache.druid.msq.indexing.error.MSQWarnings;
import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -64,7 +65,7 @@ public enum MSQMode
return value; return value;
} }
public static void populateDefaultQueryContext(final String modeStr, final QueryContext originalQueryContext) public static void populateDefaultQueryContext(final String modeStr, final Map<String, Object> originalQueryContext)
{ {
MSQMode mode = MSQMode.fromString(modeStr); MSQMode mode = MSQMode.fromString(modeStr);
if (mode == null) { if (mode == null) {
@ -74,8 +75,7 @@ public enum MSQMode
Arrays.stream(MSQMode.values()).map(m -> m.value).collect(Collectors.toList()) Arrays.stream(MSQMode.values()).map(m -> m.value).collect(Collectors.toList())
); );
} }
Map<String, Object> defaultQueryContext = mode.defaultQueryContext; log.debug("Populating default query context with %s for the %s multi stage query mode", mode.defaultQueryContext, mode);
log.debug("Populating default query context with %s for the %s multi stage query mode", defaultQueryContext, mode); QueryContexts.addDefaults(originalQueryContext, mode.defaultQueryContext);
originalQueryContext.addDefaultParams(defaultQueryContext);
} }
} }

View File

@ -42,6 +42,7 @@ import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.indexing.MSQTuningConfig; import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.TaskReportMSQDestination; import org.apache.druid.msq.indexing.TaskReportMSQDestination;
import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.rpc.indexing.OverlordClient;
@ -59,6 +60,7 @@ import org.apache.druid.sql.calcite.table.RowSignatures;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
@ -109,17 +111,18 @@ public class MSQTaskQueryMaker implements QueryMaker
{ {
String taskId = MSQTasks.controllerTaskId(plannerContext.getSqlQueryId()); String taskId = MSQTasks.controllerTaskId(plannerContext.getSqlQueryId());
String msqMode = MultiStageQueryContext.getMSQMode(plannerContext.getQueryContext()); QueryContext queryContext = plannerContext.queryContext();
String msqMode = MultiStageQueryContext.getMSQMode(queryContext);
if (msqMode != null) { if (msqMode != null) {
MSQMode.populateDefaultQueryContext(msqMode, plannerContext.getQueryContext()); MSQMode.populateDefaultQueryContext(msqMode, plannerContext.queryContextMap());
} }
final String ctxDestination = final String ctxDestination =
DimensionHandlerUtils.convertObjectToString(MultiStageQueryContext.getDestination(plannerContext.getQueryContext())); DimensionHandlerUtils.convertObjectToString(MultiStageQueryContext.getDestination(queryContext));
Object segmentGranularity; Object segmentGranularity;
try { try {
segmentGranularity = Optional.ofNullable(plannerContext.getQueryContext() segmentGranularity = Optional.ofNullable(plannerContext.queryContext()
.get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) .get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY))
.orElse(jsonMapper.writeValueAsString(DEFAULT_SEGMENT_GRANULARITY)); .orElse(jsonMapper.writeValueAsString(DEFAULT_SEGMENT_GRANULARITY));
} }
@ -128,7 +131,7 @@ public class MSQTaskQueryMaker implements QueryMaker
+ "segment graularity"); + "segment graularity");
} }
final int maxNumTasks = MultiStageQueryContext.getMaxNumTasks(plannerContext.getQueryContext()); final int maxNumTasks = MultiStageQueryContext.getMaxNumTasks(queryContext);
if (maxNumTasks < 2) { if (maxNumTasks < 2) {
throw new IAE(MultiStageQueryContext.CTX_MAX_NUM_TASKS throw new IAE(MultiStageQueryContext.CTX_MAX_NUM_TASKS
@ -139,19 +142,19 @@ public class MSQTaskQueryMaker implements QueryMaker
final int maxNumWorkers = maxNumTasks - 1; final int maxNumWorkers = maxNumTasks - 1;
final int rowsPerSegment = MultiStageQueryContext.getRowsPerSegment( final int rowsPerSegment = MultiStageQueryContext.getRowsPerSegment(
plannerContext.getQueryContext(), queryContext,
DEFAULT_ROWS_PER_SEGMENT DEFAULT_ROWS_PER_SEGMENT
); );
final int maxRowsInMemory = MultiStageQueryContext.getRowsInMemory( final int maxRowsInMemory = MultiStageQueryContext.getRowsInMemory(
plannerContext.getQueryContext(), queryContext,
DEFAULT_ROWS_IN_MEMORY DEFAULT_ROWS_IN_MEMORY
); );
final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(plannerContext.getQueryContext()); final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(queryContext);
final List<Interval> replaceTimeChunks = final List<Interval> replaceTimeChunks =
Optional.ofNullable(plannerContext.getQueryContext().get(DruidSqlReplace.SQL_REPLACE_TIME_CHUNKS)) Optional.ofNullable(plannerContext.queryContext().get(DruidSqlReplace.SQL_REPLACE_TIME_CHUNKS))
.map( .map(
s -> { s -> {
if (s instanceof String && "all".equals(StringUtils.toLowerCase((String) s))) { if (s instanceof String && "all".equals(StringUtils.toLowerCase((String) s))) {
@ -213,7 +216,7 @@ public class MSQTaskQueryMaker implements QueryMaker
} }
final List<String> segmentSortOrder = MultiStageQueryContext.decodeSortOrder( final List<String> segmentSortOrder = MultiStageQueryContext.decodeSortOrder(
MultiStageQueryContext.getSortOrder(plannerContext.getQueryContext()) MultiStageQueryContext.getSortOrder(queryContext)
); );
validateSegmentSortOrder( validateSegmentSortOrder(
@ -245,7 +248,7 @@ public class MSQTaskQueryMaker implements QueryMaker
.query(druidQuery.getQuery().withOverriddenContext(nativeQueryContextOverrides)) .query(druidQuery.getQuery().withOverriddenContext(nativeQueryContextOverrides))
.columnMappings(new ColumnMappings(columnMappings)) .columnMappings(new ColumnMappings(columnMappings))
.destination(destination) .destination(destination)
.assignmentStrategy(MultiStageQueryContext.getAssignmentStrategy(plannerContext.getQueryContext())) .assignmentStrategy(MultiStageQueryContext.getAssignmentStrategy(queryContext))
.tuningConfig(new MSQTuningConfig(maxNumWorkers, maxRowsInMemory, rowsPerSegment)) .tuningConfig(new MSQTuningConfig(maxNumWorkers, maxRowsInMemory, rowsPerSegment))
.build(); .build();
@ -253,7 +256,7 @@ public class MSQTaskQueryMaker implements QueryMaker
taskId, taskId,
querySpec, querySpec,
plannerContext.getSql(), plannerContext.getSql(),
plannerContext.getQueryContext().getMergedParams(), plannerContext.queryContextMap(),
sqlTypeNames, sqlTypeNames,
null null
); );

View File

@ -38,7 +38,6 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.msq.querykit.QueryKitUtils; import org.apache.druid.msq.querykit.QueryKitUtils;
import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.QueryContext;
import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.rpc.indexing.OverlordClient;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
@ -52,6 +51,7 @@ import org.apache.druid.sql.calcite.run.SqlEngines;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
public class MSQTaskSqlEngine implements SqlEngine public class MSQTaskSqlEngine implements SqlEngine
@ -86,7 +86,7 @@ public class MSQTaskSqlEngine implements SqlEngine
} }
@Override @Override
public void validateContext(QueryContext queryContext) throws ValidationException public void validateContext(Map<String, Object> queryContext) throws ValidationException
{ {
SqlEngines.validateNoSpecialContextKeys(queryContext, SYSTEM_CONTEXT_PARAMETERS); SqlEngines.validateNoSpecialContextKeys(queryContext, SYSTEM_CONTEXT_PARAMETERS);
} }
@ -166,7 +166,7 @@ public class MSQTaskSqlEngine implements SqlEngine
{ {
validateNoDuplicateAliases(fieldMappings); validateNoDuplicateAliases(fieldMappings);
if (plannerContext.getQueryContext().containsKey(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) { if (plannerContext.queryContext().containsKey(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) {
throw new ValidationException( throw new ValidationException(
StringUtils.format("Cannot use \"%s\" without INSERT", DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY) StringUtils.format("Cannot use \"%s\" without INSERT", DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)
); );
@ -207,14 +207,14 @@ public class MSQTaskSqlEngine implements SqlEngine
try { try {
segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext( segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext(
plannerContext.getQueryContext().getMergedParams() plannerContext.queryContextMap()
); );
} }
catch (Exception e) { catch (Exception e) {
throw new ValidationException( throw new ValidationException(
StringUtils.format( StringUtils.format(
"Invalid segmentGranularity: %s", "Invalid segmentGranularity: %s",
plannerContext.getQueryContext().get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY) plannerContext.queryContext().get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)
), ),
e e
); );

View File

@ -26,12 +26,12 @@ import com.google.common.annotations.VisibleForTesting;
import com.opencsv.RFC4180Parser; import com.opencsv.RFC4180Parser;
import com.opencsv.RFC4180ParserBuilder; import com.opencsv.RFC4180ParserBuilder;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy; import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.sql.MSQMode; import org.apache.druid.msq.sql.MSQMode;
import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContext;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -59,7 +59,7 @@ public class MultiStageQueryContext
private static final boolean DEFAULT_FINALIZE_AGGREGATIONS = true; private static final boolean DEFAULT_FINALIZE_AGGREGATIONS = true;
public static final String CTX_ENABLE_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage"; public static final String CTX_ENABLE_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage";
private static final String DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = "false"; private static final boolean DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = false;
public static final String CTX_DESTINATION = "destination"; public static final String CTX_DESTINATION = "destination";
private static final String DEFAULT_DESTINATION = null; private static final String DEFAULT_DESTINATION = null;
@ -77,48 +77,34 @@ public class MultiStageQueryContext
private static final Pattern LOOKS_LIKE_JSON_ARRAY = Pattern.compile("^\\s*\\[.*", Pattern.DOTALL); private static final Pattern LOOKS_LIKE_JSON_ARRAY = Pattern.compile("^\\s*\\[.*", Pattern.DOTALL);
public static String getMSQMode(QueryContext queryContext) public static String getMSQMode(final QueryContext queryContext)
{ {
return (String) MultiStageQueryContext.getValueFromPropertyMap( return queryContext.getString(
queryContext.getMergedParams(),
CTX_MSQ_MODE, CTX_MSQ_MODE,
null,
DEFAULT_MSQ_MODE DEFAULT_MSQ_MODE
); );
} }
public static boolean isDurableStorageEnabled(Map<String, Object> propertyMap) public static boolean isDurableStorageEnabled(final QueryContext queryContext)
{ {
return Boolean.parseBoolean( return queryContext.getBoolean(
String.valueOf( CTX_ENABLE_DURABLE_SHUFFLE_STORAGE,
getValueFromPropertyMap( DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE
propertyMap,
CTX_ENABLE_DURABLE_SHUFFLE_STORAGE,
null,
DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE
)
)
); );
} }
public static boolean isFinalizeAggregations(final QueryContext queryContext) public static boolean isFinalizeAggregations(final QueryContext queryContext)
{ {
return Numbers.parseBoolean( return queryContext.getBoolean(
getValueFromPropertyMap( CTX_FINALIZE_AGGREGATIONS,
queryContext.getMergedParams(), DEFAULT_FINALIZE_AGGREGATIONS
CTX_FINALIZE_AGGREGATIONS,
null,
DEFAULT_FINALIZE_AGGREGATIONS
)
); );
} }
public static WorkerAssignmentStrategy getAssignmentStrategy(final QueryContext queryContext) public static WorkerAssignmentStrategy getAssignmentStrategy(final QueryContext queryContext)
{ {
String assignmentStrategyString = (String) getValueFromPropertyMap( String assignmentStrategyString = queryContext.getString(
queryContext.getMergedParams(),
CTX_TASK_ASSIGNMENT_STRATEGY, CTX_TASK_ASSIGNMENT_STRATEGY,
null,
DEFAULT_TASK_ASSIGNMENT_STRATEGY DEFAULT_TASK_ASSIGNMENT_STRATEGY
); );
@ -127,47 +113,33 @@ public class MultiStageQueryContext
public static int getMaxNumTasks(final QueryContext queryContext) public static int getMaxNumTasks(final QueryContext queryContext)
{ {
return Numbers.parseInt( return queryContext.getInt(
getValueFromPropertyMap( CTX_MAX_NUM_TASKS,
queryContext.getMergedParams(), DEFAULT_MAX_NUM_TASKS
CTX_MAX_NUM_TASKS,
null,
DEFAULT_MAX_NUM_TASKS
)
); );
} }
public static Object getDestination(final QueryContext queryContext) public static Object getDestination(final QueryContext queryContext)
{ {
return getValueFromPropertyMap( return queryContext.get(
queryContext.getMergedParams(),
CTX_DESTINATION, CTX_DESTINATION,
null,
DEFAULT_DESTINATION DEFAULT_DESTINATION
); );
} }
public static int getRowsPerSegment(final QueryContext queryContext, int defaultRowsPerSegment) public static int getRowsPerSegment(final QueryContext queryContext, int defaultRowsPerSegment)
{ {
return Numbers.parseInt( return queryContext.getInt(
getValueFromPropertyMap( CTX_ROWS_PER_SEGMENT,
queryContext.getMergedParams(), defaultRowsPerSegment
CTX_ROWS_PER_SEGMENT,
null,
defaultRowsPerSegment
)
); );
} }
public static int getRowsInMemory(final QueryContext queryContext, int defaultRowsInMemory) public static int getRowsInMemory(final QueryContext queryContext, int defaultRowsInMemory)
{ {
return Numbers.parseInt( return queryContext.getInt(
getValueFromPropertyMap( CTX_ROWS_IN_MEMORY,
queryContext.getMergedParams(), defaultRowsInMemory
CTX_ROWS_IN_MEMORY,
null,
defaultRowsInMemory
)
); );
} }
@ -196,10 +168,8 @@ public class MultiStageQueryContext
public static String getSortOrder(final QueryContext queryContext) public static String getSortOrder(final QueryContext queryContext)
{ {
return (String) getValueFromPropertyMap( return queryContext.getString(
queryContext.getMergedParams(),
CTX_SORT_ORDER, CTX_SORT_ORDER,
null,
DEFAULT_SORT_ORDER DEFAULT_SORT_ORDER
); );
} }

View File

@ -22,34 +22,36 @@ package org.apache.druid.msq.sql;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.indexing.error.MSQWarnings; import org.apache.druid.msq.indexing.error.MSQWarnings;
import org.apache.druid.query.QueryContext;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.util.HashMap;
import java.util.Map;
public class MSQModeTest public class MSQModeTest
{ {
@Test @Test
public void testPopulateQueryContextWhenNoSupercedingValuePresent() public void testPopulateQueryContextWhenNoSupercedingValuePresent()
{ {
QueryContext originalQueryContext = new QueryContext(); Map<String, Object> originalQueryContext = new HashMap<>();
MSQMode.populateDefaultQueryContext("strict", originalQueryContext); MSQMode.populateDefaultQueryContext("strict", originalQueryContext);
Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 0), originalQueryContext.getMergedParams()); Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 0), originalQueryContext);
} }
@Test @Test
public void testPopulateQueryContextWhenSupercedingValuePresent() public void testPopulateQueryContextWhenSupercedingValuePresent()
{ {
QueryContext originalQueryContext = new QueryContext(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10)); Map<String, Object> originalQueryContext = new HashMap<>();
originalQueryContext.put(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10);
MSQMode.populateDefaultQueryContext("strict", originalQueryContext); MSQMode.populateDefaultQueryContext("strict", originalQueryContext);
Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10), originalQueryContext.getMergedParams()); Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10), originalQueryContext);
} }
@Test @Test
public void testPopulateQueryContextWhenInvalidMode() public void testPopulateQueryContextWhenInvalidMode()
{ {
QueryContext originalQueryContext = new QueryContext(); Map<String, Object> originalQueryContext = new HashMap<>();
Assert.assertThrows(ISE.class, () -> { Assert.assertThrows(ISE.class, () -> {
MSQMode.populateDefaultQueryContext("fake_mode", originalQueryContext); MSQMode.populateDefaultQueryContext("fake_mode", originalQueryContext);
}); });

View File

@ -89,7 +89,6 @@ import org.apache.druid.msq.sql.MSQTaskSqlEngine;
import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.ForwardingQueryProcessingPool; import org.apache.druid.query.ForwardingQueryProcessingPool;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryProcessingPool;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
@ -132,7 +131,6 @@ import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.external.ExternalDataSource; import org.apache.druid.sql.calcite.external.ExternalDataSource;
import org.apache.druid.sql.calcite.planner.CalciteRulesManager; import org.apache.druid.sql.calcite.planner.CalciteRulesManager;
import org.apache.druid.sql.calcite.planner.PlannerConfig; import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.planner.PlannerFactory; import org.apache.druid.sql.calcite.planner.PlannerFactory;
import org.apache.druid.sql.calcite.rel.DruidQuery; import org.apache.druid.sql.calcite.rel.DruidQuery;
import org.apache.druid.sql.calcite.run.SqlEngine; import org.apache.druid.sql.calcite.run.SqlEngine;
@ -162,6 +160,7 @@ import org.mockito.Mockito;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.io.Closeable; import java.io.Closeable;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -207,7 +206,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
public static final Map<String, Object> DEFAULT_MSQ_CONTEXT = public static final Map<String, Object> DEFAULT_MSQ_CONTEXT =
ImmutableMap.<String, Object>builder() ImmutableMap.<String, Object>builder()
.put(MultiStageQueryContext.CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, true) .put(MultiStageQueryContext.CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, true)
.put(PlannerContext.CTX_SQL_QUERY_ID, "test-query") .put(QueryContexts.CTX_SQL_QUERY_ID, "test-query")
.put(QueryContexts.FINALIZE_KEY, true) .put(QueryContexts.FINALIZE_KEY, true)
.build(); .build();
@ -587,7 +586,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
final DirectStatement stmt = sqlStatementFactory.directStatement( final DirectStatement stmt = sqlStatementFactory.directStatement(
new SqlQueryPlus( new SqlQueryPlus(
query, query,
new QueryContext(context), context,
Collections.emptyList(), Collections.emptyList(),
CalciteTests.REGULAR_USER_AUTH_RESULT CalciteTests.REGULAR_USER_AUTH_RESULT
) )

View File

@ -27,6 +27,7 @@ import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -46,33 +47,33 @@ public class MultiStageQueryContextTest
@Test @Test
public void isDurableStorageEnabled_noParameterSetReturnsDefaultValue() public void isDurableStorageEnabled_noParameterSetReturnsDefaultValue()
{ {
Assert.assertFalse(MultiStageQueryContext.isDurableStorageEnabled(ImmutableMap.of())); Assert.assertFalse(MultiStageQueryContext.isDurableStorageEnabled(QueryContext.empty()));
} }
@Test @Test
public void isDurableStorageEnabled_parameterSetReturnsCorrectValue() public void isDurableStorageEnabled_parameterSetReturnsCorrectValue()
{ {
Map<String, Object> propertyMap = ImmutableMap.of(CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, "true"); Map<String, Object> propertyMap = ImmutableMap.of(CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, "true");
Assert.assertTrue(MultiStageQueryContext.isDurableStorageEnabled(propertyMap)); Assert.assertTrue(MultiStageQueryContext.isDurableStorageEnabled(QueryContext.of(propertyMap)));
} }
@Test @Test
public void isFinalizeAggregations_noParameterSetReturnsDefaultValue() public void isFinalizeAggregations_noParameterSetReturnsDefaultValue()
{ {
Assert.assertTrue(MultiStageQueryContext.isFinalizeAggregations(new QueryContext())); Assert.assertTrue(MultiStageQueryContext.isFinalizeAggregations(QueryContext.empty()));
} }
@Test @Test
public void isFinalizeAggregations_parameterSetReturnsCorrectValue() public void isFinalizeAggregations_parameterSetReturnsCorrectValue()
{ {
Map<String, Object> propertyMap = ImmutableMap.of(CTX_FINALIZE_AGGREGATIONS, "false"); Map<String, Object> propertyMap = ImmutableMap.of(CTX_FINALIZE_AGGREGATIONS, "false");
Assert.assertFalse(MultiStageQueryContext.isFinalizeAggregations(new QueryContext(propertyMap))); Assert.assertFalse(MultiStageQueryContext.isFinalizeAggregations(QueryContext.of(propertyMap)));
} }
@Test @Test
public void getAssignmentStrategy_noParameterSetReturnsDefaultValue() public void getAssignmentStrategy_noParameterSetReturnsDefaultValue()
{ {
Assert.assertEquals(WorkerAssignmentStrategy.MAX, MultiStageQueryContext.getAssignmentStrategy(new QueryContext())); Assert.assertEquals(WorkerAssignmentStrategy.MAX, MultiStageQueryContext.getAssignmentStrategy(QueryContext.empty()));
} }
@Test @Test
@ -81,67 +82,67 @@ public class MultiStageQueryContextTest
Map<String, Object> propertyMap = ImmutableMap.of(CTX_TASK_ASSIGNMENT_STRATEGY, "AUTO"); Map<String, Object> propertyMap = ImmutableMap.of(CTX_TASK_ASSIGNMENT_STRATEGY, "AUTO");
Assert.assertEquals( Assert.assertEquals(
WorkerAssignmentStrategy.AUTO, WorkerAssignmentStrategy.AUTO,
MultiStageQueryContext.getAssignmentStrategy(new QueryContext(propertyMap)) MultiStageQueryContext.getAssignmentStrategy(QueryContext.of(propertyMap))
); );
} }
@Test @Test
public void getMaxNumTasks_noParameterSetReturnsDefaultValue() public void getMaxNumTasks_noParameterSetReturnsDefaultValue()
{ {
Assert.assertEquals(DEFAULT_MAX_NUM_TASKS, MultiStageQueryContext.getMaxNumTasks(new QueryContext())); Assert.assertEquals(DEFAULT_MAX_NUM_TASKS, MultiStageQueryContext.getMaxNumTasks(QueryContext.empty()));
} }
@Test @Test
public void getMaxNumTasks_parameterSetReturnsCorrectValue() public void getMaxNumTasks_parameterSetReturnsCorrectValue()
{ {
Map<String, Object> propertyMap = ImmutableMap.of(CTX_MAX_NUM_TASKS, 101); Map<String, Object> propertyMap = ImmutableMap.of(CTX_MAX_NUM_TASKS, 101);
Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(new QueryContext(propertyMap))); Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(QueryContext.of(propertyMap)));
} }
@Test @Test
public void getMaxNumTasks_legacyParameterSetReturnsCorrectValue() public void getMaxNumTasks_legacyParameterSetReturnsCorrectValue()
{ {
Map<String, Object> propertyMap = ImmutableMap.of(CTX_MAX_NUM_TASKS, 101); Map<String, Object> propertyMap = ImmutableMap.of(CTX_MAX_NUM_TASKS, 101);
Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(new QueryContext(propertyMap))); Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(QueryContext.of(propertyMap)));
} }
@Test @Test
public void getDestination_noParameterSetReturnsDefaultValue() public void getDestination_noParameterSetReturnsDefaultValue()
{ {
Assert.assertNull(MultiStageQueryContext.getDestination(new QueryContext())); Assert.assertNull(MultiStageQueryContext.getDestination(QueryContext.empty()));
} }
@Test @Test
public void getDestination_parameterSetReturnsCorrectValue() public void getDestination_parameterSetReturnsCorrectValue()
{ {
Map<String, Object> propertyMap = ImmutableMap.of(CTX_DESTINATION, "dataSource"); Map<String, Object> propertyMap = ImmutableMap.of(CTX_DESTINATION, "dataSource");
Assert.assertEquals("dataSource", MultiStageQueryContext.getDestination(new QueryContext(propertyMap))); Assert.assertEquals("dataSource", MultiStageQueryContext.getDestination(QueryContext.of(propertyMap)));
} }
@Test @Test
public void getRowsPerSegment_noParameterSetReturnsDefaultValue() public void getRowsPerSegment_noParameterSetReturnsDefaultValue()
{ {
Assert.assertEquals(1000, MultiStageQueryContext.getRowsPerSegment(new QueryContext(), 1000)); Assert.assertEquals(1000, MultiStageQueryContext.getRowsPerSegment(QueryContext.empty(), 1000));
} }
@Test @Test
public void getRowsPerSegment_parameterSetReturnsCorrectValue() public void getRowsPerSegment_parameterSetReturnsCorrectValue()
{ {
Map<String, Object> propertyMap = ImmutableMap.of(CTX_ROWS_PER_SEGMENT, 10); Map<String, Object> propertyMap = ImmutableMap.of(CTX_ROWS_PER_SEGMENT, 10);
Assert.assertEquals(10, MultiStageQueryContext.getRowsPerSegment(new QueryContext(propertyMap), 1000)); Assert.assertEquals(10, MultiStageQueryContext.getRowsPerSegment(QueryContext.of(propertyMap), 1000));
} }
@Test @Test
public void getRowsInMemory_noParameterSetReturnsDefaultValue() public void getRowsInMemory_noParameterSetReturnsDefaultValue()
{ {
Assert.assertEquals(1000, MultiStageQueryContext.getRowsInMemory(new QueryContext(), 1000)); Assert.assertEquals(1000, MultiStageQueryContext.getRowsInMemory(QueryContext.empty(), 1000));
} }
@Test @Test
public void getRowsInMemory_parameterSetReturnsCorrectValue() public void getRowsInMemory_parameterSetReturnsCorrectValue()
{ {
Map<String, Object> propertyMap = ImmutableMap.of(CTX_ROWS_IN_MEMORY, 10); Map<String, Object> propertyMap = ImmutableMap.of(CTX_ROWS_IN_MEMORY, 10);
Assert.assertEquals(10, MultiStageQueryContext.getRowsInMemory(new QueryContext(propertyMap), 1000)); Assert.assertEquals(10, MultiStageQueryContext.getRowsInMemory(QueryContext.of(propertyMap), 1000));
} }
@Test @Test
@ -161,27 +162,27 @@ public class MultiStageQueryContextTest
@Test @Test
public void getSortOrderNoParameterSetReturnsDefaultValue() public void getSortOrderNoParameterSetReturnsDefaultValue()
{ {
Assert.assertNull(MultiStageQueryContext.getSortOrder(new QueryContext())); Assert.assertNull(MultiStageQueryContext.getSortOrder(QueryContext.empty()));
} }
@Test @Test
public void getSortOrderParameterSetReturnsCorrectValue() public void getSortOrderParameterSetReturnsCorrectValue()
{ {
Map<String, Object> propertyMap = ImmutableMap.of(CTX_SORT_ORDER, "a, b,\"c,d\""); Map<String, Object> propertyMap = ImmutableMap.of(CTX_SORT_ORDER, "a, b,\"c,d\"");
Assert.assertEquals("a, b,\"c,d\"", MultiStageQueryContext.getSortOrder(new QueryContext(propertyMap))); Assert.assertEquals("a, b,\"c,d\"", MultiStageQueryContext.getSortOrder(QueryContext.of(propertyMap)));
} }
@Test @Test
public void getMSQModeNoParameterSetReturnsDefaultValue() public void getMSQModeNoParameterSetReturnsDefaultValue()
{ {
Assert.assertEquals("strict", MultiStageQueryContext.getMSQMode(new QueryContext())); Assert.assertEquals("strict", MultiStageQueryContext.getMSQMode(QueryContext.empty()));
} }
@Test @Test
public void getMSQModeParameterSetReturnsCorrectValue() public void getMSQModeParameterSetReturnsCorrectValue()
{ {
Map<String, Object> propertyMap = ImmutableMap.of(CTX_MSQ_MODE, "nonStrict"); Map<String, Object> propertyMap = ImmutableMap.of(CTX_MSQ_MODE, "nonStrict");
Assert.assertEquals("nonStrict", MultiStageQueryContext.getMSQMode(new QueryContext(propertyMap))); Assert.assertEquals("nonStrict", MultiStageQueryContext.getMSQMode(QueryContext.of(propertyMap)));
} }
private static List<String> decodeSortOrder(@Nullable final String input) private static List<String> decodeSortOrder(@Nullable final String input)

View File

@ -34,6 +34,7 @@ import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException; import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryProcessingPool;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryRunnerFactory; import org.apache.druid.query.QueryRunnerFactory;
@ -127,7 +128,8 @@ public class ServerManagerForQueryErrorTest extends ServerManager
Optional<byte[]> cacheKeyPrefix Optional<byte[]> cacheKeyPrefix
) )
{ {
if (query.getContextBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) { final QueryContext queryContext = query.context();
if (queryContext.getBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) {
final MutableBoolean isIgnoreSegment = new MutableBoolean(false); final MutableBoolean isIgnoreSegment = new MutableBoolean(false);
queryToIgnoredSegments.compute( queryToIgnoredSegments.compute(
query.getMostSpecificId(), query.getMostSpecificId(),
@ -147,7 +149,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
LOG.info("Pretending I don't have segment [%s]", descriptor); LOG.info("Pretending I don't have segment [%s]", descriptor);
return new ReportTimelineMissingSegmentQueryRunner<>(descriptor); return new ReportTimelineMissingSegmentQueryRunner<>(descriptor);
} }
} else if (query.getContextBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) { } else if (queryContext.getBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>() return (queryPlus, responseContext) -> new Sequence<T>()
{ {
@Override @Override
@ -162,7 +164,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new QueryTimeoutException("query timeout test"); throw new QueryTimeoutException("query timeout test");
} }
}; };
} else if (query.getContextBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) { } else if (queryContext.getBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>() return (queryPlus, responseContext) -> new Sequence<T>()
{ {
@Override @Override
@ -177,7 +179,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw QueryCapacityExceededException.withErrorMessageAndResolvedHost("query capacity exceeded test"); throw QueryCapacityExceededException.withErrorMessageAndResolvedHost("query capacity exceeded test");
} }
}; };
} else if (query.getContextBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) { } else if (queryContext.getBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>() return (queryPlus, responseContext) -> new Sequence<T>()
{ {
@Override @Override
@ -192,7 +194,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new QueryUnsupportedException("query unsupported test"); throw new QueryUnsupportedException("query unsupported test");
} }
}; };
} else if (query.getContextBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) { } else if (queryContext.getBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>() return (queryPlus, responseContext) -> new Sequence<T>()
{ {
@Override @Override
@ -207,7 +209,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new ResourceLimitExceededException("resource limit exceeded test"); throw new ResourceLimitExceededException("resource limit exceeded test");
} }
}; };
} else if (query.getContextBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) { } else if (queryContext.getBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>() return (queryPlus, responseContext) -> new Sequence<T>()
{ {
@Override @Override

View File

@ -34,6 +34,7 @@ import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException; import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryProcessingPool;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryRunnerFactory; import org.apache.druid.query.QueryRunnerFactory;
@ -125,7 +126,8 @@ public class ServerManagerForQueryErrorTest extends ServerManager
Optional<byte[]> cacheKeyPrefix Optional<byte[]> cacheKeyPrefix
) )
{ {
if (query.getContextBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) { final QueryContext queryContext = query.context();
if (queryContext.getBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) {
final MutableBoolean isIgnoreSegment = new MutableBoolean(false); final MutableBoolean isIgnoreSegment = new MutableBoolean(false);
queryToIgnoredSegments.compute( queryToIgnoredSegments.compute(
query.getMostSpecificId(), query.getMostSpecificId(),
@ -145,7 +147,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
LOG.info("Pretending I don't have segment[%s]", descriptor); LOG.info("Pretending I don't have segment[%s]", descriptor);
return new ReportTimelineMissingSegmentQueryRunner<>(descriptor); return new ReportTimelineMissingSegmentQueryRunner<>(descriptor);
} }
} else if (query.getContextBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) { } else if (queryContext.getBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>() return (queryPlus, responseContext) -> new Sequence<T>()
{ {
@Override @Override
@ -160,7 +162,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new QueryTimeoutException("query timeout test"); throw new QueryTimeoutException("query timeout test");
} }
}; };
} else if (query.getContextBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) { } else if (queryContext.getBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>() return (queryPlus, responseContext) -> new Sequence<T>()
{ {
@Override @Override
@ -175,7 +177,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw QueryCapacityExceededException.withErrorMessageAndResolvedHost("query capacity exceeded test"); throw QueryCapacityExceededException.withErrorMessageAndResolvedHost("query capacity exceeded test");
} }
}; };
} else if (query.getContextBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) { } else if (queryContext.getBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>() return (queryPlus, responseContext) -> new Sequence<T>()
{ {
@Override @Override
@ -190,7 +192,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new QueryUnsupportedException("query unsupported test"); throw new QueryUnsupportedException("query unsupported test");
} }
}; };
} else if (query.getContextBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) { } else if (queryContext.getBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>() return (queryPlus, responseContext) -> new Sequence<T>()
{ {
@Override @Override
@ -205,7 +207,7 @@ public class ServerManagerForQueryErrorTest extends ServerManager
throw new ResourceLimitExceededException("resource limit exceeded test"); throw new ResourceLimitExceededException("resource limit exceeded test");
} }
}; };
} else if (query.getContextBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) { } else if (queryContext.getBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) {
return (queryPlus, responseContext) -> new Sequence<T>() return (queryPlus, responseContext) -> new Sequence<T>()
{ {
@Override @Override

View File

@ -547,7 +547,7 @@ public abstract class AbstractAuthConfigurationTest
public void test_sqlQueryWithContext_datasourceOnlyUser_fail() throws Exception public void test_sqlQueryWithContext_datasourceOnlyUser_fail() throws Exception
{ {
final String query = "select count(*) from auth_test"; final String query = "select count(*) from auth_test";
StatusResponseHolder responseHolder = makeSQLQueryRequest( makeSQLQueryRequest(
getHttpClient(User.DATASOURCE_ONLY_USER), getHttpClient(User.DATASOURCE_ONLY_USER),
query, query,
ImmutableMap.of("auth_test_ctx", "should-be-denied"), ImmutableMap.of("auth_test_ctx", "should-be-denied"),
@ -559,7 +559,7 @@ public abstract class AbstractAuthConfigurationTest
public void test_sqlQueryWithContext_datasourceAndContextParamsUser_succeed() throws Exception public void test_sqlQueryWithContext_datasourceAndContextParamsUser_succeed() throws Exception
{ {
final String query = "select count(*) from auth_test"; final String query = "select count(*) from auth_test";
StatusResponseHolder responseHolder = makeSQLQueryRequest( makeSQLQueryRequest(
getHttpClient(User.DATASOURCE_AND_CONTEXT_PARAMS_USER), getHttpClient(User.DATASOURCE_AND_CONTEXT_PARAMS_USER),
query, query,
ImmutableMap.of("auth_test_ctx", "should-be-allowed"), ImmutableMap.of("auth_test_ctx", "should-be-allowed"),
@ -844,11 +844,6 @@ public abstract class AbstractAuthConfigurationTest
protected void verifyInvalidAuthNameFails(String endpoint) protected void verifyInvalidAuthNameFails(String endpoint)
{ {
HttpClient adminClient = new CredentialedHttpClient(
new BasicCredentials("admin", "priest"),
httpClient
);
HttpUtil.makeRequestWithExpectedStatus( HttpUtil.makeRequestWithExpectedStatus(
getHttpClient(User.ADMIN), getHttpClient(User.ADMIN),
HttpMethod.POST, HttpMethod.POST,

View File

@ -32,6 +32,11 @@ public class BadQueryContextException extends BadQueryException
this(ERROR_CODE, e.getMessage(), ERROR_CLASS); this(ERROR_CODE, e.getMessage(), ERROR_CLASS);
} }
public BadQueryContextException(String msg)
{
this(ERROR_CODE, msg, ERROR_CLASS);
}
@JsonCreator @JsonCreator
private BadQueryContextException( private BadQueryContextException(
@JsonProperty("error") String errorCode, @JsonProperty("error") String errorCode,

View File

@ -27,7 +27,6 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Ordering; import com.google.common.collect.Ordering;
import org.apache.druid.guice.annotations.ExtensionPoint; import org.apache.druid.guice.annotations.ExtensionPoint;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.granularity.PeriodGranularity; import org.apache.druid.java.util.common.granularity.PeriodGranularity;
@ -38,7 +37,6 @@ import org.joda.time.Duration;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -89,7 +87,7 @@ public abstract class BaseQuery<T> implements Query<T>
Preconditions.checkNotNull(granularity, "Must specify a granularity"); Preconditions.checkNotNull(granularity, "Must specify a granularity");
this.dataSource = dataSource; this.dataSource = dataSource;
this.context = new QueryContext(context); this.context = QueryContext.of(context);
this.querySegmentSpec = querySegmentSpec; this.querySegmentSpec = querySegmentSpec;
this.descending = descending; this.descending = descending;
this.granularity = granularity; this.granularity = granularity;
@ -172,27 +170,15 @@ public abstract class BaseQuery<T> implements Query<T>
@JsonInclude(Include.NON_DEFAULT) @JsonInclude(Include.NON_DEFAULT)
public Map<String, Object> getContext() public Map<String, Object> getContext()
{ {
return context.getMergedParams(); return context.asMap();
} }
@Override @Override
public QueryContext getQueryContext() public QueryContext context()
{ {
return context; return context;
} }
@Override
public boolean getContextBoolean(String key, boolean defaultValue)
{
return context.getAsBoolean(key, defaultValue);
}
@Override
public HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{
return context.getAsHumanReadableBytes(key, defaultValue);
}
/** /**
* @deprecated use {@link #computeOverriddenContext(Map, Map) computeOverriddenContext(getContext(), overrides))} * @deprecated use {@link #computeOverriddenContext(Map, Map) computeOverriddenContext(getContext(), overrides))}
* instead. This method may be removed in the next minor or major version of Druid. * instead. This method may be removed in the next minor or major version of Druid.
@ -228,7 +214,7 @@ public abstract class BaseQuery<T> implements Query<T>
@Override @Override
public String getId() public String getId()
{ {
return context.getAsString(QUERY_ID); return context().getString(QUERY_ID);
} }
@Override @Override
@ -241,7 +227,7 @@ public abstract class BaseQuery<T> implements Query<T>
@Override @Override
public String getSubQueryId() public String getSubQueryId()
{ {
return context.getAsString(SUB_QUERY_ID); return context().getString(SUB_QUERY_ID);
} }
@Override @Override

View File

@ -35,7 +35,7 @@ import java.util.List;
* *
* Note that despite the type parameter "T", this runner may not actually return sequences with type T. They * Note that despite the type parameter "T", this runner may not actually return sequences with type T. They
* may really be of type {@code Result<BySegmentResultValue<T>>}, if "bySegment" is set. Downstream consumers * may really be of type {@code Result<BySegmentResultValue<T>>}, if "bySegment" is set. Downstream consumers
* of the returned sequence must be aware of this, and can use {@link QueryContexts#isBySegment(Query)} to * of the returned sequence must be aware of this, and can use {@link QueryContext#isBySegment()} to
* know what to expect. * know what to expect.
*/ */
public class BySegmentQueryRunner<T> implements QueryRunner<T> public class BySegmentQueryRunner<T> implements QueryRunner<T>
@ -55,7 +55,7 @@ public class BySegmentQueryRunner<T> implements QueryRunner<T>
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public Sequence<T> run(final QueryPlus<T> queryPlus, ResponseContext responseContext) public Sequence<T> run(final QueryPlus<T> queryPlus, ResponseContext responseContext)
{ {
if (QueryContexts.isBySegment(queryPlus.getQuery())) { if (queryPlus.getQuery().context().isBySegment()) {
final Sequence<T> baseSequence = base.run(queryPlus, responseContext); final Sequence<T> baseSequence = base.run(queryPlus, responseContext);
final List<T> results = baseSequence.toList(); final List<T> results = baseSequence.toList();
return Sequences.simple( return Sequences.simple(

View File

@ -39,7 +39,7 @@ public abstract class BySegmentSkippingQueryRunner<T> implements QueryRunner<T>
@Override @Override
public Sequence<T> run(QueryPlus<T> queryPlus, ResponseContext responseContext) public Sequence<T> run(QueryPlus<T> queryPlus, ResponseContext responseContext)
{ {
if (QueryContexts.isBySegment(queryPlus.getQuery())) { if (queryPlus.getQuery().context().isBySegment()) {
return baseRunner.run(queryPlus, responseContext); return baseRunner.run(queryPlus, responseContext);
} }

View File

@ -78,7 +78,7 @@ public class ChainedExecutionQueryRunner<T> implements QueryRunner<T>
public Sequence<T> run(final QueryPlus<T> queryPlus, final ResponseContext responseContext) public Sequence<T> run(final QueryPlus<T> queryPlus, final ResponseContext responseContext)
{ {
Query<T> query = queryPlus.getQuery(); Query<T> query = queryPlus.getQuery();
final int priority = QueryContexts.getPriority(query); final int priority = query.context().getPriority();
final Ordering ordering = query.getResultOrdering(); final Ordering ordering = query.getResultOrdering();
final QueryPlus<T> threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState(); final QueryPlus<T> threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState();
return new BaseSequence<T, Iterator<T>>( return new BaseSequence<T, Iterator<T>>(
@ -137,9 +137,10 @@ public class ChainedExecutionQueryRunner<T> implements QueryRunner<T>
queryWatcher.registerQueryFuture(query, future); queryWatcher.registerQueryFuture(query, future);
try { try {
final QueryContext context = query.context();
return new MergeIterable<>( return new MergeIterable<>(
QueryContexts.hasTimeout(query) ? context.hasTimeout() ?
future.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS) : future.get(context.getTimeout(), TimeUnit.MILLISECONDS) :
future.get(), future.get(),
ordering.nullsFirst() ordering.nullsFirst()
).iterator(); ).iterator();

View File

@ -56,8 +56,9 @@ public class FinalizeResultsQueryRunner<T> implements QueryRunner<T>
public Sequence<T> run(final QueryPlus<T> queryPlus, ResponseContext responseContext) public Sequence<T> run(final QueryPlus<T> queryPlus, ResponseContext responseContext)
{ {
final Query<T> query = queryPlus.getQuery(); final Query<T> query = queryPlus.getQuery();
final boolean isBySegment = QueryContexts.isBySegment(query); final QueryContext queryContext = query.context();
final boolean shouldFinalize = QueryContexts.isFinalize(query, true); final boolean isBySegment = queryContext.isBySegment();
final boolean shouldFinalize = queryContext.isFinalize(true);
final Query<T> queryToRun; final Query<T> queryToRun;
final Function<T, ?> finalizerFn; final Function<T, ?> finalizerFn;

View File

@ -84,8 +84,9 @@ public class GroupByMergedQueryRunner<T> implements QueryRunner<T>
querySpecificConfig querySpecificConfig
); );
final Pair<Queue, Accumulator<Queue, T>> bySegmentAccumulatorPair = GroupByQueryHelper.createBySegmentAccumulatorPair(); final Pair<Queue, Accumulator<Queue, T>> bySegmentAccumulatorPair = GroupByQueryHelper.createBySegmentAccumulatorPair();
final boolean bySegment = QueryContexts.isBySegment(query); final QueryContext queryContext = query.context();
final int priority = QueryContexts.getPriority(query); final boolean bySegment = queryContext.isBySegment();
final int priority = queryContext.getPriority();
final QueryPlus<T> threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState(); final QueryPlus<T> threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState();
final List<ListenableFuture<Void>> futures = final List<ListenableFuture<Void>> futures =
Lists.newArrayList( Lists.newArrayList(
@ -173,8 +174,9 @@ public class GroupByMergedQueryRunner<T> implements QueryRunner<T>
ListenableFuture<List<Void>> future = Futures.allAsList(futures); ListenableFuture<List<Void>> future = Futures.allAsList(futures);
try { try {
queryWatcher.registerQueryFuture(query, future); queryWatcher.registerQueryFuture(query, future);
if (QueryContexts.hasTimeout(query)) { final QueryContext context = query.context();
future.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS); if (context.hasTimeout()) {
future.get(context.getTimeout(), TimeUnit.MILLISECONDS);
} else { } else {
future.get(); future.get();
} }

View File

@ -20,6 +20,7 @@
package org.apache.druid.query; package org.apache.druid.query;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import org.apache.druid.guice.annotations.PublicApi; import org.apache.druid.guice.annotations.PublicApi;
@ -36,6 +37,7 @@ import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
@ -293,4 +295,24 @@ public class Queries
return requiredColumns; return requiredColumns;
} }
public static <T> Query<T> withMaxScatterGatherBytes(Query<T> query, long maxScatterGatherBytesLimit)
{
QueryContext context = query.context();
if (!context.containsKey(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY)) {
return query.withOverriddenContext(ImmutableMap.of(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, maxScatterGatherBytesLimit));
}
context.verifyMaxScatterGatherBytes(maxScatterGatherBytesLimit);
return query;
}
public static <T> Query<T> withTimeout(Query<T> query, long timeout)
{
return query.withOverriddenContext(ImmutableMap.of(QueryContexts.TIMEOUT_KEY, timeout));
}
public static <T> Query<T> withDefaultTimeout(Query<T> query, long defaultTimeout)
{
return query.withOverriddenContext(ImmutableMap.of(QueryContexts.DEFAULT_TIMEOUT_KEY, defaultTimeout));
}
} }

View File

@ -45,6 +45,7 @@ import org.joda.time.Duration;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -96,64 +97,53 @@ public interface Query<T>
DateTimeZone getTimezone(); DateTimeZone getTimezone();
/** /**
* Use {@link #getQueryContext()} instead. * Returns the context as an (immutable) map.
*/ */
@Deprecated
Map<String, Object> getContext(); Map<String, Object> getContext();
/** /**
* Returns QueryContext for this query. This type distinguishes between user provided, system default, and system * Returns the query context as a {@link QueryContext}, which provides
* generated query context keys so that authorization may be employed directly against the user supplied context * convenience methods for accessing typed context values. The returned
* values. * instance is a view on top of the context provided by {@link #getContext()}.
* * <p>
* This method is marked @Nullable, but is only so for backwards compatibility with Druid versions older than 0.23. * The default implementation is for backward compatibility. Derived classes should
* Callers should check if the result of this method is null, and if so, they are dealing with a legacy query * store and return the {@link QueryContext} directly.
* implementation, and should fall back to using {@link #getContext()} and {@link #withOverriddenContext(Map)} to
* manipulate the query context.
*
* Note for query context serialization and deserialization.
* Currently, once a query is serialized, its queryContext can be different from the original queryContext
* after the query is deserialized back. If the queryContext has any {@link QueryContext#defaultParams} or
* {@link QueryContext#systemParams} in it, those will be found in {@link QueryContext#userParams}
* after it is deserialized. This is because {@link BaseQuery#getContext()} uses
* {@link QueryContext#getMergedParams()} for serialization, and queries accept a map for deserialization.
*/ */
@Nullable default QueryContext context()
default QueryContext getQueryContext()
{ {
return null; return QueryContext.of(getContext());
} }
/** /**
* Get context value and cast to ContextType in an unsafe way. * Get context value and cast to ContextType in an unsafe way.
* *
* For safe conversion, it's recommended to use following methods instead * For safe conversion, it's recommended to use following methods instead:
* <p>
* {@link QueryContext#getBoolean(String)} <br/>
* {@link QueryContext#getString(String)} <br/>
* {@link QueryContext#getInt(String)} <br/>
* {@link QueryContext#getLong(String)} <br/>
* {@link QueryContext#getFloat(String)} <br/>
* {@link QueryContext#getEnum(String, Class, Enum)} <br/>
* {@link QueryContext#getHumanReadableBytes(String, HumanReadableBytes)}
* *
* {@link QueryContext#getAsBoolean(String)} * @deprecated use {@code queryContext().get<Type>()} instead
* {@link QueryContext#getAsString(String)}
* {@link QueryContext#getAsInt(String)}
* {@link QueryContext#getAsLong(String)}
* {@link QueryContext#getAsFloat(String, float)}
* {@link QueryContext#getAsEnum(String, Class, Enum)}
* {@link QueryContext#getAsHumanReadableBytes(String, HumanReadableBytes)}
*/ */
@Deprecated
@SuppressWarnings("unchecked")
@Nullable @Nullable
default <ContextType> ContextType getContextValue(String key) default <ContextType> ContextType getContextValue(String key)
{ {
if (getQueryContext() == null) { return (ContextType) context().get(key);
return null;
} else {
return (ContextType) getQueryContext().get(key);
}
} }
/**
* @deprecated use {@code queryContext().getBoolean()} instead.
*/
@Deprecated
default boolean getContextBoolean(String key, boolean defaultValue) default boolean getContextBoolean(String key, boolean defaultValue)
{ {
if (getQueryContext() == null) { return context().getBoolean(key, defaultValue);
return defaultValue;
} else {
return getQueryContext().getAsBoolean(key, defaultValue);
}
} }
/** /**
@ -164,14 +154,12 @@ public interface Query<T>
* @param key The context key value being looked up * @param key The context key value being looked up
* @param defaultValue The default to return if the key value doesn't exist or the context is null. * @param defaultValue The default to return if the key value doesn't exist or the context is null.
* @return {@link HumanReadableBytes} * @return {@link HumanReadableBytes}
* @deprecated use {@code queryContext().getContextHumanReadableBytes()} instead.
*/ */
default HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue) @Deprecated
default HumanReadableBytes getContextHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{ {
if (getQueryContext() == null) { return context().getHumanReadableBytes(key, defaultValue);
return defaultValue;
} else {
return getQueryContext().getAsHumanReadableBytes(key, defaultValue);
}
} }
boolean isDescending(); boolean isDescending();
@ -230,7 +218,7 @@ public interface Query<T>
@Nullable @Nullable
default String getSqlQueryId() default String getSqlQueryId()
{ {
return getQueryContext().getAsString(BaseQuery.SQL_QUERY_ID); return context().getString(BaseQuery.SQL_QUERY_ID);
} }
/** /**

View File

@ -20,6 +20,10 @@
package org.apache.druid.query; package org.apache.druid.query;
import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.QueryContexts.Vectorize;
import org.apache.druid.segment.QueryableIndexStorageAdapter;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -29,227 +33,547 @@ import java.util.Objects;
import java.util.TreeMap; import java.util.TreeMap;
/** /**
* Holder for query context parameters. There are 3 ways to set context params today. * Immutable holder for query context parameters with typed access methods.
* * Code builds up a map of context values from serialization or during
* - Default parameters. These are set mostly via {@link DefaultQueryConfig#context}. * planning. Once that map is handed to the {@code QueryContext}, that map
* Auto-generated queryId or sqlQueryId are also set as default parameters. These default parameters can * is effectively immutable.
* be overridden by user or system parameters. * <p>
* - User parameters. These are the params set by the user. User params override default parameters but * The implementation uses a {@link TreeMap} so that the serialized form of a query
* are overridden by system parameters. * lists context values in a deterministic order. Jackson will call
* - System parameters. These are the params set by the Druid query engine for internal use only. * {@code getContext()} on the query, which will call {@link #asMap()} here,
* * which returns the sorted {@code TreeMap}.
* You can use {@code getX} methods or {@link #getMergedParams()} to compute the context params * <p>
* merging 3 types of params above. * The {@code TreeMap} is a mutable class. We'd prefer an immutable class, but
* * we can choose either ordering or immutability. Since the semantics of the context
* Currently, this class is mainly used for query context parameter authorization, * is that it is immutable once it is placed in a query. Code should NEVER get the
* such as HTTP query endpoints or JDBC endpoint. Its usage can be expanded in the future if we * context map from a query and modify it, even if the actual implementation
* want to track user parameters and separate them from others during query processing. * allows it.
*/ */
public class QueryContext public class QueryContext
{ {
private final Map<String, Object> defaultParams; private static final QueryContext EMPTY = new QueryContext(null);
private final Map<String, Object> userParams;
private final Map<String, Object> systemParams;
/** private final Map<String, Object> context;
* Cache of params merged.
*/
@Nullable
private Map<String, Object> mergedParams;
public QueryContext() public QueryContext(Map<String, Object> context)
{ {
this(null); // There is no semantic difference between an empty and a null context.
// Ensure that a context always exists to avoid the need to check for
// a null context. Jackson serialization will omit empty contexts.
this.context = context == null
? Collections.emptyMap()
: Collections.unmodifiableMap(new TreeMap<>(context));
} }
public QueryContext(@Nullable Map<String, Object> userParams) public static QueryContext empty()
{ {
this( return EMPTY;
new TreeMap<>(),
userParams == null ? new TreeMap<>() : new TreeMap<>(userParams),
new TreeMap<>()
);
} }
private QueryContext( public static QueryContext of(Map<String, Object> context)
final Map<String, Object> defaultParams,
final Map<String, Object> userParams,
final Map<String, Object> systemParams
)
{ {
this.defaultParams = defaultParams; return new QueryContext(context);
this.userParams = userParams;
this.systemParams = systemParams;
this.mergedParams = null;
}
private void invalidateMergedParams()
{
this.mergedParams = null;
} }
public boolean isEmpty() public boolean isEmpty()
{ {
return defaultParams.isEmpty() && userParams.isEmpty() && systemParams.isEmpty(); return context.isEmpty();
} }
public void addDefaultParam(String key, Object val) public Map<String, Object> asMap()
{ {
invalidateMergedParams(); return context;
defaultParams.put(key, val);
}
public void addDefaultParams(Map<String, Object> defaultParams)
{
invalidateMergedParams();
this.defaultParams.putAll(defaultParams);
}
public void addSystemParam(String key, Object val)
{
invalidateMergedParams();
this.systemParams.put(key, val);
}
public Object removeUserParam(String key)
{
invalidateMergedParams();
return userParams.remove(key);
} }
/** /**
* Returns only the context parameters the user sets. * Check if the given key is set. If the client will then fetch the value,
* The returned map does not include the parameters that have been removed via {@link #removeUserParam}. * consider using one of the {@code get<Type>(String key)} methods instead:
* * they each return {@code null} if the value is not set.
* Callers should use {@code getX} methods or {@link #getMergedParams()} instead to use the whole context params.
*/ */
public Map<String, Object> getUserParams()
{
return userParams;
}
public boolean isDebug()
{
return getAsBoolean(QueryContexts.ENABLE_DEBUG, QueryContexts.DEFAULT_ENABLE_DEBUG);
}
public boolean isEnableJoinLeftScanDirect()
{
return getAsBoolean(
QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT,
QueryContexts.DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT
);
}
@SuppressWarnings("unused")
public boolean containsKey(String key) public boolean containsKey(String key)
{ {
return get(key) != null; return context.containsKey(key);
} }
/**
* Return a value as a generic {@code Object}, returning {@code null} if the
* context value is not set.
*/
@Nullable @Nullable
public Object get(String key) public Object get(String key)
{ {
Object val = systemParams.get(key); return context.get(key);
if (val != null) {
return val;
}
val = userParams.get(key);
return val == null ? defaultParams.get(key) : val;
} }
@SuppressWarnings("unused") /**
public Object getOrDefault(String key, Object defaultValue) * Return a value as a generic {@code Object}, returning the default value if the
* context value is not set.
*/
public Object get(String key, Object defaultValue)
{ {
final Object val = get(key); final Object val = get(key);
return val == null ? defaultValue : val; return val == null ? defaultValue : val;
} }
/**
* Return a value as an {@code String}, returning {@link null} if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
@Nullable @Nullable
public String getAsString(String key) public String getString(String key)
{ {
Object val = get(key); return getString(key, null);
return val == null ? null : val.toString();
} }
public String getAsString(String key, String defaultValue) public String getString(String key, String defaultValue)
{ {
Object val = get(key); return QueryContexts.parseString(context, key, defaultValue);
return val == null ? defaultValue : val.toString();
} }
@Nullable /**
public Boolean getAsBoolean(String key) * Return a value as an {@code Boolean}, returning {@link null} if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public Boolean getBoolean(final String key)
{ {
return QueryContexts.getAsBoolean(key, get(key)); return QueryContexts.getAsBoolean(key, get(key));
} }
public boolean getAsBoolean( /**
final String key, * Return a value as an {@code boolean}, returning the default value if the
final boolean defaultValue * context value is not set.
) *
* @throws BadQueryContextException for an invalid value
*/
public boolean getBoolean(final String key, final boolean defaultValue)
{ {
return QueryContexts.getAsBoolean(key, get(key), defaultValue); return QueryContexts.parseBoolean(context, key, defaultValue);
} }
public Integer getAsInt(final String key) /**
* Return a value as an {@code Integer}, returning {@link null} if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public Integer getInt(final String key)
{ {
return QueryContexts.getAsInt(key, get(key)); return QueryContexts.getAsInt(key, get(key));
} }
public int getAsInt( /**
final String key, * Return a value as an {@code int}, returning the default value if the
final int defaultValue * context value is not set.
) *
* @throws BadQueryContextException for an invalid value
*/
public int getInt(final String key, final int defaultValue)
{ {
return QueryContexts.getAsInt(key, get(key), defaultValue); return QueryContexts.parseInt(context, key, defaultValue);
} }
public Long getAsLong(final String key) /**
* Return a value as an {@code Long}, returning {@link null} if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public Long getLong(final String key)
{ {
return QueryContexts.getAsLong(key, get(key)); return QueryContexts.getAsLong(key, get(key));
} }
public long getAsLong(final String key, final long defaultValue) /**
* Return a value as an {@code long}, returning the default value if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public long getLong(final String key, final long defaultValue)
{ {
return QueryContexts.getAsLong(key, get(key), defaultValue); return QueryContexts.parseLong(context, key, defaultValue);
} }
public HumanReadableBytes getAsHumanReadableBytes(final String key, final HumanReadableBytes defaultValue) /**
* Return a value as an {@code Float}, returning {@link null} if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
@SuppressWarnings("unused")
public Float getFloat(final String key)
{ {
return QueryContexts.getAsHumanReadableBytes(key, get(key), defaultValue); return QueryContexts.getAsFloat(key, get(key));
} }
public float getAsFloat(final String key, final float defaultValue) /**
* Return a value as an {@code float}, returning the default value if the
* context value is not set.
*
* @throws BadQueryContextException for an invalid value
*/
public float getFloat(final String key, final float defaultValue)
{ {
return QueryContexts.getAsFloat(key, get(key), defaultValue); return QueryContexts.getAsFloat(key, get(key), defaultValue);
} }
public <E extends Enum<E>> E getAsEnum(String key, Class<E> clazz, E defaultValue) public HumanReadableBytes getHumanReadableBytes(final String key, final HumanReadableBytes defaultValue)
{
return QueryContexts.getAsHumanReadableBytes(key, get(key), defaultValue);
}
public <E extends Enum<E>> E getEnum(String key, Class<E> clazz, E defaultValue)
{ {
return QueryContexts.getAsEnum(key, get(key), clazz, defaultValue); return QueryContexts.getAsEnum(key, get(key), clazz, defaultValue);
} }
public Map<String, Object> getMergedParams() public Granularity getGranularity(String key)
{ {
if (mergedParams == null) { final Object value = get(key);
final Map<String, Object> merged = new TreeMap<>(defaultParams); if (value == null) {
merged.putAll(userParams); return null;
merged.putAll(systemParams); }
mergedParams = Collections.unmodifiableMap(merged); if (value instanceof Granularity) {
return (Granularity) value;
} else {
throw QueryContexts.badTypeException(key, "a Granularity", value);
} }
return mergedParams;
} }
public QueryContext copy() public boolean isDebug()
{ {
return new QueryContext( return getBoolean(QueryContexts.ENABLE_DEBUG, QueryContexts.DEFAULT_ENABLE_DEBUG);
new TreeMap<>(defaultParams), }
new TreeMap<>(userParams),
new TreeMap<>(systemParams) public boolean isBySegment()
{
return isBySegment(QueryContexts.DEFAULT_BY_SEGMENT);
}
public boolean isBySegment(boolean defaultValue)
{
return getBoolean(QueryContexts.BY_SEGMENT_KEY, defaultValue);
}
public boolean isPopulateCache()
{
return isPopulateCache(QueryContexts.DEFAULT_POPULATE_CACHE);
}
public boolean isPopulateCache(boolean defaultValue)
{
return getBoolean(QueryContexts.POPULATE_CACHE_KEY, defaultValue);
}
public boolean isUseCache()
{
return isUseCache(QueryContexts.DEFAULT_USE_CACHE);
}
public boolean isUseCache(boolean defaultValue)
{
return getBoolean(QueryContexts.USE_CACHE_KEY, defaultValue);
}
public boolean isPopulateResultLevelCache()
{
return isPopulateResultLevelCache(QueryContexts.DEFAULT_POPULATE_RESULTLEVEL_CACHE);
}
public boolean isPopulateResultLevelCache(boolean defaultValue)
{
return getBoolean(QueryContexts.POPULATE_RESULT_LEVEL_CACHE_KEY, defaultValue);
}
public boolean isUseResultLevelCache()
{
return isUseResultLevelCache(QueryContexts.DEFAULT_USE_RESULTLEVEL_CACHE);
}
public boolean isUseResultLevelCache(boolean defaultValue)
{
return getBoolean(QueryContexts.USE_RESULT_LEVEL_CACHE_KEY, defaultValue);
}
public boolean isFinalize(boolean defaultValue)
{
return getBoolean(QueryContexts.FINALIZE_KEY, defaultValue);
}
public boolean isSerializeDateTimeAsLong(boolean defaultValue)
{
return getBoolean(QueryContexts.SERIALIZE_DATE_TIME_AS_LONG_KEY, defaultValue);
}
public boolean isSerializeDateTimeAsLongInner(boolean defaultValue)
{
return getBoolean(QueryContexts.SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY, defaultValue);
}
public Vectorize getVectorize()
{
return getVectorize(QueryContexts.DEFAULT_VECTORIZE);
}
public Vectorize getVectorize(Vectorize defaultValue)
{
return getEnum(QueryContexts.VECTORIZE_KEY, Vectorize.class, defaultValue);
}
public Vectorize getVectorizeVirtualColumns()
{
return getVectorizeVirtualColumns(QueryContexts.DEFAULT_VECTORIZE_VIRTUAL_COLUMN);
}
public Vectorize getVectorizeVirtualColumns(Vectorize defaultValue)
{
return getEnum(
QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY,
Vectorize.class,
defaultValue
); );
} }
public int getVectorSize()
{
return getVectorSize(QueryableIndexStorageAdapter.DEFAULT_VECTOR_SIZE);
}
public int getVectorSize(int defaultSize)
{
return getInt(QueryContexts.VECTOR_SIZE_KEY, defaultSize);
}
public int getMaxSubqueryRows(int defaultSize)
{
return getInt(QueryContexts.MAX_SUBQUERY_ROWS_KEY, defaultSize);
}
public int getUncoveredIntervalsLimit()
{
return getUncoveredIntervalsLimit(QueryContexts.DEFAULT_UNCOVERED_INTERVALS_LIMIT);
}
public int getUncoveredIntervalsLimit(int defaultValue)
{
return getInt(QueryContexts.UNCOVERED_INTERVALS_LIMIT_KEY, defaultValue);
}
public int getPriority()
{
return getPriority(QueryContexts.DEFAULT_PRIORITY);
}
public int getPriority(int defaultValue)
{
return getInt(QueryContexts.PRIORITY_KEY, defaultValue);
}
public String getLane()
{
return getString(QueryContexts.LANE_KEY);
}
public boolean getEnableParallelMerges()
{
return getBoolean(
QueryContexts.BROKER_PARALLEL_MERGE_KEY,
QueryContexts.DEFAULT_ENABLE_PARALLEL_MERGE
);
}
public int getParallelMergeInitialYieldRows(int defaultValue)
{
return getInt(QueryContexts.BROKER_PARALLEL_MERGE_INITIAL_YIELD_ROWS_KEY, defaultValue);
}
public int getParallelMergeSmallBatchRows(int defaultValue)
{
return getInt(QueryContexts.BROKER_PARALLEL_MERGE_SMALL_BATCH_ROWS_KEY, defaultValue);
}
public int getParallelMergeParallelism(int defaultValue)
{
return getInt(QueryContexts.BROKER_PARALLELISM, defaultValue);
}
public long getJoinFilterRewriteMaxSize()
{
return getLong(
QueryContexts.JOIN_FILTER_REWRITE_MAX_SIZE_KEY,
QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE
);
}
public boolean getEnableJoinFilterPushDown()
{
return getBoolean(
QueryContexts.JOIN_FILTER_PUSH_DOWN_KEY,
QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN
);
}
public boolean getEnableJoinFilterRewrite()
{
return getBoolean(
QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY,
QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE
);
}
public boolean isSecondaryPartitionPruningEnabled()
{
return getBoolean(
QueryContexts.SECONDARY_PARTITION_PRUNING_KEY,
QueryContexts.DEFAULT_SECONDARY_PARTITION_PRUNING
);
}
public long getMaxQueuedBytes(long defaultValue)
{
return getLong(QueryContexts.MAX_QUEUED_BYTES_KEY, defaultValue);
}
public long getMaxScatterGatherBytes()
{
return getLong(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, Long.MAX_VALUE);
}
public boolean hasTimeout()
{
return getTimeout() != QueryContexts.NO_TIMEOUT;
}
public long getTimeout()
{
return getTimeout(getDefaultTimeout());
}
public long getTimeout(long defaultTimeout)
{
final long timeout = getLong(QueryContexts.TIMEOUT_KEY, defaultTimeout);
if (timeout >= 0) {
return timeout;
}
throw new BadQueryContextException(
StringUtils.format(
"Timeout [%s] must be a non negative value, but was %d",
QueryContexts.TIMEOUT_KEY,
timeout
)
);
}
public long getDefaultTimeout()
{
final long defaultTimeout = getLong(QueryContexts.DEFAULT_TIMEOUT_KEY, QueryContexts.DEFAULT_TIMEOUT_MILLIS);
if (defaultTimeout >= 0) {
return defaultTimeout;
}
throw new BadQueryContextException(
StringUtils.format(
"Timeout [%s] must be a non negative value, but was %d",
QueryContexts.DEFAULT_TIMEOUT_KEY,
defaultTimeout
)
);
}
public void verifyMaxQueryTimeout(long maxQueryTimeout)
{
long timeout = getTimeout();
if (timeout > maxQueryTimeout) {
throw new BadQueryContextException(
StringUtils.format(
"Configured %s = %d is more than enforced limit of %d.",
QueryContexts.TIMEOUT_KEY,
timeout,
maxQueryTimeout
)
);
}
}
public void verifyMaxScatterGatherBytes(long maxScatterGatherBytesLimit)
{
long curr = getLong(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, 0);
if (curr > maxScatterGatherBytesLimit) {
throw new BadQueryContextException(
StringUtils.format(
"Configured %s = %d is more than enforced limit of %d.",
QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY,
curr,
maxScatterGatherBytesLimit
)
);
}
}
public int getNumRetriesOnMissingSegments(int defaultValue)
{
return getInt(QueryContexts.NUM_RETRIES_ON_MISSING_SEGMENTS_KEY, defaultValue);
}
public boolean allowReturnPartialResults(boolean defaultValue)
{
return getBoolean(QueryContexts.RETURN_PARTIAL_RESULTS_KEY, defaultValue);
}
public boolean getEnableJoinFilterRewriteValueColumnFilters()
{
return getBoolean(
QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY,
QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS
);
}
public boolean getEnableRewriteJoinToFilter()
{
return getBoolean(
QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY,
QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER
);
}
public boolean getEnableJoinLeftScanDirect()
{
return getBoolean(
QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT,
QueryContexts.DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT
);
}
public int getInSubQueryThreshold()
{
return getInSubQueryThreshold(QueryContexts.DEFAULT_IN_SUB_QUERY_THRESHOLD);
}
public int getInSubQueryThreshold(int defaultValue)
{
return getInt(
QueryContexts.IN_SUB_QUERY_THRESHOLD_KEY,
defaultValue
);
}
public boolean isTimeBoundaryPlanningEnabled()
{
return getBoolean(
QueryContexts.TIME_BOUNDARY_PLANNING_KEY,
QueryContexts.DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING
);
}
public String getBrokerServiceName()
{
return getString(QueryContexts.BROKER_SERVICE_NAME);
}
@Override @Override
public boolean equals(Object o) public boolean equals(Object o)
{ {
@ -259,23 +583,21 @@ public class QueryContext
if (o == null || getClass() != o.getClass()) { if (o == null || getClass() != o.getClass()) {
return false; return false;
} }
QueryContext context = (QueryContext) o; QueryContext other = (QueryContext) o;
return getMergedParams().equals(context.getMergedParams()); return context.equals(other.context);
} }
@Override @Override
public int hashCode() public int hashCode()
{ {
return Objects.hash(getMergedParams()); return Objects.hash(context);
} }
@Override @Override
public String toString() public String toString()
{ {
return "QueryContext{" + return "QueryContext{" +
"defaultParams=" + defaultParams + "context=" + context +
", userParams=" + userParams +
", systemParams=" + systemParams +
'}'; '}';
} }
} }

View File

@ -21,19 +21,19 @@ package org.apache.druid.query;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.guice.annotations.PublicApi; import org.apache.druid.guice.annotations.PublicApi;
import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Numbers; import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.QueryableIndexStorageAdapter;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.TreeMap; import java.util.Map.Entry;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@PublicApi @PublicApi
@ -80,7 +80,13 @@ public class QueryContexts
public static final String SERIALIZE_DATE_TIME_AS_LONG_KEY = "serializeDateTimeAsLong"; public static final String SERIALIZE_DATE_TIME_AS_LONG_KEY = "serializeDateTimeAsLong";
public static final String SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY = "serializeDateTimeAsLongInner"; public static final String SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY = "serializeDateTimeAsLongInner";
public static final String UNCOVERED_INTERVALS_LIMIT_KEY = "uncoveredIntervalsLimit"; public static final String UNCOVERED_INTERVALS_LIMIT_KEY = "uncoveredIntervalsLimit";
public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold";
// SQL query context keys
public static final String CTX_SQL_QUERY_ID = BaseQuery.SQL_QUERY_ID;
public static final String CTX_SQL_STRINGIFY_ARRAYS = "sqlStringifyArrays";
// Defaults
public static final boolean DEFAULT_BY_SEGMENT = false; public static final boolean DEFAULT_BY_SEGMENT = false;
public static final boolean DEFAULT_POPULATE_CACHE = true; public static final boolean DEFAULT_POPULATE_CACHE = true;
public static final boolean DEFAULT_USE_CACHE = true; public static final boolean DEFAULT_USE_CACHE = true;
@ -150,332 +156,42 @@ public class QueryContexts
} }
} }
public static <T> boolean isBySegment(Query<T> query) private QueryContexts()
{ {
return isBySegment(query, DEFAULT_BY_SEGMENT);
} }
public static <T> boolean isBySegment(Query<T> query, boolean defaultValue) public static long parseLong(Map<String, Object> context, String key, long defaultValue)
{
return query.getContextBoolean(BY_SEGMENT_KEY, defaultValue);
}
public static <T> boolean isPopulateCache(Query<T> query)
{
return isPopulateCache(query, DEFAULT_POPULATE_CACHE);
}
public static <T> boolean isPopulateCache(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(POPULATE_CACHE_KEY, defaultValue);
}
public static <T> boolean isUseCache(Query<T> query)
{
return isUseCache(query, DEFAULT_USE_CACHE);
}
public static <T> boolean isUseCache(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(USE_CACHE_KEY, defaultValue);
}
public static <T> boolean isPopulateResultLevelCache(Query<T> query)
{
return isPopulateResultLevelCache(query, DEFAULT_POPULATE_RESULTLEVEL_CACHE);
}
public static <T> boolean isPopulateResultLevelCache(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(POPULATE_RESULT_LEVEL_CACHE_KEY, defaultValue);
}
public static <T> boolean isUseResultLevelCache(Query<T> query)
{
return isUseResultLevelCache(query, DEFAULT_USE_RESULTLEVEL_CACHE);
}
public static <T> boolean isUseResultLevelCache(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(USE_RESULT_LEVEL_CACHE_KEY, defaultValue);
}
public static <T> boolean isFinalize(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(FINALIZE_KEY, defaultValue);
}
public static <T> boolean isSerializeDateTimeAsLong(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(SERIALIZE_DATE_TIME_AS_LONG_KEY, defaultValue);
}
public static <T> boolean isSerializeDateTimeAsLongInner(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY, defaultValue);
}
public static <T> Vectorize getVectorize(Query<T> query)
{
return getVectorize(query, QueryContexts.DEFAULT_VECTORIZE);
}
public static <T> Vectorize getVectorize(Query<T> query, Vectorize defaultValue)
{
return query.getQueryContext().getAsEnum(VECTORIZE_KEY, Vectorize.class, defaultValue);
}
public static <T> Vectorize getVectorizeVirtualColumns(Query<T> query)
{
return getVectorizeVirtualColumns(query, QueryContexts.DEFAULT_VECTORIZE_VIRTUAL_COLUMN);
}
public static <T> Vectorize getVectorizeVirtualColumns(Query<T> query, Vectorize defaultValue)
{
return query.getQueryContext().getAsEnum(VECTORIZE_VIRTUAL_COLUMNS_KEY, Vectorize.class, defaultValue);
}
public static <T> int getVectorSize(Query<T> query)
{
return getVectorSize(query, QueryableIndexStorageAdapter.DEFAULT_VECTOR_SIZE);
}
public static <T> int getVectorSize(Query<T> query, int defaultSize)
{
return query.getQueryContext().getAsInt(VECTOR_SIZE_KEY, defaultSize);
}
public static <T> int getMaxSubqueryRows(Query<T> query, int defaultSize)
{
return query.getQueryContext().getAsInt(MAX_SUBQUERY_ROWS_KEY, defaultSize);
}
public static <T> int getUncoveredIntervalsLimit(Query<T> query)
{
return getUncoveredIntervalsLimit(query, DEFAULT_UNCOVERED_INTERVALS_LIMIT);
}
public static <T> int getUncoveredIntervalsLimit(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(UNCOVERED_INTERVALS_LIMIT_KEY, defaultValue);
}
public static <T> int getPriority(Query<T> query)
{
return getPriority(query, DEFAULT_PRIORITY);
}
public static <T> int getPriority(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(PRIORITY_KEY, defaultValue);
}
public static <T> String getLane(Query<T> query)
{
return query.getQueryContext().getAsString(LANE_KEY);
}
public static <T> boolean getEnableParallelMerges(Query<T> query)
{
return query.getContextBoolean(BROKER_PARALLEL_MERGE_KEY, DEFAULT_ENABLE_PARALLEL_MERGE);
}
public static <T> int getParallelMergeInitialYieldRows(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(BROKER_PARALLEL_MERGE_INITIAL_YIELD_ROWS_KEY, defaultValue);
}
public static <T> int getParallelMergeSmallBatchRows(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(BROKER_PARALLEL_MERGE_SMALL_BATCH_ROWS_KEY, defaultValue);
}
public static <T> int getParallelMergeParallelism(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(BROKER_PARALLELISM, defaultValue);
}
public static <T> boolean getEnableJoinFilterRewriteValueColumnFilters(Query<T> query)
{
return query.getContextBoolean(
JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY,
DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS
);
}
public static <T> boolean getEnableRewriteJoinToFilter(Query<T> query)
{
return query.getContextBoolean(
REWRITE_JOIN_TO_FILTER_ENABLE_KEY,
DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER
);
}
public static <T> long getJoinFilterRewriteMaxSize(Query<T> query)
{
return query.getQueryContext().getAsLong(JOIN_FILTER_REWRITE_MAX_SIZE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE);
}
public static <T> boolean getEnableJoinFilterPushDown(Query<T> query)
{
return query.getContextBoolean(JOIN_FILTER_PUSH_DOWN_KEY, DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN);
}
public static <T> boolean getEnableJoinFilterRewrite(Query<T> query)
{
return query.getContextBoolean(JOIN_FILTER_REWRITE_ENABLE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE);
}
public static boolean getEnableJoinLeftScanDirect(Map<String, Object> context)
{
return parseBoolean(context, SQL_JOIN_LEFT_SCAN_DIRECT, DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT);
}
public static <T> boolean isSecondaryPartitionPruningEnabled(Query<T> query)
{
return query.getContextBoolean(SECONDARY_PARTITION_PRUNING_KEY, DEFAULT_SECONDARY_PARTITION_PRUNING);
}
public static <T> boolean isDebug(Query<T> query)
{
return query.getContextBoolean(ENABLE_DEBUG, DEFAULT_ENABLE_DEBUG);
}
public static boolean isDebug(Map<String, Object> queryContext)
{
return parseBoolean(queryContext, ENABLE_DEBUG, DEFAULT_ENABLE_DEBUG);
}
public static int getInSubQueryThreshold(Map<String, Object> context)
{
return getInSubQueryThreshold(context, DEFAULT_IN_SUB_QUERY_THRESHOLD);
}
public static int getInSubQueryThreshold(Map<String, Object> context, int defaultValue)
{
return parseInt(context, IN_SUB_QUERY_THRESHOLD_KEY, defaultValue);
}
public static boolean isTimeBoundaryPlanningEnabled(Map<String, Object> queryContext)
{
return parseBoolean(queryContext, TIME_BOUNDARY_PLANNING_KEY, DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING);
}
public static <T> Query<T> withMaxScatterGatherBytes(Query<T> query, long maxScatterGatherBytesLimit)
{
Long curr = query.getQueryContext().getAsLong(MAX_SCATTER_GATHER_BYTES_KEY);
if (curr == null) {
return query.withOverriddenContext(ImmutableMap.of(MAX_SCATTER_GATHER_BYTES_KEY, maxScatterGatherBytesLimit));
} else {
if (curr > maxScatterGatherBytesLimit) {
throw new IAE(
"configured [%s = %s] is more than enforced limit of [%s].",
MAX_SCATTER_GATHER_BYTES_KEY,
curr,
maxScatterGatherBytesLimit
);
} else {
return query;
}
}
}
public static <T> Query<T> verifyMaxQueryTimeout(Query<T> query, long maxQueryTimeout)
{
long timeout = getTimeout(query);
if (timeout > maxQueryTimeout) {
throw new IAE(
"configured [%s = %s] is more than enforced limit of maxQueryTimeout [%s].",
TIMEOUT_KEY,
timeout,
maxQueryTimeout
);
} else {
return query;
}
}
public static <T> long getMaxQueuedBytes(Query<T> query, long defaultValue)
{
return query.getQueryContext().getAsLong(MAX_QUEUED_BYTES_KEY, defaultValue);
}
public static <T> long getMaxScatterGatherBytes(Query<T> query)
{
return query.getQueryContext().getAsLong(MAX_SCATTER_GATHER_BYTES_KEY, Long.MAX_VALUE);
}
public static <T> boolean hasTimeout(Query<T> query)
{
return getTimeout(query) != NO_TIMEOUT;
}
public static <T> long getTimeout(Query<T> query)
{
return getTimeout(query, getDefaultTimeout(query));
}
public static <T> long getTimeout(Query<T> query, long defaultTimeout)
{
try {
final long timeout = query.getQueryContext().getAsLong(TIMEOUT_KEY, defaultTimeout);
Preconditions.checkState(timeout >= 0, "Timeout must be a non negative value, but was [%s]", timeout);
return timeout;
}
catch (IAE e) {
throw new BadQueryContextException(e);
}
}
public static <T> Query<T> withTimeout(Query<T> query, long timeout)
{
return query.withOverriddenContext(ImmutableMap.of(TIMEOUT_KEY, timeout));
}
public static <T> Query<T> withDefaultTimeout(Query<T> query, long defaultTimeout)
{
return query.withOverriddenContext(ImmutableMap.of(QueryContexts.DEFAULT_TIMEOUT_KEY, defaultTimeout));
}
static <T> long getDefaultTimeout(Query<T> query)
{
final long defaultTimeout = query.getQueryContext().getAsLong(DEFAULT_TIMEOUT_KEY, DEFAULT_TIMEOUT_MILLIS);
Preconditions.checkState(defaultTimeout >= 0, "Timeout must be a non negative value, but was [%s]", defaultTimeout);
return defaultTimeout;
}
public static <T> int getNumRetriesOnMissingSegments(Query<T> query, int defaultValue)
{
return query.getQueryContext().getAsInt(NUM_RETRIES_ON_MISSING_SEGMENTS_KEY, defaultValue);
}
public static <T> boolean allowReturnPartialResults(Query<T> query, boolean defaultValue)
{
return query.getContextBoolean(RETURN_PARTIAL_RESULTS_KEY, defaultValue);
}
public static String getBrokerServiceName(Map<String, Object> queryContext)
{
return queryContext == null ? null : (String) queryContext.get(BROKER_SERVICE_NAME);
}
@SuppressWarnings("unused")
static <T> long parseLong(Map<String, Object> context, String key, long defaultValue)
{ {
return getAsLong(key, context.get(key), defaultValue); return getAsLong(key, context.get(key), defaultValue);
} }
static int parseInt(Map<String, Object> context, String key, int defaultValue) public static int parseInt(Map<String, Object> context, String key, int defaultValue)
{ {
return getAsInt(key, context.get(key), defaultValue); return getAsInt(key, context.get(key), defaultValue);
} }
static boolean parseBoolean(Map<String, Object> context, String key, boolean defaultValue) @Nullable
public static String parseString(Map<String, Object> context, String key)
{
return parseString(context, key, null);
}
public static boolean parseBoolean(Map<String, Object> context, String key, boolean defaultValue)
{ {
return getAsBoolean(key, context.get(key), defaultValue); return getAsBoolean(key, context.get(key), defaultValue);
} }
public static String parseString(Map<String, Object> context, String key, String defaultValue)
{
return getAsString(key, context.get(key), defaultValue);
}
@SuppressWarnings("unused") // To keep IntelliJ inspections happy
public static float parseFloat(Map<String, Object> context, String key, float defaultValue)
{
return getAsFloat(key, context.get(key), defaultValue);
}
public static String getAsString( public static String getAsString(
final String key, final String key,
final Object value, final Object value,
@ -486,14 +202,13 @@ public class QueryContexts
return defaultValue; return defaultValue;
} else if (value instanceof String) { } else if (value instanceof String) {
return (String) value; return (String) value;
} else {
throw new IAE("Expected key [%s] to be a String, but got [%s]", key, value.getClass().getName());
} }
throw badTypeException(key, "a String", value);
} }
@Nullable @Nullable
public static Boolean getAsBoolean( public static Boolean getAsBoolean(
final String parameter, final String key,
final Object value final Object value
) )
{ {
@ -503,13 +218,12 @@ public class QueryContexts
return Boolean.parseBoolean((String) value); return Boolean.parseBoolean((String) value);
} else if (value instanceof Boolean) { } else if (value instanceof Boolean) {
return (Boolean) value; return (Boolean) value;
} else {
throw new IAE("Expected parameter [%s] to be a Boolean, but got [%s]", parameter, value.getClass().getName());
} }
throw badTypeException(key, "a Boolean", value);
} }
/** /**
* Get the value of a parameter as a {@code boolean}. The parameter is expected * Get the value of a context value as a {@code boolean}. The value is expected
* to be {@code null}, a string or a {@code Boolean} object. * to be {@code null}, a string or a {@code Boolean} object.
*/ */
public static boolean getAsBoolean( public static boolean getAsBoolean(
@ -534,24 +248,33 @@ public class QueryContexts
return Numbers.parseInt(value); return Numbers.parseInt(value);
} }
catch (NumberFormatException ignored) { catch (NumberFormatException ignored) {
throw new IAE("Expected key [%s] in integer format, but got [%s]", key, value);
// Attempt to handle trivial decimal values: 12.00, etc.
// This mimics how Jackson will convert "12.00" to a Integer on request.
try {
return new BigDecimal((String) value).intValueExact();
}
catch (Exception nfe) {
// That didn't work either. Give up.
throw badValueException(key, "in integer format", value);
}
} }
} }
throw new IAE("Expected key [%s] to be an Integer, but got [%s]", key, value.getClass().getName()); throw badTypeException(key, "an Integer", value);
} }
/** /**
* Get the value of a parameter as an {@code int}. The parameter is expected * Get the value of a context value as an {@code int}. The value is expected
* to be {@code null}, a string or a {@code Number} object. * to be {@code null}, a string or a {@code Number} object.
*/ */
public static int getAsInt( public static int getAsInt(
final String ke, final String key,
final Object value, final Object value,
final int defaultValue final int defaultValue
) )
{ {
Integer val = getAsInt(ke, value); Integer val = getAsInt(key, value);
return val == null ? defaultValue : val; return val == null ? defaultValue : val;
} }
@ -567,14 +290,23 @@ public class QueryContexts
return Numbers.parseLong(value); return Numbers.parseLong(value);
} }
catch (NumberFormatException ignored) { catch (NumberFormatException ignored) {
throw new IAE("Expected key [%s] in long format, but got [%s]", key, value);
// Attempt to handle trivial decimal values: 12.00, etc.
// This mimics how Jackson will convert "12.00" to a Long on request.
try {
return new BigDecimal((String) value).longValueExact();
}
catch (Exception nfe) {
// That didn't work either. Give up.
throw badValueException(key, "in long format", value);
}
} }
} }
throw new IAE("Expected key [%s] to be a Long, but got [%s]", key, value.getClass().getName()); throw badTypeException(key, "a Long", value);
} }
/** /**
* Get the value of a parameter as an {@code long}. The parameter is expected * Get the value of a context value as an {@code long}. The value is expected
* to be {@code null}, a string or a {@code Number} object. * to be {@code null}, a string or a {@code Number} object.
*/ */
public static long getAsLong( public static long getAsLong(
@ -587,8 +319,39 @@ public class QueryContexts
return val == null ? defaultValue : val; return val == null ? defaultValue : val;
} }
/**
* Get the value of a context value as an {@code Float}. The value is expected
* to be {@code null}, a string or a {@code Number} object.
*/
public static Float getAsFloat(final String key, final Object value)
{
if (value == null) {
return null;
} else if (value instanceof Number) {
return ((Number) value).floatValue();
} else if (value instanceof String) {
try {
return Float.parseFloat((String) value);
}
catch (NumberFormatException ignored) {
throw badValueException(key, "in float format", value);
}
}
throw badTypeException(key, "a Float", value);
}
public static float getAsFloat(
final String key,
final Object value,
final float defaultValue
)
{
Float val = getAsFloat(key, value);
return val == null ? defaultValue : val;
}
public static HumanReadableBytes getAsHumanReadableBytes( public static HumanReadableBytes getAsHumanReadableBytes(
final String parameter, final String key,
final Object value, final Object value,
final HumanReadableBytes defaultValue final HumanReadableBytes defaultValue
) )
@ -602,73 +365,126 @@ public class QueryContexts
return HumanReadableBytes.valueOf(HumanReadableBytes.parse((String) value)); return HumanReadableBytes.valueOf(HumanReadableBytes.parse((String) value));
} }
catch (IAE e) { catch (IAE e) {
throw new IAE("Expected key [%s] in human readable format, but got [%s]", parameter, value); throw badValueException(key, "a human readable number", value);
} }
} }
throw new IAE("Expected key [%s] to be a human readable number, but got [%s]", parameter, value.getClass().getName()); throw badTypeException(key, "a human readable number", value);
} }
public static float getAsFloat(String key, Object value, float defaultValue) /**
* Insert, update or remove a single key to produce an overridden context.
* Leaves the original context unchanged.
*
* @param context context to override
* @param key key to insert, update or remove
* @param value if {@code null}, remove the key. Otherwise, insert or replace
* the key.
* @return a new context map
*/
public static Map<String, Object> override(
final Map<String, Object> context,
final String key,
final Object value
)
{ {
if (null == value) { Map<String, Object> overridden = new HashMap<>(context);
return defaultValue; if (value == null) {
} else if (value instanceof Number) { overridden.remove(key);
return ((Number) value).floatValue(); } else {
} else if (value instanceof String) { overridden.put(key, value);
try {
return Float.parseFloat((String) value);
}
catch (NumberFormatException ignored) {
throw new IAE("Expected key [%s] in float format, but got [%s]", key, value);
}
} }
throw new IAE("Expected key [%s] to be a Float, but got [%s]", key, value.getClass().getName()); return overridden;
} }
/**
* Insert or replace multiple keys to produce an overridden context.
* Leaves the original context unchanged.
*
* @param context context to override
* @param overrides map of values to insert or replace
* @return a new context map
*/
public static Map<String, Object> override( public static Map<String, Object> override(
final Map<String, Object> context, final Map<String, Object> context,
final Map<String, Object> overrides final Map<String, Object> overrides
) )
{ {
Map<String, Object> overridden = new TreeMap<>(); Map<String, Object> overridden = new HashMap<>();
if (context != null) { if (context != null) {
overridden.putAll(context); overridden.putAll(context);
} }
overridden.putAll(overrides); if (overrides != null) {
overridden.putAll(overrides);
}
return overridden; return overridden;
} }
private QueryContexts() public static <E extends Enum<E>> E getAsEnum(String key, Object value, Class<E> clazz, E defaultValue)
{ {
} if (value == null) {
public static <E extends Enum<E>> E getAsEnum(String key, Object val, Class<E> clazz, E defaultValue)
{
if (val == null) {
return defaultValue; return defaultValue;
} }
try { try {
if (val instanceof String) { if (value instanceof String) {
return Enum.valueOf(clazz, StringUtils.toUpperCase((String) val)); return Enum.valueOf(clazz, StringUtils.toUpperCase((String) value));
} else if (val instanceof Boolean) { } else if (value instanceof Boolean) {
return Enum.valueOf(clazz, StringUtils.toUpperCase(String.valueOf(val))); return Enum.valueOf(clazz, StringUtils.toUpperCase(String.valueOf(value)));
} }
} }
catch (IllegalArgumentException e) { catch (IllegalArgumentException e) {
throw new IAE("Expected key [%s] must be value of enum [%s], but got [%s].", throw badValueException(
key, key,
clazz.getName(), StringUtils.format("a value of enum [%s]", clazz.getSimpleName()),
val.toString()); value
);
} }
throw new ISE( throw badTypeException(
"Expected key [%s] must be type of [%s], actual type is [%s].",
key, key,
clazz.getName(), StringUtils.format("of type [%s]", clazz.getSimpleName()),
val.getClass() value
); );
} }
public static BadQueryContextException badValueException(
final String key,
final String expected,
final Object actual
)
{
return new BadQueryContextException(
StringUtils.format(
"Expected key [%s] to be in %s, but got [%s]",
key,
expected,
actual
)
);
}
public static BadQueryContextException badTypeException(
final String key,
final String expected,
final Object actual
)
{
return new BadQueryContextException(
StringUtils.format(
"Expected key [%s] to be %s, but got [%s]",
key,
expected,
actual.getClass().getName()
)
);
}
public static void addDefaults(Map<String, Object> context, Map<String, Object> defaults)
{
for (Entry<String, Object> entry : defaults.entrySet()) {
context.putIfAbsent(entry.getKey(), entry.getValue());
}
}
} }

View File

@ -41,7 +41,7 @@ public class SubqueryQueryRunner<T> implements QueryRunner<T>
{ {
DataSource dataSource = queryPlus.getQuery().getDataSource(); DataSource dataSource = queryPlus.getQuery().getDataSource();
boolean forcePushDownNestedQuery = queryPlus.getQuery() boolean forcePushDownNestedQuery = queryPlus.getQuery()
.getContextBoolean( .context().getBoolean(
GroupByQueryConfig.CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY, GroupByQueryConfig.CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY,
false false
); );

View File

@ -450,7 +450,7 @@ public class GroupByQuery extends BaseQuery<ResultRow>
@JsonIgnore @JsonIgnore
public boolean getContextSortByDimsFirst() public boolean getContextSortByDimsFirst()
{ {
return getContextBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false); return context().getBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false);
} }
@JsonIgnore @JsonIgnore
@ -465,7 +465,7 @@ public class GroupByQuery extends BaseQuery<ResultRow>
@JsonIgnore @JsonIgnore
public boolean getApplyLimitPushDownFromContext() public boolean getApplyLimitPushDownFromContext()
{ {
return getContextBoolean(GroupByQueryConfig.CTX_KEY_APPLY_LIMIT_PUSH_DOWN, true); return context().getBoolean(GroupByQueryConfig.CTX_KEY_APPLY_LIMIT_PUSH_DOWN, true);
} }
@Override @Override
@ -487,7 +487,7 @@ public class GroupByQuery extends BaseQuery<ResultRow>
private boolean validateAndGetForceLimitPushDown() private boolean validateAndGetForceLimitPushDown()
{ {
final boolean forcePushDown = getContextBoolean(GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, false); final boolean forcePushDown = context().getBoolean(GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, false);
if (forcePushDown) { if (forcePushDown) {
if (!(limitSpec instanceof DefaultLimitSpec)) { if (!(limitSpec instanceof DefaultLimitSpec)) {
throw new IAE("When forcing limit push down, a limit spec must be provided."); throw new IAE("When forcing limit push down, a limit spec must be provided.");
@ -748,7 +748,7 @@ public class GroupByQuery extends BaseQuery<ResultRow>
@Nullable @Nullable
private DateTime computeUniversalTimestamp() private DateTime computeUniversalTimestamp()
{ {
final String timestampStringFromContext = getQueryContext().getAsString(CTX_KEY_FUDGE_TIMESTAMP, ""); final String timestampStringFromContext = context().getString(CTX_KEY_FUDGE_TIMESTAMP, "");
final Granularity granularity = getGranularity(); final Granularity granularity = getGranularity();
if (!timestampStringFromContext.isEmpty()) { if (!timestampStringFromContext.isEmpty()) {

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.groupby.strategy.GroupByStrategySelector; import org.apache.druid.query.groupby.strategy.GroupByStrategySelector;
import org.apache.druid.utils.JvmUtils; import org.apache.druid.utils.JvmUtils;
@ -335,25 +336,26 @@ public class GroupByQueryConfig
public GroupByQueryConfig withOverrides(final GroupByQuery query) public GroupByQueryConfig withOverrides(final GroupByQuery query)
{ {
final GroupByQueryConfig newConfig = new GroupByQueryConfig(); final GroupByQueryConfig newConfig = new GroupByQueryConfig();
newConfig.defaultStrategy = query.getQueryContext().getAsString(CTX_KEY_STRATEGY, getDefaultStrategy()); final QueryContext queryContext = query.context();
newConfig.singleThreaded = query.getQueryContext().getAsBoolean(CTX_KEY_IS_SINGLE_THREADED, isSingleThreaded()); newConfig.defaultStrategy = queryContext.getString(CTX_KEY_STRATEGY, getDefaultStrategy());
newConfig.singleThreaded = queryContext.getBoolean(CTX_KEY_IS_SINGLE_THREADED, isSingleThreaded());
newConfig.maxIntermediateRows = Math.min( newConfig.maxIntermediateRows = Math.min(
query.getQueryContext().getAsInt(CTX_KEY_MAX_INTERMEDIATE_ROWS, getMaxIntermediateRows()), queryContext.getInt(CTX_KEY_MAX_INTERMEDIATE_ROWS, getMaxIntermediateRows()),
getMaxIntermediateRows() getMaxIntermediateRows()
); );
newConfig.maxResults = Math.min( newConfig.maxResults = Math.min(
query.getQueryContext().getAsInt(CTX_KEY_MAX_RESULTS, getMaxResults()), queryContext.getInt(CTX_KEY_MAX_RESULTS, getMaxResults()),
getMaxResults() getMaxResults()
); );
newConfig.bufferGrouperMaxSize = Math.min( newConfig.bufferGrouperMaxSize = Math.min(
query.getQueryContext().getAsInt(CTX_KEY_BUFFER_GROUPER_MAX_SIZE, getBufferGrouperMaxSize()), queryContext.getInt(CTX_KEY_BUFFER_GROUPER_MAX_SIZE, getBufferGrouperMaxSize()),
getBufferGrouperMaxSize() getBufferGrouperMaxSize()
); );
newConfig.bufferGrouperMaxLoadFactor = query.getQueryContext().getAsFloat( newConfig.bufferGrouperMaxLoadFactor = queryContext.getFloat(
CTX_KEY_BUFFER_GROUPER_MAX_LOAD_FACTOR, CTX_KEY_BUFFER_GROUPER_MAX_LOAD_FACTOR,
getBufferGrouperMaxLoadFactor() getBufferGrouperMaxLoadFactor()
); );
newConfig.bufferGrouperInitialBuckets = query.getQueryContext().getAsInt( newConfig.bufferGrouperInitialBuckets = queryContext.getInt(
CTX_KEY_BUFFER_GROUPER_INITIAL_BUCKETS, CTX_KEY_BUFFER_GROUPER_INITIAL_BUCKETS,
getBufferGrouperInitialBuckets() getBufferGrouperInitialBuckets()
); );
@ -362,33 +364,33 @@ public class GroupByQueryConfig
// choose a default value lower than the max allowed when the context key is missing in the client query. // choose a default value lower than the max allowed when the context key is missing in the client query.
newConfig.maxOnDiskStorage = HumanReadableBytes.valueOf( newConfig.maxOnDiskStorage = HumanReadableBytes.valueOf(
Math.min( Math.min(
query.getContextAsHumanReadableBytes(CTX_KEY_MAX_ON_DISK_STORAGE, getDefaultOnDiskStorage()).getBytes(), queryContext.getHumanReadableBytes(CTX_KEY_MAX_ON_DISK_STORAGE, getDefaultOnDiskStorage()).getBytes(),
getMaxOnDiskStorage().getBytes() getMaxOnDiskStorage().getBytes()
) )
); );
newConfig.maxSelectorDictionarySize = maxSelectorDictionarySize; // No overrides newConfig.maxSelectorDictionarySize = maxSelectorDictionarySize; // No overrides
newConfig.maxMergingDictionarySize = maxMergingDictionarySize; // No overrides newConfig.maxMergingDictionarySize = maxMergingDictionarySize; // No overrides
newConfig.forcePushDownLimit = query.getContextBoolean(CTX_KEY_FORCE_LIMIT_PUSH_DOWN, isForcePushDownLimit()); newConfig.forcePushDownLimit = queryContext.getBoolean(CTX_KEY_FORCE_LIMIT_PUSH_DOWN, isForcePushDownLimit());
newConfig.applyLimitPushDownToSegment = query.getContextBoolean( newConfig.applyLimitPushDownToSegment = queryContext.getBoolean(
CTX_KEY_APPLY_LIMIT_PUSH_DOWN_TO_SEGMENT, CTX_KEY_APPLY_LIMIT_PUSH_DOWN_TO_SEGMENT,
isApplyLimitPushDownToSegment() isApplyLimitPushDownToSegment()
); );
newConfig.forceHashAggregation = query.getContextBoolean(CTX_KEY_FORCE_HASH_AGGREGATION, isForceHashAggregation()); newConfig.forceHashAggregation = queryContext.getBoolean(CTX_KEY_FORCE_HASH_AGGREGATION, isForceHashAggregation());
newConfig.forcePushDownNestedQuery = query.getContextBoolean( newConfig.forcePushDownNestedQuery = queryContext.getBoolean(
CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY, CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY,
isForcePushDownNestedQuery() isForcePushDownNestedQuery()
); );
newConfig.intermediateCombineDegree = query.getQueryContext().getAsInt( newConfig.intermediateCombineDegree = queryContext.getInt(
CTX_KEY_INTERMEDIATE_COMBINE_DEGREE, CTX_KEY_INTERMEDIATE_COMBINE_DEGREE,
getIntermediateCombineDegree() getIntermediateCombineDegree()
); );
newConfig.numParallelCombineThreads = query.getQueryContext().getAsInt( newConfig.numParallelCombineThreads = queryContext.getInt(
CTX_KEY_NUM_PARALLEL_COMBINE_THREADS, CTX_KEY_NUM_PARALLEL_COMBINE_THREADS,
getNumParallelCombineThreads() getNumParallelCombineThreads()
); );
newConfig.mergeThreadLocal = query.getContextBoolean(CTX_KEY_MERGE_THREAD_LOCAL, isMergeThreadLocal()); newConfig.mergeThreadLocal = queryContext.getBoolean(CTX_KEY_MERGE_THREAD_LOCAL, isMergeThreadLocal());
newConfig.vectorize = query.getContextBoolean(QueryContexts.VECTORIZE_KEY, isVectorize()); newConfig.vectorize = queryContext.getBoolean(QueryContexts.VECTORIZE_KEY, isVectorize());
newConfig.enableMultiValueUnnesting = query.getContextBoolean( newConfig.enableMultiValueUnnesting = queryContext.getBoolean(
CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING,
isMultiValueUnnestingEnabled() isMultiValueUnnestingEnabled()
); );

View File

@ -96,7 +96,7 @@ public class GroupByQueryEngine
"Null storage adapter found. Probably trying to issue a query against a segment being memory unmapped." "Null storage adapter found. Probably trying to issue a query against a segment being memory unmapped."
); );
} }
if (!query.getContextBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true)) { if (!query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true)) {
throw new UOE( throw new UOE(
"GroupBy v1 does not support %s as false. Set %s to true or use groupBy v2", "GroupBy v1 does not support %s as false. Set %s to true or use groupBy v2",
GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING,

View File

@ -100,7 +100,7 @@ public class GroupByQueryHelper
); );
final IncrementalIndex index; final IncrementalIndex index;
final boolean sortResults = query.getContextBoolean(CTX_KEY_SORT_RESULTS, true); final boolean sortResults = query.context().getBoolean(CTX_KEY_SORT_RESULTS, true);
// All groupBy dimensions are strings, for now. // All groupBy dimensions are strings, for now.
final List<DimensionSchema> dimensionSchemas = new ArrayList<>(); final List<DimensionSchema> dimensionSchemas = new ArrayList<>();
@ -118,7 +118,7 @@ public class GroupByQueryHelper
final AppendableIndexBuilder indexBuilder; final AppendableIndexBuilder indexBuilder;
if (query.getContextBoolean("useOffheap", false)) { if (query.context().getBoolean("useOffheap", false)) {
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"The 'useOffheap' option is no longer available for groupBy v1. Please move to the newer groupBy engine, " "The 'useOffheap' option is no longer available for groupBy v1. Please move to the newer groupBy engine, "
+ "which always operates off-heap, by removing any custom 'druid.query.groupBy.defaultStrategy' runtime " + "which always operates off-heap, by removing any custom 'druid.query.groupBy.defaultStrategy' runtime "

View File

@ -45,7 +45,6 @@ import org.apache.druid.java.util.common.jackson.JacksonUtils;
import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.DataSource; import org.apache.druid.query.DataSource;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
@ -118,7 +117,7 @@ public class GroupByQueryQueryToolChest extends QueryToolChest<ResultRow, GroupB
public QueryRunner<ResultRow> mergeResults(final QueryRunner<ResultRow> runner) public QueryRunner<ResultRow> mergeResults(final QueryRunner<ResultRow> runner)
{ {
return (queryPlus, responseContext) -> { return (queryPlus, responseContext) -> {
if (QueryContexts.isBySegment(queryPlus.getQuery())) { if (queryPlus.getQuery().context().isBySegment()) {
return runner.run(queryPlus, responseContext); return runner.run(queryPlus, responseContext);
} }
@ -304,7 +303,7 @@ public class GroupByQueryQueryToolChest extends QueryToolChest<ResultRow, GroupB
private Sequence<ResultRow> finalizeSubqueryResults(Sequence<ResultRow> subqueryResult, GroupByQuery subquery) private Sequence<ResultRow> finalizeSubqueryResults(Sequence<ResultRow> subqueryResult, GroupByQuery subquery)
{ {
final Sequence<ResultRow> finalizingResults; final Sequence<ResultRow> finalizingResults;
if (QueryContexts.isFinalize(subquery, false)) { if (subquery.context().isFinalize(false)) {
finalizingResults = new MappedSequence<>( finalizingResults = new MappedSequence<>(
subqueryResult, subqueryResult,
makePreComputeManipulatorFn( makePreComputeManipulatorFn(
@ -321,7 +320,7 @@ public class GroupByQueryQueryToolChest extends QueryToolChest<ResultRow, GroupB
public static boolean isNestedQueryPushDown(GroupByQuery q, GroupByStrategy strategy) public static boolean isNestedQueryPushDown(GroupByQuery q, GroupByStrategy strategy)
{ {
return q.getDataSource() instanceof QueryDataSource return q.getDataSource() instanceof QueryDataSource
&& q.getContextBoolean(GroupByQueryConfig.CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY, false) && q.context().getBoolean(GroupByQueryConfig.CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY, false)
&& q.getSubtotalsSpec() == null && q.getSubtotalsSpec() == null
&& strategy.supportsNestedQueryPushDown(); && strategy.supportsNestedQueryPushDown();
} }
@ -418,7 +417,7 @@ public class GroupByQueryQueryToolChest extends QueryToolChest<ResultRow, GroupB
@Override @Override
public ObjectMapper decorateObjectMapper(final ObjectMapper objectMapper, final GroupByQuery query) public ObjectMapper decorateObjectMapper(final ObjectMapper objectMapper, final GroupByQuery query)
{ {
final boolean resultAsArray = query.getContextBoolean(GroupByQueryConfig.CTX_KEY_ARRAY_RESULT_ROWS, false); final boolean resultAsArray = query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ARRAY_RESULT_ROWS, false);
if (resultAsArray && !queryConfig.isIntermediateResultAsMapCompat()) { if (resultAsArray && !queryConfig.isIntermediateResultAsMapCompat()) {
// We can assume ResultRow are serialized and deserialized as arrays. No need for special decoration, // We can assume ResultRow are serialized and deserialized as arrays. No need for special decoration,

View File

@ -45,7 +45,7 @@ import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.AbstractPrioritizedQueryRunnerCallable; import org.apache.druid.query.AbstractPrioritizedQueryRunnerCallable;
import org.apache.druid.query.ChainedExecutionQueryRunner; import org.apache.druid.query.ChainedExecutionQueryRunner;
import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryProcessingPool;
@ -134,7 +134,7 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
// merge buffer, otherwise the query will allocate too many merge buffers. This is potentially sub-optimal as it // merge buffer, otherwise the query will allocate too many merge buffers. This is potentially sub-optimal as it
// will involve materializing the results for each sink before starting to feed them into the outer merge buffer. // will involve materializing the results for each sink before starting to feed them into the outer merge buffer.
// I'm not sure of a better way to do this without tweaking how realtime servers do queries. // I'm not sure of a better way to do this without tweaking how realtime servers do queries.
final boolean forceChainedExecution = query.getContextBoolean( final boolean forceChainedExecution = query.context().getBoolean(
CTX_KEY_MERGE_RUNNERS_USING_CHAINED_EXECUTION, CTX_KEY_MERGE_RUNNERS_USING_CHAINED_EXECUTION,
false false
); );
@ -144,7 +144,8 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
) )
.withoutThreadUnsafeState(); .withoutThreadUnsafeState();
if (QueryContexts.isBySegment(query) || forceChainedExecution) { final QueryContext queryContext = query.context();
if (queryContext.isBySegment() || forceChainedExecution) {
ChainedExecutionQueryRunner<ResultRow> runner = new ChainedExecutionQueryRunner<>(queryProcessingPool, queryWatcher, queryables); ChainedExecutionQueryRunner<ResultRow> runner = new ChainedExecutionQueryRunner<>(queryProcessingPool, queryWatcher, queryables);
return runner.run(queryPlusForRunners, responseContext); return runner.run(queryPlusForRunners, responseContext);
} }
@ -156,12 +157,12 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
StringUtils.format("druid-groupBy-%s_%s", UUID.randomUUID(), query.getId()) StringUtils.format("druid-groupBy-%s_%s", UUID.randomUUID(), query.getId())
); );
final int priority = QueryContexts.getPriority(query); final int priority = queryContext.getPriority();
// Figure out timeoutAt time now, so we can apply the timeout to both the mergeBufferPool.take and the actual // Figure out timeoutAt time now, so we can apply the timeout to both the mergeBufferPool.take and the actual
// query processing together. // query processing together.
final long queryTimeout = QueryContexts.getTimeout(query); final long queryTimeout = queryContext.getTimeout();
final boolean hasTimeout = QueryContexts.hasTimeout(query); final boolean hasTimeout = queryContext.hasTimeout();
final long timeoutAt = System.currentTimeMillis() + queryTimeout; final long timeoutAt = System.currentTimeMillis() + queryTimeout;
return new BaseSequence<>( return new BaseSequence<>(

View File

@ -34,7 +34,6 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.ColumnSelectorPlus; import org.apache.druid.query.ColumnSelectorPlus;
import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.aggregation.AggregatorAdapters; import org.apache.druid.query.aggregation.AggregatorAdapters;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.dimension.ColumnSelectorStrategyFactory; import org.apache.druid.query.dimension.ColumnSelectorStrategyFactory;
@ -77,6 +76,7 @@ import org.joda.time.DateTime;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.io.Closeable; import java.io.Closeable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Iterator; import java.util.Iterator;
@ -141,7 +141,7 @@ public class GroupByQueryEngineV2
try { try {
final String fudgeTimestampString = NullHandling.emptyToNullIfNeeded( final String fudgeTimestampString = NullHandling.emptyToNullIfNeeded(
query.getQueryContext().getAsString(GroupByStrategyV2.CTX_KEY_FUDGE_TIMESTAMP) query.context().getString(GroupByStrategyV2.CTX_KEY_FUDGE_TIMESTAMP)
); );
final DateTime fudgeTimestamp = fudgeTimestampString == null final DateTime fudgeTimestamp = fudgeTimestampString == null
@ -151,7 +151,7 @@ public class GroupByQueryEngineV2
final Filter filter = Filters.convertToCNFFromQueryContext(query, Filters.toFilter(query.getFilter())); final Filter filter = Filters.convertToCNFFromQueryContext(query, Filters.toFilter(query.getFilter()));
final Interval interval = Iterables.getOnlyElement(query.getIntervals()); final Interval interval = Iterables.getOnlyElement(query.getIntervals());
final boolean doVectorize = QueryContexts.getVectorize(query).shouldVectorize( final boolean doVectorize = query.context().getVectorize().shouldVectorize(
VectorGroupByEngine.canVectorize(query, storageAdapter, filter) VectorGroupByEngine.canVectorize(query, storageAdapter, filter)
); );
@ -496,7 +496,7 @@ public class GroupByQueryEngineV2
// Time is the same for every row in the cursor // Time is the same for every row in the cursor
this.timestamp = fudgeTimestamp != null ? fudgeTimestamp : cursor.getTime(); this.timestamp = fudgeTimestamp != null ? fudgeTimestamp : cursor.getTime();
this.allSingleValueDims = allSingleValueDims; this.allSingleValueDims = allSingleValueDims;
this.allowMultiValueGrouping = query.getContextBoolean( this.allowMultiValueGrouping = query.context().getBoolean(
GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING,
true true
); );

View File

@ -28,7 +28,6 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.parsers.CloseableIterator; import org.apache.druid.java.util.common.parsers.CloseableIterator;
import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.aggregation.AggregatorAdapters; import org.apache.druid.query.aggregation.AggregatorAdapters;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.filter.Filter; import org.apache.druid.query.filter.Filter;
@ -56,6 +55,7 @@ import org.joda.time.DateTime;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Collections; import java.util.Collections;
@ -150,7 +150,7 @@ public class VectorGroupByEngine
interval, interval,
query.getVirtualColumns(), query.getVirtualColumns(),
false, false,
QueryContexts.getVectorSize(query), query.context().getVectorSize(),
groupByQueryMetrics groupByQueryMetrics
); );

View File

@ -37,6 +37,7 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.TopNSequence; import org.apache.druid.java.util.common.guava.TopNSequence;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
@ -232,9 +233,11 @@ public class DefaultLimitSpec implements LimitSpec
} }
if (!sortingNeeded) { if (!sortingNeeded) {
String timestampField = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD); final QueryContext queryContext = query.context();
String timestampField = queryContext.getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD);
if (timestampField != null && !timestampField.isEmpty()) { if (timestampField != null && !timestampField.isEmpty()) {
int timestampResultFieldIndex = query.getQueryContext().getAsInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX); // Will NPE if the key is not set
int timestampResultFieldIndex = queryContext.getInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX);
sortingNeeded = query.getContextSortByDimsFirst() sortingNeeded = query.getContextSortByDimsFirst()
? timestampResultFieldIndex != query.getDimensions().size() - 1 ? timestampResultFieldIndex != query.getDimensions().size() - 1
: timestampResultFieldIndex != 0; : timestampResultFieldIndex != 0;

View File

@ -91,7 +91,7 @@ public class GroupByStrategyV1 implements GroupByStrategy
@Override @Override
public boolean doMergeResults(final GroupByQuery query) public boolean doMergeResults(final GroupByQuery query)
{ {
return query.getContextBoolean(GroupByQueryQueryToolChest.GROUP_BY_MERGE_KEY, true); return query.context().getBoolean(GroupByQueryQueryToolChest.GROUP_BY_MERGE_KEY, true);
} }
@Override @Override

View File

@ -44,6 +44,7 @@ import org.apache.druid.query.DataSource;
import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException; import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
@ -132,8 +133,9 @@ public class GroupByStrategyV2 implements GroupByStrategy
return new GroupByQueryResource(); return new GroupByQueryResource();
} else { } else {
final List<ReferenceCountingResourceHolder<ByteBuffer>> mergeBufferHolders; final List<ReferenceCountingResourceHolder<ByteBuffer>> mergeBufferHolders;
if (QueryContexts.hasTimeout(query)) { final QueryContext context = query.context();
mergeBufferHolders = mergeBufferPool.takeBatch(requiredMergeBufferNum, QueryContexts.getTimeout(query)); if (context.hasTimeout()) {
mergeBufferHolders = mergeBufferPool.takeBatch(requiredMergeBufferNum, context.getTimeout());
} else { } else {
mergeBufferHolders = mergeBufferPool.takeBatch(requiredMergeBufferNum); mergeBufferHolders = mergeBufferPool.takeBatch(requiredMergeBufferNum);
} }
@ -221,9 +223,10 @@ public class GroupByStrategyV2 implements GroupByStrategy
Granularity granularity = query.getGranularity(); Granularity granularity = query.getGranularity();
List<DimensionSpec> dimensionSpecs = query.getDimensions(); List<DimensionSpec> dimensionSpecs = query.getDimensions();
// the CTX_TIMESTAMP_RESULT_FIELD is set in DruidQuery.java // the CTX_TIMESTAMP_RESULT_FIELD is set in DruidQuery.java
final String timestampResultField = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD); final QueryContext queryContext = query.context();
final String timestampResultField = queryContext.getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD);
final boolean hasTimestampResultField = (timestampResultField != null && !timestampResultField.isEmpty()) final boolean hasTimestampResultField = (timestampResultField != null && !timestampResultField.isEmpty())
&& query.getContextBoolean(CTX_KEY_OUTERMOST, true) && queryContext.getBoolean(CTX_KEY_OUTERMOST, true)
&& !query.isApplyLimitPushDown(); && !query.isApplyLimitPushDown();
int timestampResultFieldIndex = 0; int timestampResultFieldIndex = 0;
if (hasTimestampResultField) { if (hasTimestampResultField) {
@ -249,7 +252,7 @@ public class GroupByStrategyV2 implements GroupByStrategy
// the granularity and dimensions are slightly different. // the granularity and dimensions are slightly different.
// now, part of the query plan logic is handled in GroupByStrategyV2, not only in DruidQuery.toGroupByQuery() // now, part of the query plan logic is handled in GroupByStrategyV2, not only in DruidQuery.toGroupByQuery()
final Granularity timestampResultFieldGranularity final Granularity timestampResultFieldGranularity
= query.getContextValue(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY); = queryContext.getGranularity(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY);
dimensionSpecs = dimensionSpecs =
query.getDimensions() query.getDimensions()
.stream() .stream()
@ -258,7 +261,7 @@ public class GroupByStrategyV2 implements GroupByStrategy
granularity = timestampResultFieldGranularity; granularity = timestampResultFieldGranularity;
// when timestampResultField is the last dimension, should set sortByDimsFirst=true, // when timestampResultField is the last dimension, should set sortByDimsFirst=true,
// otherwise the downstream is sorted by row's timestamp first which makes the final ordering not as expected // otherwise the downstream is sorted by row's timestamp first which makes the final ordering not as expected
timestampResultFieldIndex = query.getQueryContext().getAsInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX); timestampResultFieldIndex = queryContext.getInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX);
if (!query.getContextSortByDimsFirst() && timestampResultFieldIndex == query.getDimensions().size() - 1) { if (!query.getContextSortByDimsFirst() && timestampResultFieldIndex == query.getDimensions().size() - 1) {
context.put(GroupByQuery.CTX_KEY_SORT_BY_DIMS_FIRST, true); context.put(GroupByQuery.CTX_KEY_SORT_BY_DIMS_FIRST, true);
} }
@ -312,8 +315,8 @@ public class GroupByStrategyV2 implements GroupByStrategy
// Apply postaggregators if this is the outermost mergeResults (CTX_KEY_OUTERMOST) and we are not executing a // Apply postaggregators if this is the outermost mergeResults (CTX_KEY_OUTERMOST) and we are not executing a
// pushed-down subquery (CTX_KEY_EXECUTING_NESTED_QUERY). // pushed-down subquery (CTX_KEY_EXECUTING_NESTED_QUERY).
if (!query.getContextBoolean(CTX_KEY_OUTERMOST, true) if (!queryContext.getBoolean(CTX_KEY_OUTERMOST, true)
|| query.getContextBoolean(GroupByQueryConfig.CTX_KEY_EXECUTING_NESTED_QUERY, false)) { || queryContext.getBoolean(GroupByQueryConfig.CTX_KEY_EXECUTING_NESTED_QUERY, false)) {
return mergedResults; return mergedResults;
} else if (query.getPostAggregatorSpecs().isEmpty()) { } else if (query.getPostAggregatorSpecs().isEmpty()) {
if (!hasTimestampResultField) { if (!hasTimestampResultField) {
@ -405,7 +408,7 @@ public class GroupByStrategyV2 implements GroupByStrategy
public Sequence<ResultRow> applyPostProcessing(Sequence<ResultRow> results, GroupByQuery query) public Sequence<ResultRow> applyPostProcessing(Sequence<ResultRow> results, GroupByQuery query)
{ {
// Don't apply limit here for inner results, that will be pushed down to the BufferHashGrouper // Don't apply limit here for inner results, that will be pushed down to the BufferHashGrouper
if (query.getContextBoolean(CTX_KEY_OUTERMOST, true)) { if (query.context().getBoolean(CTX_KEY_OUTERMOST, true)) {
return query.postProcess(results); return query.postProcess(results);
} else { } else {
return results; return results;

View File

@ -31,7 +31,7 @@ import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.AbstractPrioritizedQueryRunnerCallable; import org.apache.druid.query.AbstractPrioritizedQueryRunnerCallable;
import org.apache.druid.query.ConcatQueryRunner; import org.apache.druid.query.ConcatQueryRunner;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryProcessingPool;
@ -205,7 +205,7 @@ public class SegmentMetadataQueryRunnerFactory implements QueryRunnerFactory<Seg
) )
{ {
final Query<SegmentAnalysis> query = queryPlus.getQuery(); final Query<SegmentAnalysis> query = queryPlus.getQuery();
final int priority = QueryContexts.getPriority(query); final int priority = query.context().getPriority();
final QueryPlus<SegmentAnalysis> threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState(); final QueryPlus<SegmentAnalysis> threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState();
ListenableFuture<Sequence<SegmentAnalysis>> future = queryProcessingPool.submitRunnerTask( ListenableFuture<Sequence<SegmentAnalysis>> future = queryProcessingPool.submitRunnerTask(
new AbstractPrioritizedQueryRunnerCallable<Sequence<SegmentAnalysis>, SegmentAnalysis>(priority, input) new AbstractPrioritizedQueryRunnerCallable<Sequence<SegmentAnalysis>, SegmentAnalysis>(priority, input)
@ -219,8 +219,9 @@ public class SegmentMetadataQueryRunnerFactory implements QueryRunnerFactory<Seg
); );
try { try {
queryWatcher.registerQueryFuture(query, future); queryWatcher.registerQueryFuture(query, future);
if (QueryContexts.hasTimeout(query)) { final QueryContext context = query.context();
return future.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS); if (context.hasTimeout()) {
return future.get(context.getTimeout(), TimeUnit.MILLISECONDS);
} else { } else {
return future.get(); return future.get();
} }

View File

@ -264,7 +264,7 @@ public class ScanQuery extends BaseQuery<ScanResultValue>
private Integer validateAndGetMaxRowsQueuedForOrdering() private Integer validateAndGetMaxRowsQueuedForOrdering()
{ {
final Integer maxRowsQueuedForOrdering = final Integer maxRowsQueuedForOrdering =
getQueryContext().getAsInt(ScanQueryConfig.CTX_KEY_MAX_ROWS_QUEUED_FOR_ORDERING); context().getInt(ScanQueryConfig.CTX_KEY_MAX_ROWS_QUEUED_FOR_ORDERING);
Preconditions.checkArgument( Preconditions.checkArgument(
maxRowsQueuedForOrdering == null || maxRowsQueuedForOrdering > 0, maxRowsQueuedForOrdering == null || maxRowsQueuedForOrdering > 0,
"maxRowsQueuedForOrdering must be greater than 0" "maxRowsQueuedForOrdering must be greater than 0"
@ -275,7 +275,7 @@ public class ScanQuery extends BaseQuery<ScanResultValue>
private Integer validateAndGetMaxSegmentPartitionsOrderedInMemory() private Integer validateAndGetMaxSegmentPartitionsOrderedInMemory()
{ {
final Integer maxSegmentPartitionsOrderedInMemory = final Integer maxSegmentPartitionsOrderedInMemory =
getQueryContext().getAsInt(ScanQueryConfig.CTX_KEY_MAX_SEGMENT_PARTITIONS_FOR_ORDERING); context().getInt(ScanQueryConfig.CTX_KEY_MAX_SEGMENT_PARTITIONS_FOR_ORDERING);
Preconditions.checkArgument( Preconditions.checkArgument(
maxSegmentPartitionsOrderedInMemory == null || maxSegmentPartitionsOrderedInMemory > 0, maxSegmentPartitionsOrderedInMemory == null || maxSegmentPartitionsOrderedInMemory > 0,
"maxRowsQueuedForOrdering must be greater than 0" "maxRowsQueuedForOrdering must be greater than 0"

View File

@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.BaseSequence; import org.apache.druid.java.util.common.guava.BaseSequence;
import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryTimeoutException; import org.apache.druid.query.QueryTimeoutException;
import org.apache.druid.query.context.ResponseContext; import org.apache.druid.query.context.ResponseContext;
@ -78,7 +77,7 @@ public class ScanQueryEngine
if (numScannedRows != null && numScannedRows >= query.getScanRowsLimit() && query.getTimeOrder().equals(ScanQuery.Order.NONE)) { if (numScannedRows != null && numScannedRows >= query.getScanRowsLimit() && query.getTimeOrder().equals(ScanQuery.Order.NONE)) {
return Sequences.empty(); return Sequences.empty();
} }
final boolean hasTimeout = QueryContexts.hasTimeout(query); final boolean hasTimeout = query.context().hasTimeout();
final Long timeoutAt = responseContext.getTimeoutTime(); final Long timeoutAt = responseContext.getTimeoutTime();
final StorageAdapter adapter = segment.asStorageAdapter(); final StorageAdapter adapter = segment.asStorageAdapter();

View File

@ -99,7 +99,7 @@ public class ScanQueryLimitRowIterator implements CloseableIterator<ScanResultVa
// We want to perform multi-event ScanResultValue limiting if we are not time-ordering or are at the // We want to perform multi-event ScanResultValue limiting if we are not time-ordering or are at the
// inner-level if we are time-ordering // inner-level if we are time-ordering
if (query.getTimeOrder() == ScanQuery.Order.NONE || if (query.getTimeOrder() == ScanQuery.Order.NONE ||
!query.getContextBoolean(ScanQuery.CTX_KEY_OUTERMOST, true)) { !query.context().getBoolean(ScanQuery.CTX_KEY_OUTERMOST, true)) {
ScanResultValue batch = yielder.get(); ScanResultValue batch = yielder.get();
List events = (List) batch.getEvents(); List events = (List) batch.getEvents();
if (events.size() <= limit - count) { if (events.size() <= limit - count) {

View File

@ -33,7 +33,6 @@ import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryProcessingPool;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
@ -94,7 +93,7 @@ public class ScanQueryRunnerFactory implements QueryRunnerFactory<ScanResultValu
// Note: this variable is effective only when queryContext has a timeout. // Note: this variable is effective only when queryContext has a timeout.
// See the comment of ResponseContext.Key.TIMEOUT_AT. // See the comment of ResponseContext.Key.TIMEOUT_AT.
final long timeoutAt = System.currentTimeMillis() + QueryContexts.getTimeout(queryPlus.getQuery()); final long timeoutAt = System.currentTimeMillis() + queryPlus.getQuery().context().getTimeout();
responseContext.putTimeoutTime(timeoutAt); responseContext.putTimeoutTime(timeoutAt);
if (query.getTimeOrder().equals(ScanQuery.Order.NONE)) { if (query.getTimeOrder().equals(ScanQuery.Order.NONE)) {

View File

@ -55,7 +55,7 @@ public class SearchQueryConfig
{ {
final SearchQueryConfig newConfig = new SearchQueryConfig(); final SearchQueryConfig newConfig = new SearchQueryConfig();
newConfig.maxSearchLimit = query.getLimit(); newConfig.maxSearchLimit = query.getLimit();
newConfig.searchStrategy = query.getQueryContext().getAsString(CTX_KEY_STRATEGY, searchStrategy); newConfig.searchStrategy = query.context().getString(CTX_KEY_STRATEGY, searchStrategy);
return newConfig; return newConfig;
} }
} }

View File

@ -34,7 +34,6 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryToolChest; import org.apache.druid.query.QueryToolChest;
@ -329,7 +328,7 @@ public class SearchQueryQueryToolChest extends QueryToolChest<Result<SearchResul
return runner.run(queryPlus, responseContext); return runner.run(queryPlus, responseContext);
} }
final boolean isBySegment = QueryContexts.isBySegment(query); final boolean isBySegment = query.context().isBySegment();
return Sequences.map( return Sequences.map(
runner.run(queryPlus.withQuery(query.withLimit(config.getMaxSearchLimit())), responseContext), runner.run(queryPlus.withQuery(query.withLimit(config.getMaxSearchLimit())), responseContext),

View File

@ -24,7 +24,6 @@ import com.google.common.collect.Ordering;
import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.DataSource; import org.apache.druid.query.DataSource;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.QuerySegmentWalker;
import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.DimFilter;
@ -34,6 +33,7 @@ import org.joda.time.Duration;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -110,12 +110,6 @@ public class SelectQuery implements Query<Object>
throw new RuntimeException(REMOVED_ERROR_MESSAGE); throw new RuntimeException(REMOVED_ERROR_MESSAGE);
} }
@Override
public QueryContext getQueryContext()
{
throw new RuntimeException(REMOVED_ERROR_MESSAGE);
}
@Override @Override
public boolean isDescending() public boolean isDescending()
{ {

View File

@ -68,7 +68,7 @@ public class SpecificSegmentQueryRunner<T> implements QueryRunner<T>
) )
); );
final boolean setName = input.getQuery().getContextBoolean(CTX_SET_THREAD_NAME, true); final boolean setName = input.getQuery().context().getBoolean(CTX_SET_THREAD_NAME, true);
final Query<T> query = queryPlus.getQuery(); final Query<T> query = queryPlus.getQuery();

View File

@ -35,6 +35,7 @@ import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.DefaultGenericQueryMetricsFactory; import org.apache.druid.query.DefaultGenericQueryMetricsFactory;
import org.apache.druid.query.GenericQueryMetricsFactory; import org.apache.druid.query.GenericQueryMetricsFactory;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
@ -232,9 +233,10 @@ public class TimeBoundaryQueryQueryToolChest
{ {
if (query.isMinTime() || query.isMaxTime()) { if (query.isMinTime() || query.isMaxTime()) {
RowSignature.Builder builder = RowSignature.builder(); RowSignature.Builder builder = RowSignature.builder();
final QueryContext queryContext = query.context();
String outputName = query.isMinTime() ? String outputName = query.isMinTime() ?
query.getQueryContext().getAsString(TimeBoundaryQuery.MIN_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MIN_TIME) : queryContext.getString(TimeBoundaryQuery.MIN_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MIN_TIME) :
query.getQueryContext().getAsString(TimeBoundaryQuery.MAX_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MAX_TIME); queryContext.getString(TimeBoundaryQuery.MAX_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MAX_TIME);
return builder.add(outputName, ColumnType.LONG).build(); return builder.add(outputName, ColumnType.LONG).build();
} }
return super.resultArraySignature(query); return super.resultArraySignature(query);

View File

@ -154,17 +154,17 @@ public class TimeseriesQuery extends BaseQuery<Result<TimeseriesResultValue>>
public boolean isGrandTotal() public boolean isGrandTotal()
{ {
return getContextBoolean(CTX_GRAND_TOTAL, false); return context().getBoolean(CTX_GRAND_TOTAL, false);
} }
public String getTimestampResultField() public String getTimestampResultField()
{ {
return getQueryContext().getAsString(CTX_TIMESTAMP_RESULT_FIELD); return context().getString(CTX_TIMESTAMP_RESULT_FIELD);
} }
public boolean isSkipEmptyBuckets() public boolean isSkipEmptyBuckets()
{ {
return getContextBoolean(SKIP_EMPTY_BUCKETS, false); return context().getBoolean(SKIP_EMPTY_BUCKETS, false);
} }
@Nullable @Nullable

View File

@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerHelper; import org.apache.druid.query.QueryRunnerHelper;
import org.apache.druid.query.Result; import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.Aggregator;
@ -101,7 +100,7 @@ public class TimeseriesQueryEngine
final ColumnInspector inspector = query.getVirtualColumns().wrapInspector(adapter); final ColumnInspector inspector = query.getVirtualColumns().wrapInspector(adapter);
final boolean doVectorize = QueryContexts.getVectorize(query).shouldVectorize( final boolean doVectorize = query.context().getVectorize().shouldVectorize(
adapter.canVectorize(filter, query.getVirtualColumns(), descending) adapter.canVectorize(filter, query.getVirtualColumns(), descending)
&& VirtualColumns.shouldVectorize(query, query.getVirtualColumns(), adapter) && VirtualColumns.shouldVectorize(query, query.getVirtualColumns(), adapter)
&& query.getAggregatorSpecs().stream().allMatch(aggregatorFactory -> aggregatorFactory.canVectorize(inspector)) && query.getAggregatorSpecs().stream().allMatch(aggregatorFactory -> aggregatorFactory.canVectorize(inspector))
@ -141,7 +140,7 @@ public class TimeseriesQueryEngine
queryInterval, queryInterval,
query.getVirtualColumns(), query.getVirtualColumns(),
descending, descending,
QueryContexts.getVectorSize(query), query.context().getVectorSize(),
timeseriesQueryMetrics timeseriesQueryMetrics
); );

View File

@ -37,7 +37,6 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryToolChest; import org.apache.druid.query.QueryToolChest;
@ -147,7 +146,7 @@ public class TimeseriesQueryQueryToolChest extends QueryToolChest<Result<Timeser
!query.isSkipEmptyBuckets() && !query.isSkipEmptyBuckets() &&
// Returns empty sequence if bySegment is set because bySegment results are mostly used for // Returns empty sequence if bySegment is set because bySegment results are mostly used for
// caching in historicals or debugging where the exact results are preferred. // caching in historicals or debugging where the exact results are preferred.
!QueryContexts.isBySegment(query)) { !query.context().isBySegment()) {
// Usally it is NOT Okay to materialize results via toList(), but Granularity is ALL thus // Usally it is NOT Okay to materialize results via toList(), but Granularity is ALL thus
// we have only one record. // we have only one record.
final List<Result<TimeseriesResultValue>> val = baseResults.toList(); final List<Result<TimeseriesResultValue>> val = baseResults.toList();

View File

@ -138,7 +138,7 @@ public class TopNQueryEngine
// if sorted by dimension we should aggregate all metrics in a single pass, use the regular pooled algorithm for // if sorted by dimension we should aggregate all metrics in a single pass, use the regular pooled algorithm for
// this // this
topNAlgorithm = new PooledTopNAlgorithm(adapter, query, bufferPool); topNAlgorithm = new PooledTopNAlgorithm(adapter, query, bufferPool);
} else if (selector.isAggregateTopNMetricFirst() || query.getContextBoolean("doAggregateTopNMetricFirst", false)) { } else if (selector.isAggregateTopNMetricFirst() || query.context().getBoolean("doAggregateTopNMetricFirst", false)) {
// for high cardinality dimensions with larger result sets we aggregate with only the ordering aggregation to // for high cardinality dimensions with larger result sets we aggregate with only the ordering aggregation to
// compute the first 'n' values, and then for the rest of the metrics but for only the 'n' values // compute the first 'n' values, and then for the rest of the metrics but for only the 'n' values
topNAlgorithm = new AggregateTopNMetricFirstAlgorithm(adapter, query, bufferPool); topNAlgorithm = new AggregateTopNMetricFirstAlgorithm(adapter, query, bufferPool);

View File

@ -574,12 +574,12 @@ public class TopNQueryQueryToolChest extends QueryToolChest<Result<TopNResultVal
} }
final TopNQuery query = (TopNQuery) input; final TopNQuery query = (TopNQuery) input;
final int minTopNThreshold = query.getQueryContext().getAsInt("minTopNThreshold", config.getMinTopNThreshold()); final int minTopNThreshold = query.context().getInt(QueryContexts.MIN_TOP_N_THRESHOLD, config.getMinTopNThreshold());
if (query.getThreshold() > minTopNThreshold) { if (query.getThreshold() > minTopNThreshold) {
return runner.run(queryPlus, responseContext); return runner.run(queryPlus, responseContext);
} }
final boolean isBySegment = QueryContexts.isBySegment(query); final boolean isBySegment = query.context().isBySegment();
return Sequences.map( return Sequences.map(
runner.run(queryPlus.withQuery(query.withThreshold(minTopNThreshold)), responseContext), runner.run(queryPlus.withQuery(query.withThreshold(minTopNThreshold)), responseContext),

View File

@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.Cacheable;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
@ -48,6 +47,7 @@ import org.apache.druid.segment.virtual.VirtualizedColumnInspector;
import org.apache.druid.segment.virtual.VirtualizedColumnSelectorFactory; import org.apache.druid.segment.virtual.VirtualizedColumnSelectorFactory;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -120,7 +120,7 @@ public class VirtualColumns implements Cacheable
public static boolean shouldVectorize(Query<?> query, VirtualColumns virtualColumns, ColumnInspector inspector) public static boolean shouldVectorize(Query<?> query, VirtualColumns virtualColumns, ColumnInspector inspector)
{ {
if (virtualColumns.getVirtualColumns().length > 0) { if (virtualColumns.getVirtualColumns().length > 0) {
return QueryContexts.getVectorizeVirtualColumns(query).shouldVectorize(virtualColumns.canVectorize(inspector)); return query.context().getVectorizeVirtualColumns().shouldVectorize(virtualColumns.canVectorize(inspector));
} else { } else {
return true; return true;
} }

View File

@ -215,7 +215,7 @@ public class Filters
if (filter == null) { if (filter == null) {
return null; return null;
} }
boolean useCNF = query.getContextBoolean(QueryContexts.USE_FILTER_CNF_KEY, QueryContexts.DEFAULT_USE_FILTER_CNF); boolean useCNF = query.context().getBoolean(QueryContexts.USE_FILTER_CNF_KEY, QueryContexts.DEFAULT_USE_FILTER_CNF);
try { try {
return useCNF ? Filters.toCnf(filter) : filter; return useCNF ? Filters.toCnf(filter) : filter;
} }

View File

@ -20,7 +20,7 @@
package org.apache.druid.segment.join.filter.rewrite; package org.apache.druid.segment.join.filter.rewrite;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContext;
import java.util.Objects; import java.util.Objects;
@ -76,12 +76,13 @@ public class JoinFilterRewriteConfig
public static JoinFilterRewriteConfig forQuery(final Query<?> query) public static JoinFilterRewriteConfig forQuery(final Query<?> query)
{ {
QueryContext context = query.context();
return new JoinFilterRewriteConfig( return new JoinFilterRewriteConfig(
QueryContexts.getEnableJoinFilterPushDown(query), context.getEnableJoinFilterPushDown(),
QueryContexts.getEnableJoinFilterRewrite(query), context.getEnableJoinFilterRewrite(),
QueryContexts.getEnableJoinFilterRewriteValueColumnFilters(query), context.getEnableJoinFilterRewriteValueColumnFilters(),
QueryContexts.getEnableRewriteJoinToFilter(query), context.getEnableRewriteJoinToFilter(),
QueryContexts.getJoinFilterRewriteMaxSize(query) context.getJoinFilterRewriteMaxSize()
); );
} }

View File

@ -19,31 +19,45 @@
package org.apache.druid.query; package org.apache.druid.query;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.exc.MismatchedInputException;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Ordering; import com.google.common.collect.Ordering;
import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.EqualsVerifier;
import nl.jqno.equalsverifier.Warning; import nl.jqno.equalsverifier.Warning;
import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.spec.QuerySegmentSpec; import org.apache.druid.query.spec.QuerySegmentSpec;
import org.apache.druid.segment.DimensionHandlerUtils;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.joda.time.Duration; import org.joda.time.Duration;
import org.joda.time.Interval; import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
public class QueryContextTest public class QueryContextTest
{ {
private static final ObjectMapper JSON_MAPPER = new ObjectMapper();
@Test @Test
public void testEquals() public void testEquals()
{ {
@ -51,63 +65,83 @@ public class QueryContextTest
.suppress(Warning.NONFINAL_FIELDS, Warning.ALL_FIELDS_SHOULD_BE_USED) .suppress(Warning.NONFINAL_FIELDS, Warning.ALL_FIELDS_SHOULD_BE_USED)
.usingGetClass() .usingGetClass()
.forClass(QueryContext.class) .forClass(QueryContext.class)
.withNonnullFields("defaultParams", "userParams", "systemParams") .withNonnullFields("context")
.verify(); .verify();
} }
/**
* Verify that a context with an null map is the same as a context with
* an empty map.
*/
@Test @Test
public void testEmptyParam() public void testEmptyContext()
{ {
final QueryContext context = new QueryContext(); {
Assert.assertEquals(ImmutableMap.of(), context.getMergedParams()); final QueryContext context = new QueryContext(null);
assertEquals(ImmutableMap.of(), context.asMap());
}
{
final QueryContext context = new QueryContext(new HashMap<>());
assertEquals(ImmutableMap.of(), context.asMap());
}
{
final QueryContext context = QueryContext.of(null);
assertEquals(ImmutableMap.of(), context.asMap());
}
{
final QueryContext context = QueryContext.of(new HashMap<>());
assertEquals(ImmutableMap.of(), context.asMap());
}
{
final QueryContext context = QueryContext.empty();
assertEquals(ImmutableMap.of(), context.asMap());
}
} }
@Test @Test
public void testIsEmpty() public void testIsEmpty()
{ {
Assert.assertTrue(new QueryContext().isEmpty()); assertTrue(QueryContext.empty().isEmpty());
Assert.assertFalse(new QueryContext(ImmutableMap.of("k", "v")).isEmpty()); assertFalse(QueryContext.of(ImmutableMap.of("k", "v")).isEmpty());
QueryContext context = new QueryContext();
context.addDefaultParam("k", "v");
Assert.assertFalse(context.isEmpty());
context = new QueryContext();
context.addSystemParam("k", "v");
Assert.assertFalse(context.isEmpty());
} }
@Test @Test
public void testGetString() public void testGetString()
{ {
final QueryContext context = new QueryContext( final QueryContext context = QueryContext.of(
ImmutableMap.of("key", "val", ImmutableMap.of("key", "val",
"key2", 2) "key2", 2)
); );
Assert.assertEquals("val", context.get("key")); assertEquals("val", context.get("key"));
Assert.assertEquals("val", context.getAsString("key")); assertEquals("val", context.getString("key"));
Assert.assertEquals("2", context.getAsString("key2")); assertNull(context.getString("non-exist"));
Assert.assertNull(context.getAsString("non-exist")); assertEquals("foo", context.getString("non-exist", "foo"));
assertThrows(BadQueryContextException.class, () -> context.getString("key2"));
} }
@Test @Test
public void testGetBoolean() public void testGetBoolean()
{ {
final QueryContext context = new QueryContext( final QueryContext context = QueryContext.of(
ImmutableMap.of( ImmutableMap.of(
"key1", "true", "key1", "true",
"key2", true "key2", true
) )
); );
Assert.assertTrue(context.getAsBoolean("key1", false)); assertTrue(context.getBoolean("key1", false));
Assert.assertTrue(context.getAsBoolean("key2", false)); assertTrue(context.getBoolean("key2", false));
Assert.assertFalse(context.getAsBoolean("non-exist", false)); assertTrue(context.getBoolean("key1"));
assertFalse(context.getBoolean("non-exist", false));
assertNull(context.getBoolean("non-exist"));
} }
@Test @Test
public void testGetInt() public void testGetInt()
{ {
final QueryContext context = new QueryContext( final QueryContext context = QueryContext.of(
ImmutableMap.of( ImmutableMap.of(
"key1", "100", "key1", "100",
"key2", 100, "key2", 100,
@ -115,17 +149,17 @@ public class QueryContextTest
) )
); );
Assert.assertEquals(100, context.getAsInt("key1", 0)); assertEquals(100, context.getInt("key1", 0));
Assert.assertEquals(100, context.getAsInt("key2", 0)); assertEquals(100, context.getInt("key2", 0));
Assert.assertEquals(0, context.getAsInt("non-exist", 0)); assertEquals(0, context.getInt("non-exist", 0));
Assert.assertThrows(IAE.class, () -> context.getAsInt("key3", 5)); assertThrows(BadQueryContextException.class, () -> context.getInt("key3", 5));
} }
@Test @Test
public void testGetLong() public void testGetLong()
{ {
final QueryContext context = new QueryContext( final QueryContext context = QueryContext.of(
ImmutableMap.of( ImmutableMap.of(
"key1", "100", "key1", "100",
"key2", 100, "key2", 100,
@ -133,17 +167,127 @@ public class QueryContextTest
) )
); );
Assert.assertEquals(100L, context.getAsLong("key1", 0)); assertEquals(100L, context.getLong("key1", 0));
Assert.assertEquals(100L, context.getAsLong("key2", 0)); assertEquals(100L, context.getLong("key2", 0));
Assert.assertEquals(0L, context.getAsLong("non-exist", 0)); assertEquals(0L, context.getLong("non-exist", 0));
Assert.assertThrows(IAE.class, () -> context.getAsLong("key3", 5)); assertThrows(BadQueryContextException.class, () -> context.getLong("key3", 5));
}
/**
* Tests the several ways that Druid code parses context strings into Long
* values. The desired behavior is that "x" is parsed exactly the same as Jackson
* would parse x (where x is a valid number.) The context methods must emulate
* Jackson. The dimension utility method is included because some code used that
* for long parsing, and we must maintain backward compatibility.
* <p>
* The exceptions in the {@code assertThrows} are not critical: the key thing is
* that we're documenting what works and what doesn't. If an exception changes,
* just update the tests. If something no longer throws an exception, we'll want
* to verify that we support the new use case consistently in all three paths.
*/
@Test
public void testGetLongCompatibility() throws JsonProcessingException
{
{
String value = null;
// Only the context methods allow {"foo": null} to be parsed as a null Long.
assertNull(getContextLong(value));
// Nulls not legal on this path.
assertThrows(NullPointerException.class, () -> getDimensionLong(value));
// Nulls not legal on this path.
assertThrows(IllegalArgumentException.class, () -> getJsonLong(value));
}
{
String value = "";
// Blank string not legal on this path.
assertThrows(BadQueryContextException.class, () -> getContextLong(value));
assertNull(getDimensionLong(value));
// Blank string not allowed where a value is expected.
assertThrows(MismatchedInputException.class, () -> getJsonLong(value));
}
{
String value = "0";
assertEquals(0L, (long) getContextLong(value));
assertEquals(0L, (long) getDimensionLong(value));
assertEquals(0L, (long) getJsonLong(value));
}
{
String value = "+1";
assertEquals(1L, (long) getContextLong(value));
assertEquals(1L, (long) getDimensionLong(value));
assertThrows(JsonParseException.class, () -> getJsonLong(value));
}
{
String value = "-1";
assertEquals(-1L, (long) getContextLong(value));
assertEquals(-1L, (long) getDimensionLong(value));
assertEquals(-1L, (long) getJsonLong(value));
}
{
// Hexadecimal numbers are not supported in JSON. Druid also does not support
// them in strings.
String value = "0xabcd";
assertThrows(BadQueryContextException.class, () -> getContextLong(value));
// The dimension utils have a funny way of handling hex: they return null
assertNull(getDimensionLong(value));
assertThrows(JsonParseException.class, () -> getJsonLong(value));
}
{
// Leading zeros supported by Druid parsing, but not by JSON.
String value = "05";
assertEquals(5L, (long) getContextLong(value));
assertEquals(5L, (long) getDimensionLong(value));
assertThrows(JsonParseException.class, () -> getJsonLong(value));
}
{
// The dimension utils allow a float where a long is expected.
// Jackson can do this conversion. This test verifies that the context
// functions can handle the same conversion.
String value = "10.00";
assertEquals(10L, (long) getContextLong(value));
assertEquals(10L, (long) getDimensionLong(value));
assertEquals(10L, (long) getJsonLong(value));
}
{
// None of the conversion methods allow a (thousands) separator. The comma
// would be ambiguous in JSON. Java allows the underscore, but JSON does
// not support this syntax, and neither does Druid's string-to-long conversion.
String value = "1_234";
assertThrows(BadQueryContextException.class, () -> getContextLong(value));
assertNull(getDimensionLong(value));
assertThrows(JsonParseException.class, () -> getJsonLong(value));
}
}
private static Long getContextLong(String value)
{
return QueryContexts.getAsLong("dummy", value);
}
private static Long getJsonLong(String value) throws JsonProcessingException
{
return JSON_MAPPER.readValue(value, Long.class);
}
private static Long getDimensionLong(String value)
{
return DimensionHandlerUtils.getExactLongFromDecimalString(value);
} }
@Test @Test
public void testGetFloat() public void testGetFloat()
{ {
final QueryContext context = new QueryContext( final QueryContext context = QueryContext.of(
ImmutableMap.of( ImmutableMap.of(
"f1", "500", "f1", "500",
"f2", 500, "f2", 500,
@ -152,11 +296,11 @@ public class QueryContextTest
) )
); );
Assert.assertEquals(0, Float.compare(500, context.getAsFloat("f1", 100))); assertEquals(0, Float.compare(500, context.getFloat("f1", 100)));
Assert.assertEquals(0, Float.compare(500, context.getAsFloat("f2", 100))); assertEquals(0, Float.compare(500, context.getFloat("f2", 100)));
Assert.assertEquals(0, Float.compare(500.1f, context.getAsFloat("f3", 100))); assertEquals(0, Float.compare(500.1f, context.getFloat("f3", 100)));
Assert.assertThrows(IAE.class, () -> context.getAsLong("f4", 5)); assertThrows(BadQueryContextException.class, () -> context.getFloat("f4", 5));
} }
@Test @Test
@ -172,167 +316,30 @@ public class QueryContextTest
.put("m6", "abc") .put("m6", "abc")
.build() .build()
); );
Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m1", HumanReadableBytes.ZERO).getBytes()); assertEquals(500_000_000, context.getHumanReadableBytes("m1", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m2", HumanReadableBytes.ZERO).getBytes()); assertEquals(500_000_000, context.getHumanReadableBytes("m2", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500 * 1024 * 1024L, context.getAsHumanReadableBytes("m3", HumanReadableBytes.ZERO).getBytes()); assertEquals(500 * 1024 * 1024L, context.getHumanReadableBytes("m3", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500 * 1024 * 1024L, context.getAsHumanReadableBytes("m4", HumanReadableBytes.ZERO).getBytes()); assertEquals(500 * 1024 * 1024L, context.getHumanReadableBytes("m4", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m5", HumanReadableBytes.ZERO).getBytes()); assertEquals(500_000_000, context.getHumanReadableBytes("m5", HumanReadableBytes.ZERO).getBytes());
Assert.assertThrows(IAE.class, () -> context.getAsHumanReadableBytes("m6", HumanReadableBytes.ZERO)); assertThrows(BadQueryContextException.class, () -> context.getHumanReadableBytes("m6", HumanReadableBytes.ZERO));
} }
@Test @Test
public void testAddSystemParamOverrideUserParam() public void testDefaultEnableQueryDebugging()
{ {
final QueryContext context = new QueryContext( assertFalse(QueryContext.empty().isDebug());
ImmutableMap.of( assertTrue(QueryContext.of(ImmutableMap.of(QueryContexts.ENABLE_DEBUG, true)).isDebug());
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addSystemParam("sys1", "sysVal1");
context.addSystemParam("conflict", "sysVal2");
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
),
context.getUserParams()
);
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"sys1", "sysVal1",
"conflict", "sysVal2"
),
context.getMergedParams()
);
}
@Test
public void testUserParamOverrideDefaultParam()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1"
)
);
context.addDefaultParam("conflict", "defaultVal2");
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
),
context.getUserParams()
);
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"default1", "defaultVal1",
"conflict", "userVal2"
),
context.getMergedParams()
);
}
@Test
public void testRemoveUserParam()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1",
"conflict", "defaultVal2"
)
);
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"default1", "defaultVal1",
"conflict", "userVal2"
),
context.getMergedParams()
);
Assert.assertEquals("userVal2", context.removeUserParam("conflict"));
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"default1", "defaultVal1",
"conflict", "defaultVal2"
),
context.getMergedParams()
);
}
@Test
public void testGetMergedParams()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1",
"conflict", "defaultVal2"
)
);
Assert.assertSame(context.getMergedParams(), context.getMergedParams());
}
@Test
public void testCopy()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1",
"conflict", "defaultVal2"
)
);
context.addSystemParam("sys1", "val1");
final Map<String, Object> merged = ImmutableMap.copyOf(context.getMergedParams());
final QueryContext context2 = context.copy();
context2.removeUserParam("conflict");
context2.addSystemParam("sys2", "val2");
context2.addDefaultParam("default3", "defaultVal3");
Assert.assertEquals(merged, context.getMergedParams());
} }
// This test is a bit silly. It is retained because another test uses the
// LegacyContextQuery test.
@Test @Test
public void testLegacyReturnsLegacy() public void testLegacyReturnsLegacy()
{ {
Query<?> legacy = new LegacyContextQuery(ImmutableMap.of("foo", "bar")); Map<String, Object> context = ImmutableMap.of("foo", "bar");
Assert.assertNull(legacy.getQueryContext()); Query<?> legacy = new LegacyContextQuery(context);
assertEquals(context, legacy.getContext());
} }
@Test @Test
@ -345,10 +352,10 @@ public class QueryContextTest
.aggregators(Collections.singletonList(new CountAggregatorFactory("theCount"))) .aggregators(Collections.singletonList(new CountAggregatorFactory("theCount")))
.context(ImmutableMap.of("foo", "bar")) .context(ImmutableMap.of("foo", "bar"))
.build(); .build();
Assert.assertNotNull(timeseries.getQueryContext()); assertNotNull(timeseries.getContext());
} }
public static class LegacyContextQuery implements Query public static class LegacyContextQuery implements Query<Integer>
{ {
private final Map<String, Object> context; private final Map<String, Object> context;
@ -382,9 +389,9 @@ public class QueryContextTest
} }
@Override @Override
public QueryRunner getRunner(QuerySegmentWalker walker) public QueryRunner<Integer> getRunner(QuerySegmentWalker walker)
{ {
return new NoopQueryRunner(); return new NoopQueryRunner<>();
} }
@Override @Override
@ -417,31 +424,6 @@ public class QueryContextTest
return context; return context;
} }
@Override
public boolean getContextBoolean(String key, boolean defaultValue)
{
if (context == null || !context.containsKey(key)) {
return defaultValue;
}
return (boolean) context.get(key);
}
@Override
public HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{
if (null == context || !context.containsKey(key)) {
return defaultValue;
}
Object value = context.get(key);
if (value instanceof Number) {
return HumanReadableBytes.valueOf(((Number) value).longValue());
} else if (value instanceof String) {
return new HumanReadableBytes((String) value);
} else {
throw new IAE("Expected parameter [%s] to be in human readable format", key);
}
}
@Override @Override
public boolean isDescending() public boolean isDescending()
{ {
@ -449,19 +431,19 @@ public class QueryContextTest
} }
@Override @Override
public Ordering getResultOrdering() public Ordering<Integer> getResultOrdering()
{ {
return Ordering.natural(); return Ordering.natural();
} }
@Override @Override
public Query withQuerySegmentSpec(QuerySegmentSpec spec) public Query<Integer> withQuerySegmentSpec(QuerySegmentSpec spec)
{ {
return new LegacyContextQuery(context); return new LegacyContextQuery(context);
} }
@Override @Override
public Query withId(String id) public Query<Integer> withId(String id)
{ {
context.put(BaseQuery.QUERY_ID, id); context.put(BaseQuery.QUERY_ID, id);
return this; return this;
@ -475,7 +457,7 @@ public class QueryContextTest
} }
@Override @Override
public Query withSubQueryId(String subQueryId) public Query<Integer> withSubQueryId(String subQueryId)
{ {
context.put(BaseQuery.SUB_QUERY_ID, subQueryId); context.put(BaseQuery.SUB_QUERY_ID, subQueryId);
return this; return this;
@ -489,21 +471,15 @@ public class QueryContextTest
} }
@Override @Override
public Query withDataSource(DataSource dataSource) public Query<Integer> withDataSource(DataSource dataSource)
{ {
return this; return this;
} }
@Override @Override
public Query withOverriddenContext(Map contextOverride) public Query<Integer> withOverriddenContext(Map<String, Object> contextOverride)
{ {
return new LegacyContextQuery(contextOverride); return new LegacyContextQuery(contextOverride);
} }
@Override
public Object getContextValue(String key)
{
return context.get(key);
}
} }
} }

View File

@ -22,7 +22,6 @@ package org.apache.druid.query;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.junit.Assert; import org.junit.Assert;
@ -47,7 +46,7 @@ public class QueryContextsTest
false, false,
new HashMap<>() new HashMap<>()
); );
Assert.assertEquals(300_000, QueryContexts.getDefaultTimeout(query)); Assert.assertEquals(300_000, query.context().getDefaultTimeout());
} }
@Test @Test
@ -59,10 +58,10 @@ public class QueryContextsTest
false, false,
new HashMap<>() new HashMap<>()
); );
Assert.assertEquals(300_000, QueryContexts.getTimeout(query)); Assert.assertEquals(300_000, query.context().getTimeout());
query = QueryContexts.withDefaultTimeout(query, 60_000); query = Queries.withDefaultTimeout(query, 60_000);
Assert.assertEquals(60_000, QueryContexts.getTimeout(query)); Assert.assertEquals(60_000, query.context().getTimeout());
} }
@Test @Test
@ -74,17 +73,17 @@ public class QueryContextsTest
false, false,
ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1000) ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1000)
); );
Assert.assertEquals(1000, QueryContexts.getTimeout(query)); Assert.assertEquals(1000, query.context().getTimeout());
query = QueryContexts.withDefaultTimeout(query, 1_000_000); query = Queries.withDefaultTimeout(query, 1_000_000);
Assert.assertEquals(1000, QueryContexts.getTimeout(query)); Assert.assertEquals(1000, query.context().getTimeout());
} }
@Test @Test
public void testQueryMaxTimeout() public void testQueryMaxTimeout()
{ {
exception.expect(IAE.class); exception.expect(BadQueryContextException.class);
exception.expectMessage("configured [timeout = 1000] is more than enforced limit of maxQueryTimeout [100]."); exception.expectMessage("Configured timeout = 1000 is more than enforced limit of 100.");
Query<?> query = new TestQuery( Query<?> query = new TestQuery(
new TableDataSource("test"), new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))), new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
@ -92,14 +91,14 @@ public class QueryContextsTest
ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1000) ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1000)
); );
QueryContexts.verifyMaxQueryTimeout(query, 100); query.context().verifyMaxQueryTimeout(100);
} }
@Test @Test
public void testMaxScatterGatherBytes() public void testMaxScatterGatherBytes()
{ {
exception.expect(IAE.class); exception.expect(BadQueryContextException.class);
exception.expectMessage("configured [maxScatterGatherBytes = 1000] is more than enforced limit of [100]."); exception.expectMessage("Configured maxScatterGatherBytes = 1000 is more than enforced limit of 100.");
Query<?> query = new TestQuery( Query<?> query = new TestQuery(
new TableDataSource("test"), new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))), new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
@ -107,7 +106,7 @@ public class QueryContextsTest
ImmutableMap.of(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, 1000) ImmutableMap.of(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, 1000)
); );
QueryContexts.withMaxScatterGatherBytes(query, 100); Queries.withMaxScatterGatherBytes(query, 100);
} }
@Test @Test
@ -119,7 +118,7 @@ public class QueryContextsTest
false, false,
ImmutableMap.of(QueryContexts.SECONDARY_PARTITION_PRUNING_KEY, false) ImmutableMap.of(QueryContexts.SECONDARY_PARTITION_PRUNING_KEY, false)
); );
Assert.assertFalse(QueryContexts.isSecondaryPartitionPruningEnabled(query)); Assert.assertFalse(query.context().isSecondaryPartitionPruningEnabled());
} }
@Test @Test
@ -131,7 +130,7 @@ public class QueryContextsTest
false, false,
ImmutableMap.of() ImmutableMap.of()
); );
Assert.assertTrue(QueryContexts.isSecondaryPartitionPruningEnabled(query)); Assert.assertTrue(query.context().isSecondaryPartitionPruningEnabled());
} }
@Test @Test
@ -139,7 +138,7 @@ public class QueryContextsTest
{ {
Assert.assertEquals( Assert.assertEquals(
QueryContexts.DEFAULT_IN_SUB_QUERY_THRESHOLD, QueryContexts.DEFAULT_IN_SUB_QUERY_THRESHOLD,
QueryContexts.getInSubQueryThreshold(ImmutableMap.of()) QueryContext.empty().getInSubQueryThreshold()
); );
} }
@ -148,32 +147,32 @@ public class QueryContextsTest
{ {
Assert.assertEquals( Assert.assertEquals(
QueryContexts.DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING, QueryContexts.DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING,
QueryContexts.isTimeBoundaryPlanningEnabled(ImmutableMap.of()) QueryContext.empty().isTimeBoundaryPlanningEnabled()
); );
} }
@Test @Test
public void testGetEnableJoinLeftScanDirect() public void testGetEnableJoinLeftScanDirect()
{ {
Assert.assertFalse(QueryContexts.getEnableJoinLeftScanDirect(ImmutableMap.of())); Assert.assertFalse(QueryContext.empty().getEnableJoinLeftScanDirect());
Assert.assertTrue(QueryContexts.getEnableJoinLeftScanDirect(ImmutableMap.of( Assert.assertTrue(QueryContext.of(ImmutableMap.of(
QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT, QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT,
true true
))); )).getEnableJoinLeftScanDirect());
Assert.assertFalse(QueryContexts.getEnableJoinLeftScanDirect(ImmutableMap.of( Assert.assertFalse(QueryContext.of(ImmutableMap.of(
QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT, QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT,
false false
))); )).getEnableJoinLeftScanDirect());
} }
@Test @Test
public void testGetBrokerServiceName() public void testGetBrokerServiceName()
{ {
Map<String, Object> queryContext = new HashMap<>(); Map<String, Object> queryContext = new HashMap<>();
Assert.assertNull(QueryContexts.getBrokerServiceName(queryContext)); Assert.assertNull(QueryContext.of(queryContext).getBrokerServiceName());
queryContext.put(QueryContexts.BROKER_SERVICE_NAME, "hotBroker"); queryContext.put(QueryContexts.BROKER_SERVICE_NAME, "hotBroker");
Assert.assertEquals("hotBroker", QueryContexts.getBrokerServiceName(queryContext)); Assert.assertEquals("hotBroker", QueryContext.of(queryContext).getBrokerServiceName());
} }
@Test @Test
@ -182,8 +181,8 @@ public class QueryContextsTest
Map<String, Object> queryContext = new HashMap<>(); Map<String, Object> queryContext = new HashMap<>();
queryContext.put(QueryContexts.BROKER_SERVICE_NAME, 100); queryContext.put(QueryContexts.BROKER_SERVICE_NAME, 100);
exception.expect(ClassCastException.class); exception.expect(BadQueryContextException.class);
QueryContexts.getBrokerServiceName(queryContext); QueryContext.of(queryContext).getBrokerServiceName();
} }
@Test @Test
@ -193,38 +192,12 @@ public class QueryContextsTest
queryContext.put(QueryContexts.TIMEOUT_KEY, "2000'"); queryContext.put(QueryContexts.TIMEOUT_KEY, "2000'");
exception.expect(BadQueryContextException.class); exception.expect(BadQueryContextException.class);
QueryContexts.getTimeout(new TestQuery( new TestQuery(
new TableDataSource("test"), new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))), new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
false, false,
queryContext queryContext
)); ).context().getTimeout();
}
@Test
public void testDefaultEnableQueryDebugging()
{
Query<?> query = new TestQuery(
new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
false,
ImmutableMap.of()
);
Assert.assertFalse(QueryContexts.isDebug(query));
Assert.assertFalse(QueryContexts.isDebug(query.getContext()));
}
@Test
public void testEnableQueryDebuggingSetToTrue()
{
Query<?> query = new TestQuery(
new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
false,
ImmutableMap.of(QueryContexts.ENABLE_DEBUG, true)
);
Assert.assertTrue(QueryContexts.isDebug(query));
Assert.assertTrue(QueryContexts.isDebug(query.getContext()));
} }
@Test @Test
@ -237,7 +210,7 @@ public class QueryContextsTest
QueryContexts.getAsString("foo", 10, null); QueryContexts.getAsString("foo", 10, null);
Assert.fail(); Assert.fail();
} }
catch (IAE e) { catch (BadQueryContextException e) {
// Expected // Expected
} }
@ -249,7 +222,7 @@ public class QueryContextsTest
QueryContexts.getAsBoolean("foo", 10, false); QueryContexts.getAsBoolean("foo", 10, false);
Assert.fail(); Assert.fail();
} }
catch (IAE e) { catch (BadQueryContextException e) {
// Expected // Expected
} }
@ -262,7 +235,7 @@ public class QueryContextsTest
QueryContexts.getAsInt("foo", true, 20); QueryContexts.getAsInt("foo", true, 20);
Assert.fail(); Assert.fail();
} }
catch (IAE e) { catch (BadQueryContextException e) {
// Expected // Expected
} }
@ -275,7 +248,7 @@ public class QueryContextsTest
QueryContexts.getAsLong("foo", true, 20); QueryContexts.getAsLong("foo", true, 20);
Assert.fail(); Assert.fail();
} }
catch (IAE e) { catch (BadQueryContextException e) {
// Expected // Expected
} }
} }
@ -314,12 +287,12 @@ public class QueryContextsTest
Assert.assertEquals( Assert.assertEquals(
QueryContexts.Vectorize.FORCE, QueryContexts.Vectorize.FORCE,
query.getQueryContext().getAsEnum("e1", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE) query.context().getEnum("e1", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE)
); );
Assert.assertThrows( Assert.assertThrows(
IAE.class, BadQueryContextException.class,
() -> query.getQueryContext().getAsEnum("e2", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE) () -> query.context().getEnum("e2", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE)
); );
} }
} }

View File

@ -31,6 +31,7 @@ import org.apache.druid.query.DefaultGenericQueryMetricsFactory;
import org.apache.druid.query.Druids; import org.apache.druid.query.Druids;
import org.apache.druid.query.GenericQueryMetricsFactory; import org.apache.druid.query.GenericQueryMetricsFactory;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
@ -102,13 +103,14 @@ public class DataSourceMetadataQueryTest
), Query.class ), Query.class
); );
Assert.assertEquals((Integer) 1, serdeQuery.getQueryContext().getAsInt(QueryContexts.PRIORITY_KEY)); final QueryContext queryContext = serdeQuery.context();
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.USE_CACHE_KEY)); Assert.assertEquals(1, (int) queryContext.getInt(QueryContexts.PRIORITY_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.POPULATE_CACHE_KEY)); Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.FINALIZE_KEY)); Assert.assertEquals("true", queryContext.getString(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.USE_CACHE_KEY, false)); Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.FINALIZE_KEY));
Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.POPULATE_CACHE_KEY, false)); Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.USE_CACHE_KEY, false));
Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.FINALIZE_KEY, false)); Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.POPULATE_CACHE_KEY, false));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.FINALIZE_KEY, false));
} }
@Test @Test

View File

@ -20,7 +20,6 @@
package org.apache.druid.query.groupby.epinephelinae.vector; package org.apache.druid.query.groupby.epinephelinae.vector;
import org.apache.commons.lang3.mutable.MutableObject; import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerTestHelper; import org.apache.druid.query.QueryRunnerTestHelper;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
@ -68,7 +67,7 @@ public class VectorGroupByEngineIteratorTest extends InitializedNullHandlingTest
interval, interval,
query.getVirtualColumns(), query.getVirtualColumns(),
false, false,
QueryContexts.getVectorSize(query), query.context().getVectorSize(),
null null
); );
final List<GroupByVectorColumnSelector> dimensions = query.getDimensions().stream().map( final List<GroupByVectorColumnSelector> dimensions = query.getDimensions().stream().map(

View File

@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableMap;
import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.query.Druids; import org.apache.druid.query.Druids;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
@ -78,10 +79,11 @@ public class TimeBoundaryQueryTest
), TimeBoundaryQuery.class ), TimeBoundaryQuery.class
); );
Assert.assertEquals(new Integer(1), serdeQuery.getQueryContext().getAsInt(QueryContexts.PRIORITY_KEY)); final QueryContext queryContext = query.context();
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.USE_CACHE_KEY)); Assert.assertEquals(1, (int) queryContext.getInt(QueryContexts.PRIORITY_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.POPULATE_CACHE_KEY)); Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.FINALIZE_KEY)); Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.FINALIZE_KEY));
} }
@Test @Test
@ -116,9 +118,10 @@ public class TimeBoundaryQueryTest
); );
Assert.assertEquals("1", serdeQuery.getQueryContext().getAsString(QueryContexts.PRIORITY_KEY)); final QueryContext queryContext = query.context();
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.USE_CACHE_KEY)); Assert.assertEquals("1", queryContext.get(QueryContexts.PRIORITY_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.POPULATE_CACHE_KEY)); Assert.assertEquals("true", queryContext.get(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.FINALIZE_KEY)); Assert.assertEquals("true", queryContext.get(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals("true", queryContext.get(QueryContexts.FINALIZE_KEY));
} }
} }

View File

@ -24,12 +24,12 @@ import org.apache.druid.client.cache.CacheConfig;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryToolChest; import org.apache.druid.query.QueryToolChest;
import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.query.SegmentDescriptor;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
public class CacheUtil public class CacheUtil
@ -109,7 +109,7 @@ public class CacheUtil
) )
{ {
return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType) return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType)
&& QueryContexts.isUseCache(query) && query.context().isUseCache()
&& cacheConfig.isUseCache(); && cacheConfig.isUseCache();
} }
@ -129,7 +129,7 @@ public class CacheUtil
) )
{ {
return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType) return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType)
&& QueryContexts.isPopulateCache(query) && query.context().isPopulateCache()
&& cacheConfig.isPopulateCache(); && cacheConfig.isPopulateCache();
} }
@ -149,7 +149,7 @@ public class CacheUtil
) )
{ {
return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType) return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType)
&& QueryContexts.isUseResultLevelCache(query) && query.context().isUseResultLevelCache()
&& cacheConfig.isUseResultLevelCache(); && cacheConfig.isUseResultLevelCache();
} }
@ -169,7 +169,7 @@ public class CacheUtil
) )
{ {
return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType) return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType)
&& QueryContexts.isPopulateResultLevelCache(query) && query.context().isPopulateResultLevelCache()
&& cacheConfig.isPopulateResultLevelCache(); && cacheConfig.isPopulateResultLevelCache();
} }

View File

@ -60,6 +60,7 @@ import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.Queries; import org.apache.druid.query.Queries;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
@ -282,10 +283,11 @@ public class CachingClusteredClient implements QuerySegmentWalker
this.useCache = CacheUtil.isUseSegmentCache(query, strategy, cacheConfig, CacheUtil.ServerType.BROKER); this.useCache = CacheUtil.isUseSegmentCache(query, strategy, cacheConfig, CacheUtil.ServerType.BROKER);
this.populateCache = CacheUtil.isPopulateSegmentCache(query, strategy, cacheConfig, CacheUtil.ServerType.BROKER); this.populateCache = CacheUtil.isPopulateSegmentCache(query, strategy, cacheConfig, CacheUtil.ServerType.BROKER);
this.isBySegment = QueryContexts.isBySegment(query); final QueryContext queryContext = query.context();
this.isBySegment = queryContext.isBySegment();
// Note that enabling this leads to putting uncovered intervals information in the response headers // Note that enabling this leads to putting uncovered intervals information in the response headers
// and might blow up in some cases https://github.com/apache/druid/issues/2108 // and might blow up in some cases https://github.com/apache/druid/issues/2108
this.uncoveredIntervalsLimit = QueryContexts.getUncoveredIntervalsLimit(query); this.uncoveredIntervalsLimit = queryContext.getUncoveredIntervalsLimit();
// For nested queries, we need to look at the intervals of the inner most query. // For nested queries, we need to look at the intervals of the inner most query.
this.intervals = dataSourceAnalysis.getBaseQuerySegmentSpec() this.intervals = dataSourceAnalysis.getBaseQuerySegmentSpec()
.map(QuerySegmentSpec::getIntervals) .map(QuerySegmentSpec::getIntervals)
@ -304,9 +306,10 @@ public class CachingClusteredClient implements QuerySegmentWalker
{ {
final ImmutableMap.Builder<String, Object> contextBuilder = new ImmutableMap.Builder<>(); final ImmutableMap.Builder<String, Object> contextBuilder = new ImmutableMap.Builder<>();
final int priority = QueryContexts.getPriority(query); final QueryContext queryContext = query.context();
final int priority = queryContext.getPriority();
contextBuilder.put(QueryContexts.PRIORITY_KEY, priority); contextBuilder.put(QueryContexts.PRIORITY_KEY, priority);
final String lane = QueryContexts.getLane(query); final String lane = queryContext.getLane();
if (lane != null) { if (lane != null) {
contextBuilder.put(QueryContexts.LANE_KEY, lane); contextBuilder.put(QueryContexts.LANE_KEY, lane);
} }
@ -384,18 +387,19 @@ public class CachingClusteredClient implements QuerySegmentWalker
private Sequence<T> merge(List<Sequence<T>> sequencesByInterval) private Sequence<T> merge(List<Sequence<T>> sequencesByInterval)
{ {
BinaryOperator<T> mergeFn = toolChest.createMergeFn(query); BinaryOperator<T> mergeFn = toolChest.createMergeFn(query);
if (processingConfig.useParallelMergePool() && QueryContexts.getEnableParallelMerges(query) && mergeFn != null) { final QueryContext queryContext = query.context();
if (processingConfig.useParallelMergePool() && queryContext.getEnableParallelMerges() && mergeFn != null) {
return new ParallelMergeCombiningSequence<>( return new ParallelMergeCombiningSequence<>(
pool, pool,
sequencesByInterval, sequencesByInterval,
query.getResultOrdering(), query.getResultOrdering(),
mergeFn, mergeFn,
QueryContexts.hasTimeout(query), queryContext.hasTimeout(),
QueryContexts.getTimeout(query), queryContext.getTimeout(),
QueryContexts.getPriority(query), queryContext.getPriority(),
QueryContexts.getParallelMergeParallelism(query, processingConfig.getMergePoolDefaultMaxQueryParallelism()), queryContext.getParallelMergeParallelism(processingConfig.getMergePoolDefaultMaxQueryParallelism()),
QueryContexts.getParallelMergeInitialYieldRows(query, processingConfig.getMergePoolTaskInitialYieldRows()), queryContext.getParallelMergeInitialYieldRows(processingConfig.getMergePoolTaskInitialYieldRows()),
QueryContexts.getParallelMergeSmallBatchRows(query, processingConfig.getMergePoolSmallBatchRows()), queryContext.getParallelMergeSmallBatchRows(processingConfig.getMergePoolSmallBatchRows()),
processingConfig.getMergePoolTargetTaskRunTimeMillis(), processingConfig.getMergePoolTargetTaskRunTimeMillis(),
reportMetrics -> { reportMetrics -> {
QueryMetrics<?> queryMetrics = queryPlus.getQueryMetrics(); QueryMetrics<?> queryMetrics = queryPlus.getQueryMetrics();
@ -437,7 +441,7 @@ public class CachingClusteredClient implements QuerySegmentWalker
// Filter unneeded chunks based on partition dimension // Filter unneeded chunks based on partition dimension
for (TimelineObjectHolder<String, ServerSelector> holder : serversLookup) { for (TimelineObjectHolder<String, ServerSelector> holder : serversLookup) {
final Set<PartitionChunk<ServerSelector>> filteredChunks; final Set<PartitionChunk<ServerSelector>> filteredChunks;
if (QueryContexts.isSecondaryPartitionPruningEnabled(query)) { if (query.context().isSecondaryPartitionPruningEnabled()) {
filteredChunks = DimFilterUtils.filterShards( filteredChunks = DimFilterUtils.filterShards(
query.getFilter(), query.getFilter(),
holder.getObject(), holder.getObject(),
@ -652,12 +656,12 @@ public class CachingClusteredClient implements QuerySegmentWalker
final QueryRunner serverRunner = serverView.getQueryRunner(server); final QueryRunner serverRunner = serverView.getQueryRunner(server);
if (serverRunner == null) { if (serverRunner == null) {
log.error("Server[%s] doesn't have a query runner", server.getName()); log.error("Server [%s] doesn't have a query runner", server.getName());
return; return;
} }
// Divide user-provided maxQueuedBytes by the number of servers, and limit each server to that much. // Divide user-provided maxQueuedBytes by the number of servers, and limit each server to that much.
final long maxQueuedBytes = QueryContexts.getMaxQueuedBytes(query, httpClientConfig.getMaxQueuedBytes()); final long maxQueuedBytes = query.context().getMaxQueuedBytes(httpClientConfig.getMaxQueuedBytes());
final long maxQueuedBytesPerServer = maxQueuedBytes / segmentsByServer.size(); final long maxQueuedBytesPerServer = maxQueuedBytes / segmentsByServer.size();
final Sequence<T> serverResults; final Sequence<T> serverResults;
@ -776,7 +780,7 @@ public class CachingClusteredClient implements QuerySegmentWalker
this.dataSourceAnalysis = dataSourceAnalysis; this.dataSourceAnalysis = dataSourceAnalysis;
this.joinableFactoryWrapper = joinableFactoryWrapper; this.joinableFactoryWrapper = joinableFactoryWrapper;
this.isSegmentLevelCachingEnable = ((populateCache || useCache) this.isSegmentLevelCachingEnable = ((populateCache || useCache)
&& !QueryContexts.isBySegment(query)); // explicit bySegment queries are never cached && !query.context().isBySegment()); // explicit bySegment queries are never cached
} }

View File

@ -41,8 +41,9 @@ import org.apache.druid.java.util.http.client.response.ClientResponse;
import org.apache.druid.java.util.http.client.response.HttpResponseHandler; import org.apache.druid.java.util.http.client.response.HttpResponseHandler;
import org.apache.druid.java.util.http.client.response.StatusResponseHandler; import org.apache.druid.java.util.http.client.response.StatusResponseHandler;
import org.apache.druid.java.util.http.client.response.StatusResponseHolder; import org.apache.druid.java.util.http.client.response.StatusResponseHolder;
import org.apache.druid.query.Queries;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
@ -152,7 +153,7 @@ public class DirectDruidClient<T> implements QueryRunner<T>
{ {
final Query<T> query = queryPlus.getQuery(); final Query<T> query = queryPlus.getQuery();
QueryToolChest<T, Query<T>> toolChest = warehouse.getToolChest(query); QueryToolChest<T, Query<T>> toolChest = warehouse.getToolChest(query);
boolean isBySegment = QueryContexts.isBySegment(query); boolean isBySegment = query.context().isBySegment();
final JavaType queryResultType = isBySegment ? toolChest.getBySegmentResultType() : toolChest.getBaseResultType(); final JavaType queryResultType = isBySegment ? toolChest.getBySegmentResultType() : toolChest.getBaseResultType();
final ListenableFuture<InputStream> future; final ListenableFuture<InputStream> future;
@ -160,13 +161,15 @@ public class DirectDruidClient<T> implements QueryRunner<T>
final String cancelUrl = url + query.getId(); final String cancelUrl = url + query.getId();
try { try {
log.debug("Querying queryId[%s] url[%s]", query.getId(), url); log.debug("Querying queryId [%s] url [%s]", query.getId(), url);
final long requestStartTimeNs = System.nanoTime(); final long requestStartTimeNs = System.nanoTime();
final long timeoutAt = query.getQueryContext().getAsLong(QUERY_FAIL_TIME); final QueryContext queryContext = query.context();
final long maxScatterGatherBytes = QueryContexts.getMaxScatterGatherBytes(query); // Will NPE if the value is not set.
final long timeoutAt = queryContext.getLong(QUERY_FAIL_TIME);
final long maxScatterGatherBytes = queryContext.getMaxScatterGatherBytes();
final AtomicLong totalBytesGathered = context.getTotalBytes(); final AtomicLong totalBytesGathered = context.getTotalBytes();
final long maxQueuedBytes = QueryContexts.getMaxQueuedBytes(query, 0); final long maxQueuedBytes = queryContext.getMaxQueuedBytes(0);
final boolean usingBackpressure = maxQueuedBytes > 0; final boolean usingBackpressure = maxQueuedBytes > 0;
final HttpResponseHandler<InputStream, InputStream> responseHandler = new HttpResponseHandler<InputStream, InputStream>() final HttpResponseHandler<InputStream, InputStream> responseHandler = new HttpResponseHandler<InputStream, InputStream>()
@ -454,7 +457,7 @@ public class DirectDruidClient<T> implements QueryRunner<T>
new Request( new Request(
HttpMethod.POST, HttpMethod.POST,
new URL(url) new URL(url)
).setContent(objectMapper.writeValueAsBytes(QueryContexts.withTimeout(query, timeLeft))) ).setContent(objectMapper.writeValueAsBytes(Queries.withTimeout(query, timeLeft)))
.setHeader( .setHeader(
HttpHeaders.Names.CONTENT_TYPE, HttpHeaders.Names.CONTENT_TYPE,
isSmile ? SmileMediaTypes.APPLICATION_JACKSON_SMILE : MediaType.APPLICATION_JSON isSmile ? SmileMediaTypes.APPLICATION_JACKSON_SMILE : MediaType.APPLICATION_JSON

View File

@ -75,7 +75,7 @@ public class JsonParserIterator<T> implements Iterator<T>, Closeable
this.future = future; this.future = future;
this.url = url; this.url = url;
if (query != null) { if (query != null) {
this.timeoutAt = query.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME, -1L); this.timeoutAt = query.context().getLong(DirectDruidClient.QUERY_FAIL_TIME, -1L);
this.queryId = query.getId(); this.queryId = query.getId();
} else { } else {
this.timeoutAt = -1; this.timeoutAt = -1;

View File

@ -215,15 +215,15 @@ public class RetryQueryRunner<T> implements QueryRunner<T>
if (sequence != null) { if (sequence != null) {
return true; return true;
} else { } else {
final QueryContext queryContext = queryPlus.getQuery().context();
final List<SegmentDescriptor> missingSegments = getMissingSegments(queryPlus, context); final List<SegmentDescriptor> missingSegments = getMissingSegments(queryPlus, context);
final int maxNumRetries = QueryContexts.getNumRetriesOnMissingSegments( final int maxNumRetries = queryContext.getNumRetriesOnMissingSegments(
queryPlus.getQuery(),
config.getNumTries() config.getNumTries()
); );
if (missingSegments.isEmpty()) { if (missingSegments.isEmpty()) {
return false; return false;
} else if (retryCount >= maxNumRetries) { } else if (retryCount >= maxNumRetries) {
if (!QueryContexts.allowReturnPartialResults(queryPlus.getQuery(), config.isReturnPartialResults())) { if (!queryContext.allowReturnPartialResults(config.isReturnPartialResults())) {
throw new SegmentMissingException("No results found for segments[%s]", missingSegments); throw new SegmentMissingException("No results found for segments[%s]", missingSegments);
} else { } else {
return false; return false;

View File

@ -161,7 +161,7 @@ public class SinkQuerySegmentWalker implements QuerySegmentWalker
} }
final QueryToolChest<T, Query<T>> toolChest = factory.getToolchest(); final QueryToolChest<T, Query<T>> toolChest = factory.getToolchest();
final boolean skipIncrementalSegment = query.getContextBoolean(CONTEXT_SKIP_INCREMENTAL_SEGMENT, false); final boolean skipIncrementalSegment = query.context().getBoolean(CONTEXT_SKIP_INCREMENTAL_SEGMENT, false);
final AtomicLong cpuTimeAccumulator = new AtomicLong(0L); final AtomicLong cpuTimeAccumulator = new AtomicLong(0L);
// Make sure this query type can handle the subquery, if present. // Make sure this query type can handle the subquery, if present.

View File

@ -39,7 +39,6 @@ import org.apache.druid.query.GlobalTableDataSource;
import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.PostProcessingOperator; import org.apache.druid.query.PostProcessingOperator;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
@ -163,7 +162,7 @@ public class ClientQuerySegmentWalker implements QuerySegmentWalker
final DataSource freeTradeDataSource = globalizeIfPossible(newQuery.getDataSource()); final DataSource freeTradeDataSource = globalizeIfPossible(newQuery.getDataSource());
// do an inlining dry run to see if any inlining is necessary, without actually running the queries. // do an inlining dry run to see if any inlining is necessary, without actually running the queries.
final int maxSubqueryRows = QueryContexts.getMaxSubqueryRows(query, serverConfig.getMaxSubqueryRows()); final int maxSubqueryRows = query.context().getMaxSubqueryRows(serverConfig.getMaxSubqueryRows());
final DataSource inlineDryRun = inlineIfNecessary( final DataSource inlineDryRun = inlineIfNecessary(
freeTradeDataSource, freeTradeDataSource,
@ -431,7 +430,7 @@ public class ClientQuerySegmentWalker implements QuerySegmentWalker
.emitCPUTimeMetric(emitter) .emitCPUTimeMetric(emitter)
.postProcess( .postProcess(
objectMapper.convertValue( objectMapper.convertValue(
query.getQueryContext().getAsString("postProcessing"), query.context().getString("postProcessing"),
new TypeReference<PostProcessingOperator<T>>() {} new TypeReference<PostProcessingOperator<T>>() {}
) )
) )

View File

@ -21,6 +21,7 @@ package org.apache.druid.server;
import com.fasterxml.jackson.databind.ObjectWriter; import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import org.apache.druid.client.DirectDruidClient; import org.apache.druid.client.DirectDruidClient;
import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.DateTimes;
@ -61,7 +62,8 @@ import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -102,6 +104,8 @@ public class QueryLifecycle
@MonotonicNonNull @MonotonicNonNull
private Query<?> baseQuery; private Query<?> baseQuery;
@MonotonicNonNull
private Set<String> userContextKeys;
public QueryLifecycle( public QueryLifecycle(
final QueryToolChestWarehouse warehouse, final QueryToolChestWarehouse warehouse,
@ -195,17 +199,15 @@ public class QueryLifecycle
{ {
transition(State.NEW, State.INITIALIZED); transition(State.NEW, State.INITIALIZED);
if (baseQuery.getQueryContext() == null) { userContextKeys = new HashSet<>(baseQuery.getContext().keySet());
QueryContext context = new QueryContext(baseQuery.getContext()); String queryId = baseQuery.getId();
context.addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString()); if (Strings.isNullOrEmpty(queryId)) {
context.addDefaultParams(defaultQueryConfig.getContext()); queryId = UUID.randomUUID().toString();
this.baseQuery = baseQuery.withOverriddenContext(context.getMergedParams());
} else {
baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString());
baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext());
this.baseQuery = baseQuery;
} }
Map<String, Object> mergedUserAndConfigContext = QueryContexts.override(defaultQueryConfig.getContext(), baseQuery.getContext());
mergedUserAndConfigContext.put(BaseQuery.QUERY_ID, queryId);
this.baseQuery = baseQuery.withOverriddenContext(mergedUserAndConfigContext);
this.toolChest = warehouse.getToolChest(this.baseQuery); this.toolChest = warehouse.getToolChest(this.baseQuery);
} }
@ -220,23 +222,15 @@ public class QueryLifecycle
public Access authorize(HttpServletRequest req) public Access authorize(HttpServletRequest req)
{ {
transition(State.INITIALIZED, State.AUTHORIZING); transition(State.INITIALIZED, State.AUTHORIZING);
final Set<String> contextKeys;
if (baseQuery.getQueryContext() == null) {
contextKeys = baseQuery.getContext().keySet();
} else {
contextKeys = baseQuery.getQueryContext().getUserParams().keySet();
}
final Iterable<ResourceAction> resourcesToAuthorize = Iterables.concat( final Iterable<ResourceAction> resourcesToAuthorize = Iterables.concat(
Iterables.transform( Iterables.transform(
baseQuery.getDataSource().getTableNames(), baseQuery.getDataSource().getTableNames(),
AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR
), ),
authConfig.authorizeQueryContextParams() Iterables.transform(
? Iterables.transform( authConfig.contextKeysToAuthorize(userContextKeys),
contextKeys,
contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE) contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE)
) )
: Collections.emptyList()
); );
return doAuthorize( return doAuthorize(
AuthorizationUtils.authenticationResultFromRequest(req), AuthorizationUtils.authenticationResultFromRequest(req),
@ -353,7 +347,7 @@ public class QueryLifecycle
if (e != null) { if (e != null) {
statsMap.put("exception", e.toString()); statsMap.put("exception", e.toString());
if (QueryContexts.isDebug(baseQuery)) { if (baseQuery.context().isDebug()) {
log.warn(e, "Exception while processing queryId [%s]", baseQuery.getId()); log.warn(e, "Exception while processing queryId [%s]", baseQuery.getId());
} else { } else {
log.noStackTrace().warn(e, "Exception while processing queryId [%s]", baseQuery.getId()); log.noStackTrace().warn(e, "Exception while processing queryId [%s]", baseQuery.getId());
@ -403,9 +397,10 @@ public class QueryLifecycle
private boolean isSerializeDateTimeAsLong() private boolean isSerializeDateTimeAsLong()
{ {
final boolean shouldFinalize = QueryContexts.isFinalize(baseQuery, true); final QueryContext queryContext = baseQuery.context();
return QueryContexts.isSerializeDateTimeAsLong(baseQuery, false) final boolean shouldFinalize = queryContext.isFinalize(true);
|| (!shouldFinalize && QueryContexts.isSerializeDateTimeAsLongInner(baseQuery, false)); return queryContext.isSerializeDateTimeAsLong(false)
|| (!shouldFinalize && queryContext.isSerializeDateTimeAsLongInner(false));
} }
public ObjectWriter newOutputWriter(ResourceIOReaderWriter ioReaderWriter) public ObjectWriter newOutputWriter(ResourceIOReaderWriter ioReaderWriter)

View File

@ -46,7 +46,7 @@ import org.apache.druid.query.BadJsonQueryException;
import org.apache.druid.query.BadQueryException; import org.apache.druid.query.BadQueryException;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException; import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryException; import org.apache.druid.query.QueryException;
import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryTimeoutException; import org.apache.druid.query.QueryTimeoutException;
@ -383,20 +383,19 @@ public class QueryResource implements QueryCountStatsProvider
catch (JsonParseException e) { catch (JsonParseException e) {
throw new BadJsonQueryException(e); throw new BadJsonQueryException(e);
} }
String prevEtag = getPreviousEtag(req); String prevEtag = getPreviousEtag(req);
if (prevEtag == null) {
if (prevEtag != null) { return baseQuery;
if (baseQuery.getQueryContext() == null) {
QueryContext context = new QueryContext(baseQuery.getContext());
context.addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);
return baseQuery.withOverriddenContext(context.getMergedParams());
} else {
baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);
}
} }
return baseQuery; return baseQuery.withOverriddenContext(
QueryContexts.override(
baseQuery.getContext(),
HEADER_IF_NONE_MATCH,
prevEtag
)
);
} }
private static String getPreviousEtag(final HttpServletRequest req) private static String getPreviousEtag(final HttpServletRequest req)

View File

@ -38,7 +38,6 @@ import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.java.util.emitter.service.ServiceMetricEvent; import org.apache.druid.java.util.emitter.service.ServiceMetricEvent;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException; import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryWatcher; import org.apache.druid.query.QueryWatcher;
@ -254,7 +253,7 @@ public class QueryScheduler implements QueryWatcher
@VisibleForTesting @VisibleForTesting
List<Bulkhead> acquireLanes(Query<?> query) List<Bulkhead> acquireLanes(Query<?> query)
{ {
final String lane = QueryContexts.getLane(query); final String lane = query.context().getLane();
final Optional<BulkheadConfig> laneConfig = lane == null ? Optional.empty() : laneRegistry.getConfiguration(lane); final Optional<BulkheadConfig> laneConfig = lane == null ? Optional.empty() : laneRegistry.getConfiguration(lane);
final Optional<BulkheadConfig> totalConfig = laneRegistry.getConfiguration(TOTAL); final Optional<BulkheadConfig> totalConfig = laneRegistry.getConfiguration(TOTAL);
List<Bulkhead> hallPasses = new ArrayList<>(2); List<Bulkhead> hallPasses = new ArrayList<>(2);

View File

@ -22,8 +22,9 @@ package org.apache.druid.server;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.druid.client.DirectDruidClient; import org.apache.druid.client.DirectDruidClient;
import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.query.Queries;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.context.ResponseContext; import org.apache.druid.query.context.ResponseContext;
@ -56,21 +57,23 @@ public class SetAndVerifyContextQueryRunner<T> implements QueryRunner<T>
public Query<T> withTimeoutAndMaxScatterGatherBytes(Query<T> query, ServerConfig serverConfig) public Query<T> withTimeoutAndMaxScatterGatherBytes(Query<T> query, ServerConfig serverConfig)
{ {
Query<T> newQuery = QueryContexts.verifyMaxQueryTimeout( Query<T> newQuery =
QueryContexts.withMaxScatterGatherBytes( Queries.withMaxScatterGatherBytes(
QueryContexts.withDefaultTimeout( Queries.withDefaultTimeout(
query, query,
Math.min(serverConfig.getDefaultQueryTimeout(), serverConfig.getMaxQueryTimeout()) Math.min(serverConfig.getDefaultQueryTimeout(), serverConfig.getMaxQueryTimeout())
), ),
serverConfig.getMaxScatterGatherBytes() serverConfig.getMaxScatterGatherBytes()
), );
newQuery.context().verifyMaxQueryTimeout(
serverConfig.getMaxQueryTimeout() serverConfig.getMaxQueryTimeout()
); );
// DirectDruidClient.QUERY_FAIL_TIME is used by DirectDruidClient and JsonParserIterator to determine when to // DirectDruidClient.QUERY_FAIL_TIME is used by DirectDruidClient and JsonParserIterator to determine when to
// fail with a timeout exception // fail with a timeout exception
final long failTime; final long failTime;
if (QueryContexts.hasTimeout(newQuery)) { final QueryContext context = newQuery.context();
failTime = this.startTimeMillis + QueryContexts.getTimeout(newQuery); if (context.hasTimeout()) {
failTime = this.startTimeMillis + context.getTimeout();
} else { } else {
failTime = this.startTimeMillis + serverConfig.getMaxQueryTimeout(); failTime = this.startTimeMillis + serverConfig.getMaxQueryTimeout();
} }

View File

@ -26,6 +26,7 @@ import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap; import it.unimi.dsi.fastutil.objects.Object2IntMap;
import org.apache.druid.client.SegmentServerSelector; import org.apache.druid.client.SegmentServerSelector;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.server.QueryLaningStrategy; import org.apache.druid.server.QueryLaningStrategy;
@ -70,10 +71,11 @@ public class HiLoQueryLaningStrategy implements QueryLaningStrategy
// QueryContexts.getPriority gives a default, but it can parse the value to integer. Before calling QueryContexts.getPriority // QueryContexts.getPriority gives a default, but it can parse the value to integer. Before calling QueryContexts.getPriority
// we make sure that priority has been set. // we make sure that priority has been set.
Integer priority = null; Integer priority = null;
if (theQuery.getContextValue(QueryContexts.PRIORITY_KEY) != null) { final QueryContext queryContext = theQuery.context();
priority = QueryContexts.getPriority(theQuery); if (null != queryContext.get(QueryContexts.PRIORITY_KEY)) {
priority = queryContext.getPriority();
} }
final String lane = theQuery.getQueryContext().getAsString(QueryContexts.LANE_KEY); final String lane = queryContext.getLane();
if (lane == null && priority != null && priority < 0) { if (lane == null && priority != null && priority < 0) {
return Optional.of(LOW); return Optional.of(LOW);
} }

View File

@ -25,12 +25,12 @@ import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.objects.Object2IntArrayMap; import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap; import it.unimi.dsi.fastutil.objects.Object2IntMap;
import org.apache.druid.client.SegmentServerSelector; import org.apache.druid.client.SegmentServerSelector;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.server.QueryLaningStrategy; import org.apache.druid.server.QueryLaningStrategy;
import org.apache.druid.server.QueryScheduler; import org.apache.druid.server.QueryScheduler;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
@ -84,6 +84,6 @@ public class ManualQueryLaningStrategy implements QueryLaningStrategy
@Override @Override
public <T> Optional<String> computeLane(QueryPlus<T> query, Set<SegmentServerSelector> segments) public <T> Optional<String> computeLane(QueryPlus<T> query, Set<SegmentServerSelector> segments)
{ {
return Optional.ofNullable(QueryContexts.getLane(query.getQuery())); return Optional.ofNullable(query.getQuery().context().getLane());
} }
} }

View File

@ -22,7 +22,6 @@ package org.apache.druid.server.scheduling;
import it.unimi.dsi.fastutil.objects.Object2IntArrayMap; import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap; import it.unimi.dsi.fastutil.objects.Object2IntMap;
import org.apache.druid.client.SegmentServerSelector; import org.apache.druid.client.SegmentServerSelector;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.server.QueryLaningStrategy; import org.apache.druid.server.QueryLaningStrategy;
@ -47,6 +46,6 @@ public class NoQueryLaningStrategy implements QueryLaningStrategy
@Override @Override
public <T> Optional<String> computeLane(QueryPlus<T> query, Set<SegmentServerSelector> segments) public <T> Optional<String> computeLane(QueryPlus<T> query, Set<SegmentServerSelector> segments)
{ {
return Optional.ofNullable(QueryContexts.getLane(query.getQuery())); return Optional.ofNullable(query.getQuery().context().getLane());
} }
} }

View File

@ -25,7 +25,6 @@ import com.google.common.base.Preconditions;
import org.apache.druid.client.SegmentServerSelector; import org.apache.druid.client.SegmentServerSelector;
import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.server.QueryPrioritizationStrategy; import org.apache.druid.server.QueryPrioritizationStrategy;
import org.joda.time.DateTime; import org.joda.time.DateTime;
@ -33,6 +32,7 @@ import org.joda.time.Duration;
import org.joda.time.Period; import org.joda.time.Period;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
@ -87,7 +87,7 @@ public class ThresholdBasedQueryPrioritizationStrategy implements QueryPrioritiz
boolean violatesSegmentThreshold = segments.size() > segmentCountThreshold; boolean violatesSegmentThreshold = segments.size() > segmentCountThreshold;
if (violatesPeriodThreshold || violatesDurationThreshold || violatesSegmentThreshold) { if (violatesPeriodThreshold || violatesDurationThreshold || violatesSegmentThreshold) {
final int adjustedPriority = QueryContexts.getPriority(theQuery) - adjustment; final int adjustedPriority = theQuery.context().getPriority() - adjustment;
return Optional.of(adjustedPriority); return Optional.of(adjustedPriority);
} }
return Optional.empty(); return Optional.empty();

View File

@ -27,6 +27,7 @@ public class Access
static final String DEFAULT_ERROR_MESSAGE = "Unauthorized"; static final String DEFAULT_ERROR_MESSAGE = "Unauthorized";
public static final Access OK = new Access(true); public static final Access OK = new Access(true);
public static final Access DENIED = new Access(false);
private final boolean allowed; private final boolean allowed;
private final String message; private final String message;

View File

@ -21,10 +21,14 @@ package org.apache.druid.server.security;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.utils.CollectionUtils;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
public class AuthConfig public class AuthConfig
{ {
@ -46,25 +50,20 @@ public class AuthConfig
public static final String TRUSTED_DOMAIN_NAME = "trustedDomain"; public static final String TRUSTED_DOMAIN_NAME = "trustedDomain";
/**
* Set of context keys which are always permissible because something in the Druid
* code itself sets the key before the security check.
*/
public static final Set<String> ALLOWED_CONTEXT_KEYS = ImmutableSet.of(
// Set in the Avatica server path
QueryContexts.CTX_SQL_STRINGIFY_ARRAYS,
// Set by the Router
QueryContexts.CTX_SQL_QUERY_ID
);
public AuthConfig() public AuthConfig()
{ {
this(null, null, null, false, false); this(null, null, null, false, false, null, null);
}
@JsonCreator
public AuthConfig(
@JsonProperty("authenticatorChain") List<String> authenticatorChain,
@JsonProperty("authorizers") List<String> authorizers,
@JsonProperty("unsecuredPaths") List<String> unsecuredPaths,
@JsonProperty("allowUnauthenticatedHttpOptions") boolean allowUnauthenticatedHttpOptions,
@JsonProperty("authorizeQueryContextParams") boolean authorizeQueryContextParams
)
{
this.authenticatorChain = authenticatorChain;
this.authorizers = authorizers;
this.unsecuredPaths = unsecuredPaths == null ? Collections.emptyList() : unsecuredPaths;
this.allowUnauthenticatedHttpOptions = allowUnauthenticatedHttpOptions;
this.authorizeQueryContextParams = authorizeQueryContextParams;
} }
@JsonProperty @JsonProperty
@ -82,6 +81,44 @@ public class AuthConfig
@JsonProperty @JsonProperty
private final boolean authorizeQueryContextParams; private final boolean authorizeQueryContextParams;
/**
* The set of query context keys that are allowed, even when security is
* enabled. A null value is the same as an empty set.
*/
@JsonProperty
private final Set<String> unsecuredContextKeys;
/**
* The set of query context keys to secure, when context security is
* enabled. Null has a special meaning: it means to ignore this set.
* Else, only the keys in this set are subject to security. If set,
* the unsecured list is ignored.
*/
@JsonProperty
private final Set<String> securedContextKeys;
@JsonCreator
public AuthConfig(
@JsonProperty("authenticatorChain") List<String> authenticatorChain,
@JsonProperty("authorizers") List<String> authorizers,
@JsonProperty("unsecuredPaths") List<String> unsecuredPaths,
@JsonProperty("allowUnauthenticatedHttpOptions") boolean allowUnauthenticatedHttpOptions,
@JsonProperty("authorizeQueryContextParams") boolean authorizeQueryContextParams,
@JsonProperty("unsecuredContextKeys") Set<String> unsecuredContextKeys,
@JsonProperty("securedContextKeys") Set<String> securedContextKeys
)
{
this.authenticatorChain = authenticatorChain;
this.authorizers = authorizers;
this.unsecuredPaths = unsecuredPaths == null ? Collections.emptyList() : unsecuredPaths;
this.allowUnauthenticatedHttpOptions = allowUnauthenticatedHttpOptions;
this.authorizeQueryContextParams = authorizeQueryContextParams;
this.unsecuredContextKeys = unsecuredContextKeys == null
? Collections.emptySet()
: unsecuredContextKeys;
this.securedContextKeys = securedContextKeys;
}
public List<String> getAuthenticatorChain() public List<String> getAuthenticatorChain()
{ {
return authenticatorChain; return authenticatorChain;
@ -107,6 +144,36 @@ public class AuthConfig
return authorizeQueryContextParams; return authorizeQueryContextParams;
} }
/**
* Filter the user-supplied context keys based on the context key security
* rules. If context key security is disabled, then allow all keys. Else,
* apply the three key lists defined here.
* <ul>
* <li>Allow Druid-defined keys.</li>
* <li>Allow anything not in the secured context key list.</li>
* <li>Allow anything in the config-defined unsecured key list.</li>
* </ul>
* In the typical case, a site defines either the secured key list
* (to handle a few keys that are <i>are not</i> allowed) or the unsecured key
* list (to enumerate a few that <i>are</i> allowed.) If both lists
* are given, think of the secured list as exceptions to the unsecured
* key list.
*
* @return the list of secured keys to check via authentication
*/
public Set<String> contextKeysToAuthorize(final Set<String> userKeys)
{
if (!authorizeQueryContextParams) {
return ImmutableSet.of();
}
Set<String> keysToCheck = CollectionUtils.subtract(userKeys, ALLOWED_CONTEXT_KEYS);
keysToCheck = CollectionUtils.subtract(keysToCheck, unsecuredContextKeys);
if (securedContextKeys != null) {
keysToCheck = CollectionUtils.intersect(keysToCheck, securedContextKeys);
}
return keysToCheck;
}
@Override @Override
public boolean equals(Object o) public boolean equals(Object o)
{ {
@ -121,7 +188,9 @@ public class AuthConfig
&& authorizeQueryContextParams == that.authorizeQueryContextParams && authorizeQueryContextParams == that.authorizeQueryContextParams
&& Objects.equals(authenticatorChain, that.authenticatorChain) && Objects.equals(authenticatorChain, that.authenticatorChain)
&& Objects.equals(authorizers, that.authorizers) && Objects.equals(authorizers, that.authorizers)
&& Objects.equals(unsecuredPaths, that.unsecuredPaths); && Objects.equals(unsecuredPaths, that.unsecuredPaths)
&& Objects.equals(unsecuredContextKeys, that.unsecuredContextKeys)
&& Objects.equals(securedContextKeys, that.securedContextKeys);
} }
@Override @Override
@ -132,7 +201,9 @@ public class AuthConfig
authorizers, authorizers,
unsecuredPaths, unsecuredPaths,
allowUnauthenticatedHttpOptions, allowUnauthenticatedHttpOptions,
authorizeQueryContextParams authorizeQueryContextParams,
unsecuredContextKeys,
securedContextKeys
); );
} }
@ -145,6 +216,8 @@ public class AuthConfig
", unsecuredPaths=" + unsecuredPaths + ", unsecuredPaths=" + unsecuredPaths +
", allowUnauthenticatedHttpOptions=" + allowUnauthenticatedHttpOptions + ", allowUnauthenticatedHttpOptions=" + allowUnauthenticatedHttpOptions +
", enableQueryContextAuthorization=" + authorizeQueryContextParams + ", enableQueryContextAuthorization=" + authorizeQueryContextParams +
", unsecuredContextKeys=" + unsecuredContextKeys +
", securedContextKeys=" + securedContextKeys +
'}'; '}';
} }
@ -163,6 +236,8 @@ public class AuthConfig
private List<String> unsecuredPaths; private List<String> unsecuredPaths;
private boolean allowUnauthenticatedHttpOptions; private boolean allowUnauthenticatedHttpOptions;
private boolean authorizeQueryContextParams; private boolean authorizeQueryContextParams;
private Set<String> unsecuredContextKeys;
private Set<String> securedContextKeys;
public Builder setAuthenticatorChain(List<String> authenticatorChain) public Builder setAuthenticatorChain(List<String> authenticatorChain)
{ {
@ -194,6 +269,18 @@ public class AuthConfig
return this; return this;
} }
public Builder setUnsecuredContextKeys(Set<String> unsecuredContextKeys)
{
this.unsecuredContextKeys = unsecuredContextKeys;
return this;
}
public Builder setSecuredContextKeys(Set<String> securedContextKeys)
{
this.securedContextKeys = securedContextKeys;
return this;
}
public AuthConfig build() public AuthConfig build()
{ {
return new AuthConfig( return new AuthConfig(
@ -201,7 +288,9 @@ public class AuthConfig
authorizers, authorizers,
unsecuredPaths, unsecuredPaths,
allowUnauthenticatedHttpOptions, allowUnauthenticatedHttpOptions,
authorizeQueryContextParams authorizeQueryContextParams,
unsecuredContextKeys,
securedContextKeys
); );
} }
} }

View File

@ -19,12 +19,14 @@
package org.apache.druid.client; package org.apache.druid.client;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Bytes; import com.google.common.primitives.Bytes;
import org.apache.druid.client.selector.QueryableDruidServer; import org.apache.druid.client.selector.QueryableDruidServer;
import org.apache.druid.client.selector.ServerSelector; import org.apache.druid.client.selector.ServerSelector;
import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.query.planning.DataSourceAnalysis;
import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.segment.join.JoinableFactoryWrapper;
@ -43,7 +45,6 @@ import org.junit.runner.RunWith;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import static org.apache.druid.query.QueryContexts.DEFAULT_BY_SEGMENT;
import static org.easymock.EasyMock.expect; import static org.easymock.EasyMock.expect;
import static org.easymock.EasyMock.replay; import static org.easymock.EasyMock.replay;
import static org.easymock.EasyMock.reset; import static org.easymock.EasyMock.reset;
@ -67,7 +68,7 @@ public class CachingClusteredClientCacheKeyManagerTest extends EasyMockSupport
public void setup() public void setup()
{ {
expect(strategy.computeCacheKey(query)).andReturn(QUERY_CACHE_KEY).anyTimes(); expect(strategy.computeCacheKey(query)).andReturn(QUERY_CACHE_KEY).anyTimes();
expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(false).anyTimes(); expect(query.context()).andReturn(QueryContext.of(ImmutableMap.of(QueryContexts.BY_SEGMENT_KEY, false))).anyTimes();
} }
@After @After
@ -203,7 +204,7 @@ public class CachingClusteredClientCacheKeyManagerTest extends EasyMockSupport
{ {
expect(dataSourceAnalysis.isJoin()).andReturn(false); expect(dataSourceAnalysis.isJoin()).andReturn(false);
reset(query); reset(query);
expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(true).anyTimes(); expect(query.context()).andReturn(QueryContext.of(ImmutableMap.of(QueryContexts.BY_SEGMENT_KEY, true))).anyTimes();
replayAll(); replayAll();
CachingClusteredClient.CacheKeyManager<Object> keyManager = makeKeyManager(); CachingClusteredClient.CacheKeyManager<Object> keyManager = makeKeyManager();
Set<SegmentServerSelector> selectors = ImmutableSet.of( Set<SegmentServerSelector> selectors = ImmutableSet.of(
@ -272,7 +273,7 @@ public class CachingClusteredClientCacheKeyManagerTest extends EasyMockSupport
public void testSegmentQueryCacheKey_noCachingIfBySegment() public void testSegmentQueryCacheKey_noCachingIfBySegment()
{ {
reset(query); reset(query);
expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(true).anyTimes(); expect(query.context()).andReturn(QueryContext.of(ImmutableMap.of(QueryContexts.BY_SEGMENT_KEY, true))).anyTimes();
replayAll(); replayAll();
byte[] cacheKey = makeKeyManager().computeSegmentLevelQueryCacheKey(); byte[] cacheKey = makeKeyManager().computeSegmentLevelQueryCacheKey();
Assert.assertNull(cacheKey); Assert.assertNull(cacheKey);

View File

@ -72,6 +72,7 @@ import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.Druids; import org.apache.druid.query.Druids;
import org.apache.druid.query.FinalizeResultsQueryRunner; import org.apache.druid.query.FinalizeResultsQueryRunner;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunner;
@ -2297,12 +2298,13 @@ public class CachingClusteredClientTest
for (Capture queryCapture : queryCaptures) { for (Capture queryCapture : queryCaptures) {
QueryPlus capturedQueryPlus = (QueryPlus) queryCapture.getValue(); QueryPlus capturedQueryPlus = (QueryPlus) queryCapture.getValue();
Query capturedQuery = capturedQueryPlus.getQuery(); Query capturedQuery = capturedQueryPlus.getQuery();
final QueryContext queryContext = capturedQuery.context();
if (expectBySegment) { if (expectBySegment) {
Assert.assertEquals(true, capturedQuery.getQueryContext().getAsBoolean(QueryContexts.BY_SEGMENT_KEY)); Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.BY_SEGMENT_KEY));
} else { } else {
Assert.assertTrue( Assert.assertTrue(
capturedQuery.getContextValue(QueryContexts.BY_SEGMENT_KEY) == null || queryContext.get(QueryContexts.BY_SEGMENT_KEY) == null ||
capturedQuery.getQueryContext().getAsBoolean(QueryContexts.BY_SEGMENT_KEY).equals(false) !queryContext.getBoolean(QueryContexts.BY_SEGMENT_KEY)
); );
} }
} }

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.Futures;
import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.jackson.DefaultObjectMapper;
@ -309,13 +310,8 @@ public class JsonParserIteratorTest
Query<?> query = Mockito.mock(Query.class); Query<?> query = Mockito.mock(Query.class);
QueryContext context = Mockito.mock(QueryContext.class); QueryContext context = Mockito.mock(QueryContext.class);
Mockito.when(query.getId()).thenReturn(queryId); Mockito.when(query.getId()).thenReturn(queryId);
Mockito.when(query.getQueryContext()).thenReturn(context); Mockito.when(query.context()).thenReturn(
Mockito.when( QueryContext.of(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, timeoutAt)));
context.getAsLong(
ArgumentMatchers.eq(DirectDruidClient.QUERY_FAIL_TIME),
ArgumentMatchers.eq(-1L)
)
).thenReturn(timeoutAt);
return query; return query;
} }
} }

View File

@ -119,8 +119,6 @@ public class UnifiedIndexerAppenderatorsManagerTest extends InitializedNullHandl
@Test @Test
public void test_getBundle_knownDataSource() public void test_getBundle_knownDataSource()
{ {
final UnifiedIndexerAppenderatorsManager.DatasourceBundle bundle = manager.getBundle( final UnifiedIndexerAppenderatorsManager.DatasourceBundle bundle = manager.getBundle(
Druids.newScanQueryBuilder() Druids.newScanQueryBuilder()
.dataSource(appenderator.getDataSource()) .dataSource(appenderator.getDataSource())

View File

@ -21,6 +21,7 @@ package org.apache.druid.server;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Sequences;
@ -55,6 +56,9 @@ import org.junit.rules.ExpectedException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import java.util.HashMap;
import java.util.Map;
public class QueryLifecycleTest public class QueryLifecycleTest
{ {
private static final String DATASOURCE = "some_datasource"; private static final String DATASOURCE = "some_datasource";
@ -73,9 +77,6 @@ public class QueryLifecycleTest
RequestLogger requestLogger; RequestLogger requestLogger;
AuthorizerMapper authzMapper; AuthorizerMapper authzMapper;
DefaultQueryConfig queryConfig; DefaultQueryConfig queryConfig;
AuthConfig authConfig;
QueryLifecycle lifecycle;
QueryToolChest toolChest; QueryToolChest toolChest;
QueryRunner runner; QueryRunner runner;
@ -97,11 +98,18 @@ public class QueryLifecycleTest
authorizer = EasyMock.createMock(Authorizer.class); authorizer = EasyMock.createMock(Authorizer.class);
authzMapper = new AuthorizerMapper(ImmutableMap.of(AUTHORIZER, authorizer)); authzMapper = new AuthorizerMapper(ImmutableMap.of(AUTHORIZER, authorizer));
queryConfig = EasyMock.createMock(DefaultQueryConfig.class); queryConfig = EasyMock.createMock(DefaultQueryConfig.class);
authConfig = EasyMock.createMock(AuthConfig.class);
toolChest = EasyMock.createMock(QueryToolChest.class);
runner = EasyMock.createMock(QueryRunner.class);
metrics = EasyMock.createNiceMock(QueryMetrics.class);
authenticationResult = EasyMock.createMock(AuthenticationResult.class);
}
private QueryLifecycle createLifecycle(AuthConfig authConfig)
{
long nanos = System.nanoTime(); long nanos = System.nanoTime();
long millis = System.currentTimeMillis(); long millis = System.currentTimeMillis();
lifecycle = new QueryLifecycle( return new QueryLifecycle(
toolChestWarehouse, toolChestWarehouse,
texasRanger, texasRanger,
metricsFactory, metricsFactory,
@ -113,11 +121,6 @@ public class QueryLifecycleTest
millis, millis,
nanos nanos
); );
toolChest = EasyMock.createMock(QueryToolChest.class);
runner = EasyMock.createMock(QueryRunner.class);
metrics = EasyMock.createNiceMock(QueryMetrics.class);
authenticationResult = EasyMock.createMock(AuthenticationResult.class);
} }
@After @After
@ -151,9 +154,9 @@ public class QueryLifecycleTest
.once(); .once();
EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).once(); EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).once();
replayAll(); replayAll();
QueryLifecycle lifecycle = createLifecycle(new AuthConfig());
lifecycle.runSimple(query, authenticationResult, Access.OK); lifecycle.runSimple(query, authenticationResult, Access.OK);
} }
@ -174,6 +177,7 @@ public class QueryLifecycleTest
replayAll(); replayAll();
QueryLifecycle lifecycle = createLifecycle(new AuthConfig());
lifecycle.runSimple(query, authenticationResult, new Access(false)); lifecycle.runSimple(query, authenticationResult, new Access(false));
} }
@ -181,7 +185,6 @@ public class QueryLifecycleTest
public void testAuthorizeQueryContext_authorized() public void testAuthorizeQueryContext_authorized()
{ {
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ)) EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
@ -197,21 +200,27 @@ public class QueryLifecycleTest
replayAll(); replayAll();
final Map<String, Object> userContext = ImmutableMap.of("foo", "bar", "baz", "qux");
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE) .dataSource(DATASOURCE)
.intervals(ImmutableList.of(Intervals.ETERNITY)) .intervals(ImmutableList.of(Intervals.ETERNITY))
.aggregators(new CountAggregatorFactory("chocula")) .aggregators(new CountAggregatorFactory("chocula"))
.context(ImmutableMap.of("foo", "bar", "baz", "qux")) .context(userContext)
.build(); .build();
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query); lifecycle.initialize(query);
Assert.assertEquals( final Map<String, Object> revisedContext = new HashMap<>(lifecycle.getQuery().getContext());
ImmutableMap.of("foo", "bar", "baz", "qux"),
lifecycle.getQuery().getQueryContext().getUserParams()
);
Assert.assertTrue(lifecycle.getQuery().getQueryContext().getMergedParams().containsKey("queryId"));
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId")); Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId"));
revisedContext.remove("queryId");
Assert.assertEquals(
userContext,
revisedContext
);
Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed()); Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
} }
@ -220,13 +229,12 @@ public class QueryLifecycleTest
public void testAuthorizeQueryContext_notAuthorized() public void testAuthorizeQueryContext_notAuthorized()
{ {
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ)) EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK); .andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE)) EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(new Access(false)); .andReturn(Access.DENIED);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject())) EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest) .andReturn(toolChest)
@ -241,6 +249,128 @@ public class QueryLifecycleTest
.context(ImmutableMap.of("foo", "bar")) .context(ImmutableMap.of("foo", "bar"))
.build(); .build();
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query);
Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed());
}
@Test
public void testAuthorizeQueryContext_unsecuredKeys()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
.once();
replayAll();
final Map<String, Object> userContext = ImmutableMap.of("foo", "bar", "baz", "qux");
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE)
.intervals(ImmutableList.of(Intervals.ETERNITY))
.aggregators(new CountAggregatorFactory("chocula"))
.context(userContext)
.build();
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.setUnsecuredContextKeys(ImmutableSet.of("foo", "baz"))
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query);
final Map<String, Object> revisedContext = new HashMap<>(lifecycle.getQuery().getContext());
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId"));
revisedContext.remove("queryId");
Assert.assertEquals(
userContext,
revisedContext
);
Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
}
@Test
public void testAuthorizeQueryContext_securedKeys()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
.once();
replayAll();
final Map<String, Object> userContext = ImmutableMap.of("foo", "bar", "baz", "qux");
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE)
.intervals(ImmutableList.of(Intervals.ETERNITY))
.aggregators(new CountAggregatorFactory("chocula"))
.context(userContext)
.build();
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
// We have secured keys, just not what the user gave.
.setSecuredContextKeys(ImmutableSet.of("foo2", "baz2"))
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query);
final Map<String, Object> revisedContext = new HashMap<>(lifecycle.getQuery().getContext());
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId"));
revisedContext.remove("queryId");
Assert.assertEquals(
userContext,
revisedContext
);
Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
}
@Test
public void testAuthorizeQueryContext_securedKeysNotAuthorized()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.DENIED);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
.once();
replayAll();
final Map<String, Object> userContext = ImmutableMap.of("foo", "bar", "baz", "qux");
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE)
.intervals(ImmutableList.of(Intervals.ETERNITY))
.aggregators(new CountAggregatorFactory("chocula"))
.context(userContext)
.build();
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
// We have secured keys. User used one of them.
.setSecuredContextKeys(ImmutableSet.of("foo", "baz2"))
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query); lifecycle.initialize(query);
Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed()); Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed());
} }
@ -249,7 +379,6 @@ public class QueryLifecycleTest
public void testAuthorizeLegacyQueryContext_authorized() public void testAuthorizeLegacyQueryContext_authorized()
{ {
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("fake", ResourceType.DATASOURCE), Action.READ)) EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("fake", ResourceType.DATASOURCE), Action.READ))
@ -257,9 +386,6 @@ public class QueryLifecycleTest
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE)) EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.OK); .andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("baz", ResourceType.QUERY_CONTEXT), Action.WRITE)).andReturn(Access.OK); EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("baz", ResourceType.QUERY_CONTEXT), Action.WRITE)).andReturn(Access.OK);
// to use legacy query context with context authorization, even system generated things like queryId need to be explicitly added
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("queryId", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.OK);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject())) EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest) .andReturn(toolChest)
@ -269,12 +395,17 @@ public class QueryLifecycleTest
final QueryContextTest.LegacyContextQuery query = new QueryContextTest.LegacyContextQuery(ImmutableMap.of("foo", "bar", "baz", "qux")); final QueryContextTest.LegacyContextQuery query = new QueryContextTest.LegacyContextQuery(ImmutableMap.of("foo", "bar", "baz", "qux"));
AuthConfig authConfig = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.build();
QueryLifecycle lifecycle = createLifecycle(authConfig);
lifecycle.initialize(query); lifecycle.initialize(query);
Assert.assertNull(lifecycle.getQuery().getQueryContext()); final Map<String, Object> revisedContext = lifecycle.getQuery().getContext();
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("foo")); Assert.assertNotNull(revisedContext);
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("baz")); Assert.assertTrue(revisedContext.containsKey("foo"));
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId")); Assert.assertTrue(revisedContext.containsKey("baz"));
Assert.assertTrue(revisedContext.containsKey("queryId"));
Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed()); Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
} }
@ -301,7 +432,6 @@ public class QueryLifecycleTest
emitter, emitter,
requestLogger, requestLogger,
queryConfig, queryConfig,
authConfig,
toolChest, toolChest,
runner, runner,
metrics, metrics,

View File

@ -48,7 +48,6 @@ import org.apache.druid.java.util.emitter.core.NoopEmitter;
import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException; import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.topn.TopNQuery; import org.apache.druid.query.topn.TopNQuery;
@ -150,7 +149,7 @@ public class QuerySchedulerTest
try { try {
Query<?> scheduledReport = scheduler.prioritizeAndLaneQuery(QueryPlus.wrap(report), ImmutableSet.of()); Query<?> scheduledReport = scheduler.prioritizeAndLaneQuery(QueryPlus.wrap(report), ImmutableSet.of());
Assert.assertNotNull(scheduledReport); Assert.assertNotNull(scheduledReport);
Assert.assertEquals(HiLoQueryLaningStrategy.LOW, QueryContexts.getLane(scheduledReport)); Assert.assertEquals(HiLoQueryLaningStrategy.LOW, scheduledReport.context().getLane());
Sequence<Integer> underlyingSequence = makeSequence(10); Sequence<Integer> underlyingSequence = makeSequence(10);
underlyingSequence = Sequences.wrap(underlyingSequence, new SequenceWrapper() underlyingSequence = Sequences.wrap(underlyingSequence, new SequenceWrapper()
@ -412,8 +411,8 @@ public class QuerySchedulerTest
EasyMock.createMock(SegmentServerSelector.class) EasyMock.createMock(SegmentServerSelector.class)
) )
); );
Assert.assertEquals(-5, QueryContexts.getPriority(query)); Assert.assertEquals(-5, query.context().getPriority());
Assert.assertEquals(HiLoQueryLaningStrategy.LOW, QueryContexts.getLane(query)); Assert.assertEquals(HiLoQueryLaningStrategy.LOW, query.context().getLane());
} }
@Test @Test

View File

@ -36,7 +36,6 @@ import org.junit.Test;
public class SetAndVerifyContextQueryRunnerTest public class SetAndVerifyContextQueryRunnerTest
{ {
@Test @Test
public void testTimeoutIsUsedIfTimeoutIsNonZero() throws InterruptedException public void testTimeoutIsUsedIfTimeoutIsNonZero() throws InterruptedException
{ {
@ -58,7 +57,7 @@ public class SetAndVerifyContextQueryRunnerTest
// time + 1 at the time the method was called // time + 1 at the time the method was called
// this means that after sleeping for 1 millis, the fail time should be less than the current time when checking // this means that after sleeping for 1 millis, the fail time should be less than the current time when checking
Assert.assertTrue( Assert.assertTrue(
System.currentTimeMillis() > transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME) System.currentTimeMillis() > transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME)
); );
} }
@ -85,7 +84,7 @@ public class SetAndVerifyContextQueryRunnerTest
Query<ScanResultValue> transformed = queryRunner.withTimeoutAndMaxScatterGatherBytes(query, defaultConfig); Query<ScanResultValue> transformed = queryRunner.withTimeoutAndMaxScatterGatherBytes(query, defaultConfig);
// timeout is not set, default timeout has been set to long.max, make sure timeout is still in the future // timeout is not set, default timeout has been set to long.max, make sure timeout is still in the future
Assert.assertEquals((Long) Long.MAX_VALUE, transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME)); Assert.assertEquals(Long.MAX_VALUE, (long) transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME));
} }
@Test @Test
@ -107,7 +106,7 @@ public class SetAndVerifyContextQueryRunnerTest
// timeout is set to 0, so withTimeoutAndMaxScatterGatherBytes should set QUERY_FAIL_TIME to be the current // timeout is set to 0, so withTimeoutAndMaxScatterGatherBytes should set QUERY_FAIL_TIME to be the current
// time + max query timeout at the time the method was called // time + max query timeout at the time the method was called
// since default is long max, expect long max since current time would overflow // since default is long max, expect long max since current time would overflow
Assert.assertEquals((Long) Long.MAX_VALUE, transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME)); Assert.assertEquals(Long.MAX_VALUE, (long) transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME));
} }
@Test @Test
@ -137,7 +136,7 @@ public class SetAndVerifyContextQueryRunnerTest
// time + max query timeout at the time the method was called // time + max query timeout at the time the method was called
// this means that the fail time should be greater than the current time when checking // this means that the fail time should be greater than the current time when checking
Assert.assertTrue( Assert.assertTrue(
System.currentTimeMillis() < (Long) transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME) System.currentTimeMillis() < transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME)
); );
} }
} }

View File

@ -19,9 +19,16 @@
package org.apache.druid.server.security; package org.apache.druid.server.security;
import com.google.common.collect.ImmutableSet;
import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.query.QueryContexts;
import org.junit.Test; import org.junit.Test;
import java.util.Set;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class AuthConfigTest public class AuthConfigTest
{ {
@Test @Test
@ -29,4 +36,55 @@ public class AuthConfigTest
{ {
EqualsVerifier.configure().usingGetClass().forClass(AuthConfig.class).verify(); EqualsVerifier.configure().usingGetClass().forClass(AuthConfig.class).verify();
} }
@Test
public void testContextSecurity()
{
// No security
{
AuthConfig config = new AuthConfig();
Set<String> keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID);
assertTrue(config.contextKeysToAuthorize(keys).isEmpty());
}
// Default security
{
AuthConfig config = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.build();
Set<String> keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID);
assertEquals(ImmutableSet.of("a", "b"), config.contextKeysToAuthorize(keys));
}
// Specify unsecured keys (white-list)
{
AuthConfig config = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.setUnsecuredContextKeys(ImmutableSet.of("a"))
.build();
Set<String> keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID);
assertEquals(ImmutableSet.of("b"), config.contextKeysToAuthorize(keys));
}
// Specify secured keys (black-list)
{
AuthConfig config = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.setSecuredContextKeys(ImmutableSet.of("a"))
.build();
Set<String> keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID);
assertEquals(ImmutableSet.of("a"), config.contextKeysToAuthorize(keys));
}
// Specify both
{
AuthConfig config = AuthConfig.newBuilder()
.setAuthorizeQueryContextParams(true)
.setUnsecuredContextKeys(ImmutableSet.of("a", "b"))
.setSecuredContextKeys(ImmutableSet.of("b", "c"))
.build();
Set<String> keys = ImmutableSet.of("a", "b", "c", "d", QueryContexts.CTX_SQL_QUERY_ID);
assertEquals(ImmutableSet.of("c"), config.contextKeysToAuthorize(keys));
}
}
} }

Some files were not shown because too many files have changed in this diff Show More