Avoid ClassCastException when getting values from `QueryContext` (#13022)

* Use safe conversion methods

* Rename method

* Add getContextAsBoolean

* Update test case

* Remove generic from getContextValue

* Update catch-handler

* Add test

* Resolve comments

* Replace 'getContextXXX' to 'getQueryContext().getAsXXXX'
This commit is contained in:
Frank Chen 2022-09-13 18:00:09 +08:00 committed by GitHub
parent 08d6aca528
commit fd6c05eee8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 398 additions and 281 deletions

View File

@ -24,7 +24,6 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.collect.Ordering;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.BaseQuery;
import org.apache.druid.query.DataSource;
@ -153,30 +152,6 @@ public class MaterializedViewQuery<T> implements Query<T>
return query.getQueryContext();
}
@Override
public <ContextType> ContextType getContextValue(String key)
{
return (ContextType) query.getContextValue(key);
}
@Override
public <ContextType> ContextType getContextValue(String key, ContextType defaultValue)
{
return (ContextType) query.getContextValue(key, defaultValue);
}
@Override
public boolean getContextBoolean(String key, boolean defaultValue)
{
return query.getContextBoolean(key, defaultValue);
}
@Override
public HumanReadableBytes getContextHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{
return query.getContextHumanReadableBytes(key, defaultValue);
}
@Override
public boolean isDescending()
{

View File

@ -121,7 +121,7 @@ public class MaterializedViewQueryTest
.postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT)
.build();
MaterializedViewQuery query = new MaterializedViewQuery(topNQuery, optimizer);
Assert.assertEquals(20_000_000, query.getContextHumanReadableBytes("maxOnDiskStorage", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(20_000_000, query.getContextAsHumanReadableBytes("maxOnDiskStorage", HumanReadableBytes.ZERO).getBytes());
}
}

View File

@ -1454,7 +1454,7 @@ public class ControllerImpl implements Controller
)
{
if (isRollupQuery) {
final String queryGranularity = query.getContextValue(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY, "");
final String queryGranularity = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY, "");
if (timeIsGroupByDimension((GroupByQuery) query, columnMappings) && !queryGranularity.isEmpty()) {
return new ArbitraryGranularitySpec(
@ -1483,7 +1483,7 @@ public class ControllerImpl implements Controller
{
if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) {
final String queryTimeColumn = columnMappings.getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME);
return queryTimeColumn.equals(groupByQuery.getContextValue(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD));
return queryTimeColumn.equals(groupByQuery.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD));
} else {
return false;
}

View File

@ -191,7 +191,7 @@ public class QueryKitUtils
public static VirtualColumn makeSegmentGranularityVirtualColumn(final Query<?> query)
{
final Granularity segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext(query.getContext());
final String timeColumnName = query.getContextValue(QueryKitUtils.CTX_TIME_COLUMN_NAME);
final String timeColumnName = query.getQueryContext().getAsString(QueryKitUtils.CTX_TIME_COLUMN_NAME);
if (timeColumnName == null || Granularities.ALL.equals(segmentGranularity)) {
return null;

View File

@ -57,7 +57,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
{
RowSignature scanSignature;
try {
final String s = scanQuery.getContextValue(DruidQuery.CTX_SCAN_SIGNATURE);
final String s = scanQuery.getQueryContext().getAsString(DruidQuery.CTX_SCAN_SIGNATURE);
scanSignature = jsonMapper.readValue(s, RowSignature.class);
}
catch (JsonProcessingException e) {

View File

@ -181,19 +181,6 @@ public abstract class BaseQuery<T> implements Query<T>
return context;
}
@Override
public <ContextType> ContextType getContextValue(String key)
{
return (ContextType) context.get(key);
}
@Override
public <ContextType> ContextType getContextValue(String key, ContextType defaultValue)
{
ContextType retVal = getContextValue(key);
return retVal == null ? defaultValue : retVal;
}
@Override
public boolean getContextBoolean(String key, boolean defaultValue)
{
@ -201,7 +188,7 @@ public abstract class BaseQuery<T> implements Query<T>
}
@Override
public HumanReadableBytes getContextHumanReadableBytes(String key, HumanReadableBytes defaultValue)
public HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{
return context.getAsHumanReadableBytes(key, defaultValue);
}

View File

@ -124,11 +124,37 @@ public interface Query<T>
return null;
}
<ContextType> ContextType getContextValue(String key);
/**
* Get context value and cast to ContextType in an unsafe way.
*
* For safe conversion, it's recommended to use following methods instead
*
* {@link QueryContext#getAsBoolean(String)}
* {@link QueryContext#getAsString(String)}
* {@link QueryContext#getAsInt(String)}
* {@link QueryContext#getAsLong(String)}
* {@link QueryContext#getAsFloat(String, float)}
* {@link QueryContext#getAsEnum(String, Class, Enum)}
* {@link QueryContext#getAsHumanReadableBytes(String, HumanReadableBytes)}
*/
@Nullable
default <ContextType> ContextType getContextValue(String key)
{
if (getQueryContext() == null) {
return null;
} else {
return (ContextType) getQueryContext().get(key);
}
}
<ContextType> ContextType getContextValue(String key, ContextType defaultValue);
boolean getContextBoolean(String key, boolean defaultValue);
default boolean getContextBoolean(String key, boolean defaultValue)
{
if (getQueryContext() == null) {
return defaultValue;
} else {
return getQueryContext().getAsBoolean(key, defaultValue);
}
}
/**
* Returns {@link HumanReadableBytes} for a specified context key. If the context is null or the key doesn't exist
@ -139,12 +165,12 @@ public interface Query<T>
* @param defaultValue The default to return if the key value doesn't exist or the context is null.
* @return {@link HumanReadableBytes}
*/
default HumanReadableBytes getContextHumanReadableBytes(String key, HumanReadableBytes defaultValue)
default HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{
if (null != getQueryContext()) {
return getQueryContext().getAsHumanReadableBytes(key, defaultValue);
} else {
if (getQueryContext() == null) {
return defaultValue;
} else {
return getQueryContext().getAsHumanReadableBytes(key, defaultValue);
}
}
@ -204,7 +230,7 @@ public interface Query<T>
@Nullable
default String getSqlQueryId()
{
return getContextValue(BaseQuery.SQL_QUERY_ID);
return getQueryContext().getAsString(BaseQuery.SQL_QUERY_ID);
}
/**

View File

@ -168,33 +168,66 @@ public class QueryContext
@Nullable
public String getAsString(String key)
{
return (String) get(key);
Object val = get(key);
return val == null ? null : val.toString();
}
public String getAsString(String key, String defaultValue)
{
Object val = get(key);
return val == null ? defaultValue : val.toString();
}
@Nullable
public Boolean getAsBoolean(String key)
{
return QueryContexts.getAsBoolean(key, get(key));
}
public boolean getAsBoolean(
final String parameter,
final String key,
final boolean defaultValue
)
{
return QueryContexts.getAsBoolean(parameter, get(parameter), defaultValue);
return QueryContexts.getAsBoolean(key, get(key), defaultValue);
}
public Integer getAsInt(final String key)
{
return QueryContexts.getAsInt(key, get(key));
}
public int getAsInt(
final String parameter,
final String key,
final int defaultValue
)
{
return QueryContexts.getAsInt(parameter, get(parameter), defaultValue);
return QueryContexts.getAsInt(key, get(key), defaultValue);
}
public long getAsLong(final String parameter, final long defaultValue)
public Long getAsLong(final String key)
{
return QueryContexts.getAsLong(parameter, get(parameter), defaultValue);
return QueryContexts.getAsLong(key, get(key));
}
public HumanReadableBytes getAsHumanReadableBytes(final String parameter, final HumanReadableBytes defaultValue)
public long getAsLong(final String key, final long defaultValue)
{
return QueryContexts.getAsHumanReadableBytes(parameter, get(parameter), defaultValue);
return QueryContexts.getAsLong(key, get(key), defaultValue);
}
public HumanReadableBytes getAsHumanReadableBytes(final String key, final HumanReadableBytes defaultValue)
{
return QueryContexts.getAsHumanReadableBytes(key, get(key), defaultValue);
}
public float getAsFloat(final String key, final float defaultValue)
{
return QueryContexts.getAsFloat(key, get(key), defaultValue);
}
public <E extends Enum<E>> E getAsEnum(String key, Class<E> clazz, E defaultValue)
{
return QueryContexts.getAsEnum(key, get(key), clazz, defaultValue);
}
public Map<String, Object> getMergedParams()

View File

@ -31,6 +31,7 @@ import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.QueryableIndexStorageAdapter;
import javax.annotation.Nullable;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
@ -156,7 +157,7 @@ public class QueryContexts
public static <T> boolean isBySegment(Query<T> query, boolean defaultValue)
{
return parseBoolean(query, BY_SEGMENT_KEY, defaultValue);
return query.getContextBoolean(BY_SEGMENT_KEY, defaultValue);
}
public static <T> boolean isPopulateCache(Query<T> query)
@ -166,7 +167,7 @@ public class QueryContexts
public static <T> boolean isPopulateCache(Query<T> query, boolean defaultValue)
{
return parseBoolean(query, POPULATE_CACHE_KEY, defaultValue);
return query.getContextBoolean(POPULATE_CACHE_KEY, defaultValue);
}
public static <T> boolean isUseCache(Query<T> query)
@ -176,7 +177,7 @@ public class QueryContexts
public static <T> boolean isUseCache(Query<T> query, boolean defaultValue)
{
return parseBoolean(query, USE_CACHE_KEY, defaultValue);
return query.getContextBoolean(USE_CACHE_KEY, defaultValue);
}
public static <T> boolean isPopulateResultLevelCache(Query<T> query)
@ -186,7 +187,7 @@ public class QueryContexts
public static <T> boolean isPopulateResultLevelCache(Query<T> query, boolean defaultValue)
{
return parseBoolean(query, POPULATE_RESULT_LEVEL_CACHE_KEY, defaultValue);
return query.getContextBoolean(POPULATE_RESULT_LEVEL_CACHE_KEY, defaultValue);
}
public static <T> boolean isUseResultLevelCache(Query<T> query)
@ -196,22 +197,23 @@ public class QueryContexts
public static <T> boolean isUseResultLevelCache(Query<T> query, boolean defaultValue)
{
return parseBoolean(query, USE_RESULT_LEVEL_CACHE_KEY, defaultValue);
return query.getContextBoolean(USE_RESULT_LEVEL_CACHE_KEY, defaultValue);
}
public static <T> boolean isFinalize(Query<T> query, boolean defaultValue)
{
return parseBoolean(query, FINALIZE_KEY, defaultValue);
return query.getContextBoolean(FINALIZE_KEY, defaultValue);
}
public static <T> boolean isSerializeDateTimeAsLong(Query<T> query, boolean defaultValue)
{
return parseBoolean(query, SERIALIZE_DATE_TIME_AS_LONG_KEY, defaultValue);
return query.getContextBoolean(SERIALIZE_DATE_TIME_AS_LONG_KEY, defaultValue);
}
public static <T> boolean isSerializeDateTimeAsLongInner(Query<T> query, boolean defaultValue)
{
return parseBoolean(query, SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY, defaultValue);
return query.getContextBoolean(SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY, defaultValue);
}
public static <T> Vectorize getVectorize(Query<T> query)
@ -221,7 +223,7 @@ public class QueryContexts
public static <T> Vectorize getVectorize(Query<T> query, Vectorize defaultValue)
{
return parseEnum(query, VECTORIZE_KEY, Vectorize.class, defaultValue);
return query.getQueryContext().getAsEnum(VECTORIZE_KEY, Vectorize.class, defaultValue);
}
public static <T> Vectorize getVectorizeVirtualColumns(Query<T> query)
@ -231,7 +233,7 @@ public class QueryContexts
public static <T> Vectorize getVectorizeVirtualColumns(Query<T> query, Vectorize defaultValue)
{
return parseEnum(query, VECTORIZE_VIRTUAL_COLUMNS_KEY, Vectorize.class, defaultValue);
return query.getQueryContext().getAsEnum(VECTORIZE_VIRTUAL_COLUMNS_KEY, Vectorize.class, defaultValue);
}
public static <T> int getVectorSize(Query<T> query)
@ -241,12 +243,12 @@ public class QueryContexts
public static <T> int getVectorSize(Query<T> query, int defaultSize)
{
return parseInt(query, VECTOR_SIZE_KEY, defaultSize);
return query.getQueryContext().getAsInt(VECTOR_SIZE_KEY, defaultSize);
}
public static <T> int getMaxSubqueryRows(Query<T> query, int defaultSize)
{
return parseInt(query, MAX_SUBQUERY_ROWS_KEY, defaultSize);
return query.getQueryContext().getAsInt(MAX_SUBQUERY_ROWS_KEY, defaultSize);
}
public static <T> int getUncoveredIntervalsLimit(Query<T> query)
@ -256,7 +258,7 @@ public class QueryContexts
public static <T> int getUncoveredIntervalsLimit(Query<T> query, int defaultValue)
{
return parseInt(query, UNCOVERED_INTERVALS_LIMIT_KEY, defaultValue);
return query.getQueryContext().getAsInt(UNCOVERED_INTERVALS_LIMIT_KEY, defaultValue);
}
public static <T> int getPriority(Query<T> query)
@ -266,38 +268,37 @@ public class QueryContexts
public static <T> int getPriority(Query<T> query, int defaultValue)
{
return parseInt(query, PRIORITY_KEY, defaultValue);
return query.getQueryContext().getAsInt(PRIORITY_KEY, defaultValue);
}
public static <T> String getLane(Query<T> query)
{
return (String) query.getContextValue(LANE_KEY);
return query.getQueryContext().getAsString(LANE_KEY);
}
public static <T> boolean getEnableParallelMerges(Query<T> query)
{
return parseBoolean(query, BROKER_PARALLEL_MERGE_KEY, DEFAULT_ENABLE_PARALLEL_MERGE);
return query.getContextBoolean(BROKER_PARALLEL_MERGE_KEY, DEFAULT_ENABLE_PARALLEL_MERGE);
}
public static <T> int getParallelMergeInitialYieldRows(Query<T> query, int defaultValue)
{
return parseInt(query, BROKER_PARALLEL_MERGE_INITIAL_YIELD_ROWS_KEY, defaultValue);
return query.getQueryContext().getAsInt(BROKER_PARALLEL_MERGE_INITIAL_YIELD_ROWS_KEY, defaultValue);
}
public static <T> int getParallelMergeSmallBatchRows(Query<T> query, int defaultValue)
{
return parseInt(query, BROKER_PARALLEL_MERGE_SMALL_BATCH_ROWS_KEY, defaultValue);
return query.getQueryContext().getAsInt(BROKER_PARALLEL_MERGE_SMALL_BATCH_ROWS_KEY, defaultValue);
}
public static <T> int getParallelMergeParallelism(Query<T> query, int defaultValue)
{
return parseInt(query, BROKER_PARALLELISM, defaultValue);
return query.getQueryContext().getAsInt(BROKER_PARALLELISM, defaultValue);
}
public static <T> boolean getEnableJoinFilterRewriteValueColumnFilters(Query<T> query)
{
return parseBoolean(
query,
return query.getContextBoolean(
JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY,
DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS
);
@ -305,8 +306,7 @@ public class QueryContexts
public static <T> boolean getEnableRewriteJoinToFilter(Query<T> query)
{
return parseBoolean(
query,
return query.getContextBoolean(
REWRITE_JOIN_TO_FILTER_ENABLE_KEY,
DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER
);
@ -314,32 +314,32 @@ public class QueryContexts
public static <T> long getJoinFilterRewriteMaxSize(Query<T> query)
{
return parseLong(query, JOIN_FILTER_REWRITE_MAX_SIZE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE);
return query.getQueryContext().getAsLong(JOIN_FILTER_REWRITE_MAX_SIZE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE);
}
public static <T> boolean getEnableJoinFilterPushDown(Query<T> query)
{
return parseBoolean(query, JOIN_FILTER_PUSH_DOWN_KEY, DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN);
return query.getContextBoolean(JOIN_FILTER_PUSH_DOWN_KEY, DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN);
}
public static <T> boolean getEnableJoinFilterRewrite(Query<T> query)
{
return parseBoolean(query, JOIN_FILTER_REWRITE_ENABLE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE);
return query.getContextBoolean(JOIN_FILTER_REWRITE_ENABLE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE);
}
public static <T> boolean getEnableJoinLeftScanDirect(Map<String, Object> context)
public static boolean getEnableJoinLeftScanDirect(Map<String, Object> context)
{
return parseBoolean(context, SQL_JOIN_LEFT_SCAN_DIRECT, DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT);
}
public static <T> boolean isSecondaryPartitionPruningEnabled(Query<T> query)
{
return parseBoolean(query, SECONDARY_PARTITION_PRUNING_KEY, DEFAULT_SECONDARY_PARTITION_PRUNING);
return query.getContextBoolean(SECONDARY_PARTITION_PRUNING_KEY, DEFAULT_SECONDARY_PARTITION_PRUNING);
}
public static <T> boolean isDebug(Query<T> query)
{
return parseBoolean(query, ENABLE_DEBUG, DEFAULT_ENABLE_DEBUG);
return query.getContextBoolean(ENABLE_DEBUG, DEFAULT_ENABLE_DEBUG);
}
public static boolean isDebug(Map<String, Object> queryContext)
@ -364,11 +364,10 @@ public class QueryContexts
public static <T> Query<T> withMaxScatterGatherBytes(Query<T> query, long maxScatterGatherBytesLimit)
{
Object obj = query.getContextValue(MAX_SCATTER_GATHER_BYTES_KEY);
if (obj == null) {
Long curr = query.getQueryContext().getAsLong(MAX_SCATTER_GATHER_BYTES_KEY);
if (curr == null) {
return query.withOverriddenContext(ImmutableMap.of(MAX_SCATTER_GATHER_BYTES_KEY, maxScatterGatherBytesLimit));
} else {
long curr = ((Number) obj).longValue();
if (curr > maxScatterGatherBytesLimit) {
throw new IAE(
"configured [%s = %s] is more than enforced limit of [%s].",
@ -399,12 +398,12 @@ public class QueryContexts
public static <T> long getMaxQueuedBytes(Query<T> query, long defaultValue)
{
return parseLong(query, MAX_QUEUED_BYTES_KEY, defaultValue);
return query.getQueryContext().getAsLong(MAX_QUEUED_BYTES_KEY, defaultValue);
}
public static <T> long getMaxScatterGatherBytes(Query<T> query)
{
return parseLong(query, MAX_SCATTER_GATHER_BYTES_KEY, Long.MAX_VALUE);
return query.getQueryContext().getAsLong(MAX_SCATTER_GATHER_BYTES_KEY, Long.MAX_VALUE);
}
public static <T> boolean hasTimeout(Query<T> query)
@ -420,11 +419,11 @@ public class QueryContexts
public static <T> long getTimeout(Query<T> query, long defaultTimeout)
{
try {
final long timeout = parseLong(query, TIMEOUT_KEY, defaultTimeout);
final long timeout = query.getQueryContext().getAsLong(TIMEOUT_KEY, defaultTimeout);
Preconditions.checkState(timeout >= 0, "Timeout must be a non negative value, but was [%s]", timeout);
return timeout;
}
catch (NumberFormatException e) {
catch (IAE e) {
throw new BadQueryContextException(e);
}
}
@ -441,14 +440,14 @@ public class QueryContexts
static <T> long getDefaultTimeout(Query<T> query)
{
final long defaultTimeout = parseLong(query, DEFAULT_TIMEOUT_KEY, DEFAULT_TIMEOUT_MILLIS);
final long defaultTimeout = query.getQueryContext().getAsLong(DEFAULT_TIMEOUT_KEY, DEFAULT_TIMEOUT_MILLIS);
Preconditions.checkState(defaultTimeout >= 0, "Timeout must be a non negative value, but was [%s]", defaultTimeout);
return defaultTimeout;
}
public static <T> int getNumRetriesOnMissingSegments(Query<T> query, int defaultValue)
{
return query.getContextValue(NUM_RETRIES_ON_MISSING_SEGMENTS_KEY, defaultValue);
return query.getQueryContext().getAsInt(NUM_RETRIES_ON_MISSING_SEGMENTS_KEY, defaultValue);
}
public static <T> boolean allowReturnPartialResults(Query<T> query, boolean defaultValue)
@ -461,39 +460,24 @@ public class QueryContexts
return queryContext == null ? null : (String) queryContext.get(BROKER_SERVICE_NAME);
}
static <T> long parseLong(Query<T> query, String key, long defaultValue)
{
return getAsLong(key, query.getContextValue(key), defaultValue);
}
@SuppressWarnings("unused")
static <T> long parseLong(Map<String, Object> context, String key, long defaultValue)
{
return getAsLong(key, context.get(key), defaultValue);
}
static <T> int parseInt(Query<T> query, String key, int defaultValue)
{
return getAsInt(key, query.getContextValue(key), defaultValue);
}
static int parseInt(Map<String, Object> context, String key, int defaultValue)
{
return getAsInt(key, context.get(key), defaultValue);
}
static <T> boolean parseBoolean(Query<T> query, String key, boolean defaultValue)
{
return getAsBoolean(key, query.getContextValue(key), defaultValue);
}
static boolean parseBoolean(Map<String, Object> context, String key, boolean defaultValue)
{
return getAsBoolean(key, context.get(key), defaultValue);
}
public static String getAsString(
final String parameter,
final String key,
final Object value,
final String defaultValue
)
@ -503,7 +487,24 @@ public class QueryContexts
} else if (value instanceof String) {
return (String) value;
} else {
throw new IAE("Expected parameter [%s] to be String", parameter);
throw new IAE("Expected key [%s] to be a String, but got [%s]", key, value.getClass().getName());
}
}
@Nullable
public static Boolean getAsBoolean(
final String parameter,
final Object value
)
{
if (value == null) {
return null;
} else if (value instanceof String) {
return Boolean.parseBoolean((String) value);
} else if (value instanceof Boolean) {
return (Boolean) value;
} else {
throw new IAE("Expected parameter [%s] to be a Boolean, but got [%s]", parameter, value.getClass().getName());
}
}
@ -512,20 +513,32 @@ public class QueryContexts
* to be {@code null}, a string or a {@code Boolean} object.
*/
public static boolean getAsBoolean(
final String parameter,
final String key,
final Object value,
final boolean defaultValue
)
{
if (value == null) {
return defaultValue;
} else if (value instanceof String) {
return Boolean.parseBoolean((String) value);
} else if (value instanceof Boolean) {
return (Boolean) value;
} else {
throw new IAE("Expected parameter [%s] to be a boolean", parameter);
Boolean val = getAsBoolean(key, value);
return val == null ? defaultValue : val;
}
@Nullable
public static Integer getAsInt(String key, Object value)
{
if (value == null) {
return null;
} else if (value instanceof Number) {
return ((Number) value).intValue();
} else if (value instanceof String) {
try {
return Numbers.parseInt(value);
}
catch (NumberFormatException ignored) {
throw new IAE("Expected key [%s] in integer format, but got [%s]", key, value);
}
}
throw new IAE("Expected key [%s] to be an Integer, but got [%s]", key, value.getClass().getName());
}
/**
@ -533,20 +546,31 @@ public class QueryContexts
* to be {@code null}, a string or a {@code Number} object.
*/
public static int getAsInt(
final String parameter,
final String ke,
final Object value,
final int defaultValue
)
{
if (value == null) {
return defaultValue;
} else if (value instanceof String) {
return Numbers.parseInt(value);
} else if (value instanceof Number) {
return ((Number) value).intValue();
} else {
throw new IAE("Expected parameter [%s] to be an integer", parameter);
Integer val = getAsInt(ke, value);
return val == null ? defaultValue : val;
}
@Nullable
public static Long getAsLong(String key, Object value)
{
if (value == null) {
return null;
} else if (value instanceof Number) {
return ((Number) value).longValue();
} else if (value instanceof String) {
try {
return Numbers.parseLong(value);
}
catch (NumberFormatException ignored) {
throw new IAE("Expected key [%s] in long format, but got [%s]", key, value);
}
}
throw new IAE("Expected key [%s] to be a Long, but got [%s]", key, value.getClass().getName());
}
/**
@ -554,19 +578,13 @@ public class QueryContexts
* to be {@code null}, a string or a {@code Number} object.
*/
public static long getAsLong(
final String parameter,
final String key,
final Object value,
final long defaultValue)
final long defaultValue
)
{
if (value == null) {
return defaultValue;
} else if (value instanceof String) {
return Numbers.parseLong(value);
} else if (value instanceof Number) {
return ((Number) value).longValue();
} else {
throw new IAE("Expected parameter [%s] to be a long", parameter);
}
Long val = getAsLong(key, value);
return val == null ? defaultValue : val;
}
public static HumanReadableBytes getAsHumanReadableBytes(
@ -580,10 +598,32 @@ public class QueryContexts
} else if (value instanceof Number) {
return HumanReadableBytes.valueOf(Numbers.parseLong(value));
} else if (value instanceof String) {
return new HumanReadableBytes((String) value);
} else {
throw new IAE("Expected parameter [%s] to be in human readable format", parameter);
try {
return HumanReadableBytes.valueOf(HumanReadableBytes.parse((String) value));
}
catch (IAE e) {
throw new IAE("Expected key [%s] in human readable format, but got [%s]", parameter, value);
}
}
throw new IAE("Expected key [%s] to be a human readable number, but got [%s]", parameter, value.getClass().getName());
}
public static float getAsFloat(String key, Object value, float defaultValue)
{
if (null == value) {
return defaultValue;
} else if (value instanceof Number) {
return ((Number) value).floatValue();
} else if (value instanceof String) {
try {
return Float.parseFloat((String) value);
}
catch (NumberFormatException ignored) {
throw new IAE("Expected key [%s] in float format, but got [%s]", key, value);
}
}
throw new IAE("Expected key [%s] to be a Float, but got [%s]", key, value.getClass().getName());
}
public static Map<String, Object> override(
@ -604,18 +644,31 @@ public class QueryContexts
{
}
static <T, E extends Enum<E>> E parseEnum(Query<T> query, String key, Class<E> clazz, E defaultValue)
public static <E extends Enum<E>> E getAsEnum(String key, Object val, Class<E> clazz, E defaultValue)
{
Object val = query.getContextValue(key);
if (val == null) {
return defaultValue;
}
try {
if (val instanceof String) {
return Enum.valueOf(clazz, StringUtils.toUpperCase((String) val));
} else if (val instanceof Boolean) {
return Enum.valueOf(clazz, StringUtils.toUpperCase(String.valueOf(val)));
} else {
throw new ISE("Unknown type [%s]. Cannot parse!", val.getClass());
}
}
catch (IllegalArgumentException e) {
throw new IAE("Expected key [%s] must be value of enum [%s], but got [%s].",
key,
clazz.getName(),
val.toString());
}
throw new ISE(
"Expected key [%s] must be type of [%s], actual type is [%s].",
key,
clazz.getName(),
val.getClass()
);
}
}

View File

@ -748,7 +748,7 @@ public class GroupByQuery extends BaseQuery<ResultRow>
@Nullable
private DateTime computeUniversalTimestamp()
{
final String timestampStringFromContext = getContextValue(CTX_KEY_FUDGE_TIMESTAMP, "");
final String timestampStringFromContext = getQueryContext().getAsString(CTX_KEY_FUDGE_TIMESTAMP, "");
final Granularity granularity = getGranularity();
if (!timestampStringFromContext.isEmpty()) {

View File

@ -335,25 +335,25 @@ public class GroupByQueryConfig
public GroupByQueryConfig withOverrides(final GroupByQuery query)
{
final GroupByQueryConfig newConfig = new GroupByQueryConfig();
newConfig.defaultStrategy = query.getContextValue(CTX_KEY_STRATEGY, getDefaultStrategy());
newConfig.singleThreaded = query.getContextBoolean(CTX_KEY_IS_SINGLE_THREADED, isSingleThreaded());
newConfig.defaultStrategy = query.getQueryContext().getAsString(CTX_KEY_STRATEGY, getDefaultStrategy());
newConfig.singleThreaded = query.getQueryContext().getAsBoolean(CTX_KEY_IS_SINGLE_THREADED, isSingleThreaded());
newConfig.maxIntermediateRows = Math.min(
query.getContextValue(CTX_KEY_MAX_INTERMEDIATE_ROWS, getMaxIntermediateRows()),
query.getQueryContext().getAsInt(CTX_KEY_MAX_INTERMEDIATE_ROWS, getMaxIntermediateRows()),
getMaxIntermediateRows()
);
newConfig.maxResults = Math.min(
query.getContextValue(CTX_KEY_MAX_RESULTS, getMaxResults()),
query.getQueryContext().getAsInt(CTX_KEY_MAX_RESULTS, getMaxResults()),
getMaxResults()
);
newConfig.bufferGrouperMaxSize = Math.min(
query.getContextValue(CTX_KEY_BUFFER_GROUPER_MAX_SIZE, getBufferGrouperMaxSize()),
query.getQueryContext().getAsInt(CTX_KEY_BUFFER_GROUPER_MAX_SIZE, getBufferGrouperMaxSize()),
getBufferGrouperMaxSize()
);
newConfig.bufferGrouperMaxLoadFactor = query.getContextValue(
newConfig.bufferGrouperMaxLoadFactor = query.getQueryContext().getAsFloat(
CTX_KEY_BUFFER_GROUPER_MAX_LOAD_FACTOR,
getBufferGrouperMaxLoadFactor()
);
newConfig.bufferGrouperInitialBuckets = query.getContextValue(
newConfig.bufferGrouperInitialBuckets = query.getQueryContext().getAsInt(
CTX_KEY_BUFFER_GROUPER_INITIAL_BUCKETS,
getBufferGrouperInitialBuckets()
);
@ -362,7 +362,7 @@ public class GroupByQueryConfig
// choose a default value lower than the max allowed when the context key is missing in the client query.
newConfig.maxOnDiskStorage = HumanReadableBytes.valueOf(
Math.min(
query.getContextHumanReadableBytes(CTX_KEY_MAX_ON_DISK_STORAGE, getDefaultOnDiskStorage()).getBytes(),
query.getContextAsHumanReadableBytes(CTX_KEY_MAX_ON_DISK_STORAGE, getDefaultOnDiskStorage()).getBytes(),
getMaxOnDiskStorage().getBytes()
)
);
@ -378,11 +378,11 @@ public class GroupByQueryConfig
CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY,
isForcePushDownNestedQuery()
);
newConfig.intermediateCombineDegree = query.getContextValue(
newConfig.intermediateCombineDegree = query.getQueryContext().getAsInt(
CTX_KEY_INTERMEDIATE_COMBINE_DEGREE,
getIntermediateCombineDegree()
);
newConfig.numParallelCombineThreads = query.getContextValue(
newConfig.numParallelCombineThreads = query.getQueryContext().getAsInt(
CTX_KEY_NUM_PARALLEL_COMBINE_THREADS,
getNumParallelCombineThreads()
);

View File

@ -96,7 +96,7 @@ public class GroupByQueryEngine
"Null storage adapter found. Probably trying to issue a query against a segment being memory unmapped."
);
}
if (!query.getContextValue(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true)) {
if (!query.getContextBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true)) {
throw new UOE(
"GroupBy v1 does not support %s as false. Set %s to true or use groupBy v2",
GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING,

View File

@ -100,7 +100,7 @@ public class GroupByQueryHelper
);
final IncrementalIndex index;
final boolean sortResults = query.getContextValue(CTX_KEY_SORT_RESULTS, true);
final boolean sortResults = query.getContextBoolean(CTX_KEY_SORT_RESULTS, true);
// All groupBy dimensions are strings, for now.
final List<DimensionSchema> dimensionSchemas = new ArrayList<>();
@ -118,7 +118,7 @@ public class GroupByQueryHelper
final AppendableIndexBuilder indexBuilder;
if (query.getContextValue("useOffheap", false)) {
if (query.getContextBoolean("useOffheap", false)) {
throw new UnsupportedOperationException(
"The 'useOffheap' option is no longer available for groupBy v1. Please move to the newer groupBy engine, "
+ "which always operates off-heap, by removing any custom 'druid.query.groupBy.defaultStrategy' runtime "

View File

@ -141,7 +141,7 @@ public class GroupByQueryEngineV2
try {
final String fudgeTimestampString = NullHandling.emptyToNullIfNeeded(
query.getContextValue(GroupByStrategyV2.CTX_KEY_FUDGE_TIMESTAMP, null)
query.getQueryContext().getAsString(GroupByStrategyV2.CTX_KEY_FUDGE_TIMESTAMP)
);
final DateTime fudgeTimestamp = fudgeTimestampString == null

View File

@ -232,9 +232,9 @@ public class DefaultLimitSpec implements LimitSpec
}
if (!sortingNeeded) {
String timestampField = query.getContextValue(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD);
String timestampField = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD);
if (timestampField != null && !timestampField.isEmpty()) {
int timestampResultFieldIndex = query.getContextValue(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX);
int timestampResultFieldIndex = query.getQueryContext().getAsInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX);
sortingNeeded = query.getContextSortByDimsFirst()
? timestampResultFieldIndex != query.getDimensions().size() - 1
: timestampResultFieldIndex != 0;

View File

@ -221,7 +221,7 @@ public class GroupByStrategyV2 implements GroupByStrategy
Granularity granularity = query.getGranularity();
List<DimensionSpec> dimensionSpecs = query.getDimensions();
// the CTX_TIMESTAMP_RESULT_FIELD is set in DruidQuery.java
final String timestampResultField = query.getContextValue(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD);
final String timestampResultField = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD);
final boolean hasTimestampResultField = (timestampResultField != null && !timestampResultField.isEmpty())
&& query.getContextBoolean(CTX_KEY_OUTERMOST, true)
&& !query.isApplyLimitPushDown();
@ -258,7 +258,7 @@ public class GroupByStrategyV2 implements GroupByStrategy
granularity = timestampResultFieldGranularity;
// when timestampResultField is the last dimension, should set sortByDimsFirst=true,
// otherwise the downstream is sorted by row's timestamp first which makes the final ordering not as expected
timestampResultFieldIndex = query.getContextValue(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX);
timestampResultFieldIndex = query.getQueryContext().getAsInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX);
if (!query.getContextSortByDimsFirst() && timestampResultFieldIndex == query.getDimensions().size() - 1) {
context.put(GroupByQuery.CTX_KEY_SORT_BY_DIMS_FIRST, true);
}

View File

@ -264,7 +264,7 @@ public class ScanQuery extends BaseQuery<ScanResultValue>
private Integer validateAndGetMaxRowsQueuedForOrdering()
{
final Integer maxRowsQueuedForOrdering =
getContextValue(ScanQueryConfig.CTX_KEY_MAX_ROWS_QUEUED_FOR_ORDERING, null);
getQueryContext().getAsInt(ScanQueryConfig.CTX_KEY_MAX_ROWS_QUEUED_FOR_ORDERING);
Preconditions.checkArgument(
maxRowsQueuedForOrdering == null || maxRowsQueuedForOrdering > 0,
"maxRowsQueuedForOrdering must be greater than 0"
@ -275,7 +275,7 @@ public class ScanQuery extends BaseQuery<ScanResultValue>
private Integer validateAndGetMaxSegmentPartitionsOrderedInMemory()
{
final Integer maxSegmentPartitionsOrderedInMemory =
getContextValue(ScanQueryConfig.CTX_KEY_MAX_SEGMENT_PARTITIONS_FOR_ORDERING, null);
getQueryContext().getAsInt(ScanQueryConfig.CTX_KEY_MAX_SEGMENT_PARTITIONS_FOR_ORDERING);
Preconditions.checkArgument(
maxSegmentPartitionsOrderedInMemory == null || maxSegmentPartitionsOrderedInMemory > 0,
"maxRowsQueuedForOrdering must be greater than 0"

View File

@ -55,7 +55,7 @@ public class SearchQueryConfig
{
final SearchQueryConfig newConfig = new SearchQueryConfig();
newConfig.maxSearchLimit = query.getLimit();
newConfig.searchStrategy = query.getContextValue(CTX_KEY_STRATEGY, searchStrategy);
newConfig.searchStrategy = query.getQueryContext().getAsString(CTX_KEY_STRATEGY, searchStrategy);
return newConfig;
}
}

View File

@ -21,7 +21,6 @@ package org.apache.druid.query.select;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.google.common.collect.Ordering;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.Query;
@ -117,30 +116,6 @@ public class SelectQuery implements Query<Object>
throw new RuntimeException(REMOVED_ERROR_MESSAGE);
}
@Override
public <ContextType> ContextType getContextValue(String key)
{
throw new RuntimeException(REMOVED_ERROR_MESSAGE);
}
@Override
public <ContextType> ContextType getContextValue(String key, ContextType defaultValue)
{
throw new RuntimeException(REMOVED_ERROR_MESSAGE);
}
@Override
public boolean getContextBoolean(String key, boolean defaultValue)
{
throw new RuntimeException(REMOVED_ERROR_MESSAGE);
}
@Override
public HumanReadableBytes getContextHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{
throw new RuntimeException(REMOVED_ERROR_MESSAGE);
}
@Override
public boolean isDescending()
{

View File

@ -233,8 +233,8 @@ public class TimeBoundaryQueryQueryToolChest
if (query.isMinTime() || query.isMaxTime()) {
RowSignature.Builder builder = RowSignature.builder();
String outputName = query.isMinTime() ?
query.getContextValue(TimeBoundaryQuery.MIN_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MIN_TIME) :
query.getContextValue(TimeBoundaryQuery.MAX_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MAX_TIME);
query.getQueryContext().getAsString(TimeBoundaryQuery.MIN_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MIN_TIME) :
query.getQueryContext().getAsString(TimeBoundaryQuery.MAX_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MAX_TIME);
return builder.add(outputName, ColumnType.LONG).build();
}
return super.resultArraySignature(query);

View File

@ -159,7 +159,7 @@ public class TimeseriesQuery extends BaseQuery<Result<TimeseriesResultValue>>
public String getTimestampResultField()
{
return getContextValue(CTX_TIMESTAMP_RESULT_FIELD);
return getQueryContext().getAsString(CTX_TIMESTAMP_RESULT_FIELD);
}
public boolean isSkipEmptyBuckets()

View File

@ -574,7 +574,7 @@ public class TopNQueryQueryToolChest extends QueryToolChest<Result<TopNResultVal
}
final TopNQuery query = (TopNQuery) input;
final int minTopNThreshold = query.getContextValue("minTopNThreshold", config.getMinTopNThreshold());
final int minTopNThreshold = query.getQueryContext().getAsInt("minTopNThreshold", config.getMinTopNThreshold());
if (query.getThreshold() > minTopNThreshold) {
return runner.run(queryPlus, responseContext);
}

View File

@ -26,7 +26,6 @@ import nl.jqno.equalsverifier.Warning;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
@ -80,11 +79,13 @@ public class QueryContextTest
public void testGetString()
{
final QueryContext context = new QueryContext(
ImmutableMap.of("key", "val")
ImmutableMap.of("key", "val",
"key2", 2)
);
Assert.assertEquals("val", context.get("key"));
Assert.assertEquals("val", context.getAsString("key"));
Assert.assertEquals("2", context.getAsString("key2"));
Assert.assertNull(context.getAsString("non-exist"));
}
@ -109,13 +110,16 @@ public class QueryContextTest
final QueryContext context = new QueryContext(
ImmutableMap.of(
"key1", "100",
"key2", 100
"key2", 100,
"key3", "abc"
)
);
Assert.assertEquals(100, context.getAsInt("key1", 0));
Assert.assertEquals(100, context.getAsInt("key2", 0));
Assert.assertEquals(0, context.getAsInt("non-exist", 0));
Assert.assertThrows(IAE.class, () -> context.getAsInt("key3", 5));
}
@Test
@ -124,24 +128,57 @@ public class QueryContextTest
final QueryContext context = new QueryContext(
ImmutableMap.of(
"key1", "100",
"key2", 100
"key2", 100,
"key3", "abc"
)
);
Assert.assertEquals(100L, context.getAsLong("key1", 0));
Assert.assertEquals(100L, context.getAsLong("key2", 0));
Assert.assertEquals(0L, context.getAsLong("non-exist", 0));
Assert.assertThrows(IAE.class, () -> context.getAsLong("key3", 5));
}
@Test
public void testGetFloat()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"f1", "500",
"f2", 500,
"f3", 500.1,
"f4", "ab"
)
);
Assert.assertEquals(0, Float.compare(500, context.getAsFloat("f1", 100)));
Assert.assertEquals(0, Float.compare(500, context.getAsFloat("f2", 100)));
Assert.assertEquals(0, Float.compare(500.1f, context.getAsFloat("f3", 100)));
Assert.assertThrows(IAE.class, () -> context.getAsLong("f4", 5));
}
@Test
public void testGetHumanReadableBytes()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"maxOnDiskStorage", "500M"
)
ImmutableMap.<String, Object>builder()
.put("m1", 500_000_000)
.put("m2", "500M")
.put("m3", "500Mi")
.put("m4", "500MiB")
.put("m5", "500000000")
.put("m6", "abc")
.build()
);
Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("maxOnDiskStorage", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m1", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m2", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500 * 1024 * 1024L, context.getAsHumanReadableBytes("m3", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500 * 1024 * 1024L, context.getAsHumanReadableBytes("m4", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m5", HumanReadableBytes.ZERO).getBytes());
Assert.assertThrows(IAE.class, () -> context.getAsHumanReadableBytes("m6", HumanReadableBytes.ZERO));
}
@Test
@ -390,14 +427,14 @@ public class QueryContextTest
}
@Override
public HumanReadableBytes getContextHumanReadableBytes(String key, HumanReadableBytes defaultValue)
public HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue)
{
if (null == context || !context.containsKey(key)) {
return defaultValue;
}
Object value = context.get(key);
if (value instanceof Number) {
return HumanReadableBytes.valueOf(Numbers.parseLong(value));
return HumanReadableBytes.valueOf(((Number) value).longValue());
} else if (value instanceof String) {
return new HumanReadableBytes((String) value);
} else {
@ -463,15 +500,6 @@ public class QueryContextTest
return new LegacyContextQuery(contextOverride);
}
@Override
public Object getContextValue(String key, Object defaultValue)
{
if (!context.containsKey(key)) {
return defaultValue;
}
return context.get(key);
}
@Override
public Object getContextValue(String key)
{

View File

@ -137,15 +137,19 @@ public class QueryContextsTest
@Test
public void testDefaultInSubQueryThreshold()
{
Assert.assertEquals(QueryContexts.DEFAULT_IN_SUB_QUERY_THRESHOLD,
QueryContexts.getInSubQueryThreshold(ImmutableMap.of()));
Assert.assertEquals(
QueryContexts.DEFAULT_IN_SUB_QUERY_THRESHOLD,
QueryContexts.getInSubQueryThreshold(ImmutableMap.of())
);
}
@Test
public void testDefaultPlanTimeBoundarySql()
{
Assert.assertEquals(QueryContexts.DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING,
QueryContexts.isTimeBoundaryPlanningEnabled(ImmutableMap.of()));
Assert.assertEquals(
QueryContexts.DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING,
QueryContexts.isTimeBoundaryPlanningEnabled(ImmutableMap.of())
);
}
@Test
@ -279,8 +283,43 @@ public class QueryContextsTest
@Test
public void testGetAsHumanReadableBytes()
{
Assert.assertEquals(new HumanReadableBytes("500M").getBytes(), QueryContexts.getAsHumanReadableBytes("maxOnDiskStorage", 500_000_000, HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(new HumanReadableBytes("500M").getBytes(), QueryContexts.getAsHumanReadableBytes("maxOnDiskStorage", "500000000", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(new HumanReadableBytes("500M").getBytes(), QueryContexts.getAsHumanReadableBytes("maxOnDiskStorage", "500M", HumanReadableBytes.ZERO).getBytes());
Assert.assertEquals(
new HumanReadableBytes("500M").getBytes(),
QueryContexts.getAsHumanReadableBytes("maxOnDiskStorage", 500_000_000, HumanReadableBytes.ZERO)
.getBytes()
);
Assert.assertEquals(
new HumanReadableBytes("500M").getBytes(),
QueryContexts.getAsHumanReadableBytes("maxOnDiskStorage", "500000000", HumanReadableBytes.ZERO)
.getBytes()
);
Assert.assertEquals(
new HumanReadableBytes("500M").getBytes(),
QueryContexts.getAsHumanReadableBytes("maxOnDiskStorage", "500M", HumanReadableBytes.ZERO)
.getBytes()
);
}
@Test
public void testGetEnum()
{
Query<?> query = new TestQuery(
new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
false,
ImmutableMap.of("e1", "FORCE",
"e2", "INVALID_ENUM"
)
);
Assert.assertEquals(
QueryContexts.Vectorize.FORCE,
query.getQueryContext().getAsEnum("e1", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE)
);
Assert.assertThrows(
IAE.class,
() -> query.getQueryContext().getAsEnum("e2", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE)
);
}
}

View File

@ -102,10 +102,10 @@ public class DataSourceMetadataQueryTest
), Query.class
);
Assert.assertEquals((Integer) 1, serdeQuery.getContextValue(QueryContexts.PRIORITY_KEY));
Assert.assertEquals(true, serdeQuery.getContextValue(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals("true", serdeQuery.getContextValue(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getContextValue(QueryContexts.FINALIZE_KEY));
Assert.assertEquals((Integer) 1, serdeQuery.getQueryContext().getAsInt(QueryContexts.PRIORITY_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.FINALIZE_KEY));
Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.USE_CACHE_KEY, false));
Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.POPULATE_CACHE_KEY, false));
Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.FINALIZE_KEY, false));

View File

@ -78,11 +78,10 @@ public class TimeBoundaryQueryTest
), TimeBoundaryQuery.class
);
Assert.assertEquals(new Integer(1), serdeQuery.getContextValue(QueryContexts.PRIORITY_KEY));
Assert.assertEquals(true, serdeQuery.getContextValue(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getContextValue(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getContextValue(QueryContexts.FINALIZE_KEY));
Assert.assertEquals(new Integer(1), serdeQuery.getQueryContext().getAsInt(QueryContexts.PRIORITY_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.FINALIZE_KEY));
}
@Test
@ -117,9 +116,9 @@ public class TimeBoundaryQueryTest
);
Assert.assertEquals("1", serdeQuery.getContextValue(QueryContexts.PRIORITY_KEY));
Assert.assertEquals("true", serdeQuery.getContextValue(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals("true", serdeQuery.getContextValue(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals("true", serdeQuery.getContextValue(QueryContexts.FINALIZE_KEY));
Assert.assertEquals("1", serdeQuery.getQueryContext().getAsString(QueryContexts.PRIORITY_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.USE_CACHE_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.POPULATE_CACHE_KEY));
Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.FINALIZE_KEY));
}
}

View File

@ -163,7 +163,7 @@ public class DirectDruidClient<T> implements QueryRunner<T>
log.debug("Querying queryId[%s] url[%s]", query.getId(), url);
final long requestStartTimeNs = System.nanoTime();
final long timeoutAt = query.getContextValue(QUERY_FAIL_TIME);
final long timeoutAt = query.getQueryContext().getAsLong(QUERY_FAIL_TIME);
final long maxScatterGatherBytes = QueryContexts.getMaxScatterGatherBytes(query);
final AtomicLong totalBytesGathered = context.getTotalBytes();
final long maxQueuedBytes = QueryContexts.getMaxQueuedBytes(query, 0);

View File

@ -75,7 +75,7 @@ public class JsonParserIterator<T> implements Iterator<T>, Closeable
this.future = future;
this.url = url;
if (query != null) {
this.timeoutAt = query.<Long>getContextValue(DirectDruidClient.QUERY_FAIL_TIME, -1L);
this.timeoutAt = query.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME, -1L);
this.queryId = query.getId();
} else {
this.timeoutAt = -1;

View File

@ -161,7 +161,7 @@ public class SinkQuerySegmentWalker implements QuerySegmentWalker
}
final QueryToolChest<T, Query<T>> toolChest = factory.getToolchest();
final boolean skipIncrementalSegment = query.getContextValue(CONTEXT_SKIP_INCREMENTAL_SEGMENT, false);
final boolean skipIncrementalSegment = query.getContextBoolean(CONTEXT_SKIP_INCREMENTAL_SEGMENT, false);
final AtomicLong cpuTimeAccumulator = new AtomicLong(0L);
// Make sure this query type can handle the subquery, if present.

View File

@ -431,7 +431,7 @@ public class ClientQuerySegmentWalker implements QuerySegmentWalker
.emitCPUTimeMetric(emitter)
.postProcess(
objectMapper.convertValue(
query.<String>getContextValue("postProcessing"),
query.getQueryContext().getAsString("postProcessing"),
new TypeReference<PostProcessingOperator<T>>() {}
)
)

View File

@ -70,10 +70,10 @@ public class HiLoQueryLaningStrategy implements QueryLaningStrategy
// QueryContexts.getPriority gives a default, but it can parse the value to integer. Before calling QueryContexts.getPriority
// we make sure that priority has been set.
Integer priority = null;
if (null != theQuery.getContextValue(QueryContexts.PRIORITY_KEY)) {
if (theQuery.getContextValue(QueryContexts.PRIORITY_KEY) != null) {
priority = QueryContexts.getPriority(theQuery);
}
final String lane = theQuery.getContextValue(QueryContexts.LANE_KEY);
final String lane = theQuery.getQueryContext().getAsString(QueryContexts.LANE_KEY);
if (lane == null && priority != null && priority < 0) {
return Optional.of(LOW);
}

View File

@ -43,6 +43,7 @@ import org.junit.runner.RunWith;
import java.util.Optional;
import java.util.Set;
import static org.apache.druid.query.QueryContexts.DEFAULT_BY_SEGMENT;
import static org.easymock.EasyMock.expect;
import static org.easymock.EasyMock.replay;
import static org.easymock.EasyMock.reset;
@ -66,7 +67,7 @@ public class CachingClusteredClientCacheKeyManagerTest extends EasyMockSupport
public void setup()
{
expect(strategy.computeCacheKey(query)).andReturn(QUERY_CACHE_KEY).anyTimes();
expect(query.getContextValue(QueryContexts.BY_SEGMENT_KEY)).andReturn(false).anyTimes();
expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(false).anyTimes();
}
@After
@ -202,7 +203,7 @@ public class CachingClusteredClientCacheKeyManagerTest extends EasyMockSupport
{
expect(dataSourceAnalysis.isJoin()).andReturn(false);
reset(query);
expect(query.getContextValue(QueryContexts.BY_SEGMENT_KEY)).andReturn(true).anyTimes();
expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(true).anyTimes();
replayAll();
CachingClusteredClient.CacheKeyManager<Object> keyManager = makeKeyManager();
Set<SegmentServerSelector> selectors = ImmutableSet.of(
@ -271,7 +272,7 @@ public class CachingClusteredClientCacheKeyManagerTest extends EasyMockSupport
public void testSegmentQueryCacheKey_noCachingIfBySegment()
{
reset(query);
expect(query.getContextValue(QueryContexts.BY_SEGMENT_KEY)).andReturn(true).anyTimes();
expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(true).anyTimes();
replayAll();
byte[] cacheKey = makeKeyManager().computeSegmentLevelQueryCacheKey();
Assert.assertNull(cacheKey);

View File

@ -2298,11 +2298,11 @@ public class CachingClusteredClientTest
QueryPlus capturedQueryPlus = (QueryPlus) queryCapture.getValue();
Query capturedQuery = capturedQueryPlus.getQuery();
if (expectBySegment) {
Assert.assertEquals(true, capturedQuery.getContextValue(QueryContexts.BY_SEGMENT_KEY));
Assert.assertEquals(true, capturedQuery.getQueryContext().getAsBoolean(QueryContexts.BY_SEGMENT_KEY));
} else {
Assert.assertTrue(
capturedQuery.getContextValue(QueryContexts.BY_SEGMENT_KEY) == null ||
capturedQuery.getContextValue(QueryContexts.BY_SEGMENT_KEY).equals(false)
capturedQuery.getQueryContext().getAsBoolean(QueryContexts.BY_SEGMENT_KEY).equals(false)
);
}
}

View File

@ -307,7 +307,8 @@ public class JsonParserIteratorTest
{
Query<?> query = Mockito.mock(Query.class);
Mockito.when(query.getId()).thenReturn(queryId);
Mockito.when(query.getContextValue(ArgumentMatchers.eq(DirectDruidClient.QUERY_FAIL_TIME), ArgumentMatchers.eq(-1L)))
Mockito.when(query.getQueryContext().getAsLong(ArgumentMatchers.eq(DirectDruidClient.QUERY_FAIL_TIME),
ArgumentMatchers.eq(-1L)))
.thenReturn(timeoutAt);
return query;
}

View File

@ -58,7 +58,7 @@ public class SetAndVerifyContextQueryRunnerTest
// time + 1 at the time the method was called
// this means that after sleeping for 1 millis, the fail time should be less than the current time when checking
Assert.assertTrue(
System.currentTimeMillis() > (Long) transformed.getContextValue(DirectDruidClient.QUERY_FAIL_TIME)
System.currentTimeMillis() > transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME)
);
}
@ -85,7 +85,7 @@ public class SetAndVerifyContextQueryRunnerTest
Query<ScanResultValue> transformed = queryRunner.withTimeoutAndMaxScatterGatherBytes(query, defaultConfig);
// timeout is not set, default timeout has been set to long.max, make sure timeout is still in the future
Assert.assertEquals((Long) Long.MAX_VALUE, transformed.getContextValue(DirectDruidClient.QUERY_FAIL_TIME));
Assert.assertEquals((Long) Long.MAX_VALUE, transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME));
}
@Test
@ -107,7 +107,7 @@ public class SetAndVerifyContextQueryRunnerTest
// timeout is set to 0, so withTimeoutAndMaxScatterGatherBytes should set QUERY_FAIL_TIME to be the current
// time + max query timeout at the time the method was called
// since default is long max, expect long max since current time would overflow
Assert.assertEquals((Long) Long.MAX_VALUE, transformed.getContextValue(DirectDruidClient.QUERY_FAIL_TIME));
Assert.assertEquals((Long) Long.MAX_VALUE, transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME));
}
@Test
@ -137,7 +137,7 @@ public class SetAndVerifyContextQueryRunnerTest
// time + max query timeout at the time the method was called
// this means that the fail time should be greater than the current time when checking
Assert.assertTrue(
System.currentTimeMillis() < (Long) transformed.getContextValue(DirectDruidClient.QUERY_FAIL_TIME)
System.currentTimeMillis() < (Long) transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME)
);
}
}

View File

@ -1788,7 +1788,7 @@ public class SqlResourceTest extends CalciteTestBase
Assert.assertNotNull(queryContextException);
Assert.assertEquals(BadQueryContextException.ERROR_CODE, queryContextException.getErrorCode());
Assert.assertEquals(BadQueryContextException.ERROR_CLASS, queryContextException.getErrorClass());
Assert.assertTrue(queryContextException.getMessage().contains("For input string: \"2000'\""));
Assert.assertTrue(queryContextException.getMessage().contains("2000'"));
checkSqlRequestLog(false);
Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty());
}