Add native filter conversion for SCALAR_IN_ARRAY. (#16312)

* Add native filter conversion for SCALAR_IN_ARRAY.

Main changes:

1) Add an implementation of "toDruidFilter" in ScalarInArrayOperatorConversion.

2) Split up Expressions.literalToDruidExpression into two functions, so the first
   half (literalToExprEval) can be used by ScalarInArrayOperatorConversion to more
   efficiently create the list of match values.

* Fix type in time arithmetic conversion.

* Test updates.

* Update test cases to use null instead of '' in default-value mode.

* Switch test from msqIncompatible to compatible with a different result.

* Update one more test.

* Fix test.

* Update tests.

* Use ExprEvalWrapper to differentiate between empty string and null.

* Fix tests some more.

* Fix test.

* Additional comment.

* Style adjustment.

* Fix tests.

* trueValue -> actualValue.

* Use different approach, DruidLiteral instead of ExprEvalWrapper.

* Revert changes in ArrayOfDoublesSketchSqlAggregatorTest.
This commit is contained in:
Gian Merlino 2024-05-03 13:00:33 -07:00 committed by GitHub
parent 1b107ff695
commit 588d442422
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 726 additions and 152 deletions

View File

@ -26,6 +26,9 @@ import com.google.common.io.BaseEncoding;
import com.google.common.primitives.Chars;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
@ -180,6 +183,19 @@ public class DruidExpression
);
}
/**
* Create a literal expression from an {@link ExprEval}.
*/
public static DruidExpression ofLiteral(final DruidLiteral literal)
{
if (literal.type() != null && literal.type().is(ExprType.STRING)) {
return ofStringLiteral((String) literal.value());
} else {
final ColumnType evalColumnType = literal.type() != null ? ExpressionType.toColumnType(literal.type()) : null;
return ofLiteral(evalColumnType, ExprEval.ofType(literal.type(), literal.value()).toExpr().stringify());
}
}
public static DruidExpression ofStringLiteral(final String s)
{
return ofLiteral(ColumnType.STRING, stringLiteral(s));

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.expression;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import javax.annotation.Nullable;
/**
* Literal value, plus a {@link ExpressionType} that represents how to interpret the literal value.
*
* These are similar to {@link ExprEval}, but not identical: unlike {@link ExprEval}, string values in this class
* are not normalized through {@link NullHandling#emptyToNullIfNeeded(String)}. This allows us to differentiate
* between null and empty-string literals even when {@link NullHandling#replaceWithDefault()}.
*/
public class DruidLiteral
{
@Nullable
private final ExpressionType type;
@Nullable
private final Object value;
DruidLiteral(final ExpressionType type, @Nullable final Object value)
{
this.type = type;
this.value = value;
}
@Nullable
public ExpressionType type()
{
return type;
}
@Nullable
public Object value()
{
return value;
}
public DruidLiteral castTo(final ExpressionType toType)
{
if (type.equals(toType)) {
return this;
}
final ExprEval<?> castEval = ExprEval.ofType(type, value).castTo(toType);
return new DruidLiteral(castEval.type(), castEval.value());
}
}

View File

@ -38,6 +38,7 @@ import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.expression.TimestampFloorExprMacro;
import org.apache.druid.query.extraction.ExtractionFn;
@ -240,7 +241,8 @@ public class Expressions
} else if (rexNode instanceof RexCall) {
return rexCallToDruidExpression(plannerContext, rowSignature, rexNode, postAggregatorVisitor);
} else if (kind == SqlKind.LITERAL) {
return literalToDruidExpression(plannerContext, rexNode);
final DruidLiteral eval = calciteLiteralToDruidLiteral(plannerContext, rexNode);
return eval != null ? DruidExpression.ofLiteral(eval) : null;
} else {
// Can't translate.
return null;
@ -306,61 +308,85 @@ public class Expressions
}
}
/**
* Create a {@link DruidLiteral} from a literal {@link RexNode}. Necessary because Calcite represents literals using
* different Java classes than Druid does.
*
* @param plannerContext planner context
* @param rexNode Calcite literal
*
* @return converted literal, or null if the literal cannot be converted
*/
@Nullable
static DruidExpression literalToDruidExpression(
public static DruidLiteral calciteLiteralToDruidLiteral(
final PlannerContext plannerContext,
final RexNode rexNode
)
{
final SqlTypeName sqlTypeName = rexNode.getType().getSqlTypeName();
if (rexNode.isA(SqlKind.CAST)) {
if (SqlTypeFamily.DATE.contains(rexNode.getType())) {
// Cast to DATE suggests some timestamp flooring. We don't deal with that here, so return null.
return null;
}
final DruidLiteral innerLiteral =
calciteLiteralToDruidLiteral(plannerContext, ((RexCall) rexNode).getOperands().get(0));
if (innerLiteral == null) {
return null;
}
final ColumnType castToColumnType = Calcites.getColumnTypeForRelDataType(rexNode.getType());
if (castToColumnType == null) {
return null;
}
final ExpressionType castToExprType = ExpressionType.fromColumnType(castToColumnType);
if (castToExprType == null) {
return null;
}
return innerLiteral.castTo(castToExprType);
}
// Translate literal.
final ColumnType columnType = Calcites.getColumnTypeForRelDataType(rexNode.getType());
final SqlTypeName sqlTypeName = rexNode.getType().getSqlTypeName();
final DruidLiteral retVal;
if (RexLiteral.isNullLiteral(rexNode)) {
return DruidExpression.ofLiteral(columnType, DruidExpression.nullLiteral());
final ColumnType columnType = Calcites.getColumnTypeForRelDataType(rexNode.getType());
final ExpressionType expressionType = columnType == null ? null : ExpressionType.fromColumnTypeStrict(columnType);
retVal = new DruidLiteral(expressionType, null);
} else if (SqlTypeName.INT_TYPES.contains(sqlTypeName)) {
final Number number = (Number) RexLiteral.value(rexNode);
return DruidExpression.ofLiteral(
columnType,
number == null ? DruidExpression.nullLiteral() : DruidExpression.longLiteral(number.longValue())
);
retVal = new DruidLiteral(ExpressionType.LONG, number == null ? null : number.longValue());
} else if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
// Numeric, non-INT, means we represent it as a double.
final Number number = (Number) RexLiteral.value(rexNode);
return DruidExpression.ofLiteral(
columnType,
number == null ? DruidExpression.nullLiteral() : DruidExpression.doubleLiteral(number.doubleValue())
);
retVal = new DruidLiteral(ExpressionType.DOUBLE, number == null ? null : number.doubleValue());
} else if (SqlTypeFamily.INTERVAL_DAY_TIME == sqlTypeName.getFamily()) {
// Calcite represents DAY-TIME intervals in milliseconds.
final long milliseconds = ((Number) RexLiteral.value(rexNode)).longValue();
return DruidExpression.ofLiteral(columnType, DruidExpression.longLiteral(milliseconds));
retVal = new DruidLiteral(ExpressionType.LONG, milliseconds);
} else if (SqlTypeFamily.INTERVAL_YEAR_MONTH == sqlTypeName.getFamily()) {
// Calcite represents YEAR-MONTH intervals in months.
final long months = ((Number) RexLiteral.value(rexNode)).longValue();
return DruidExpression.ofLiteral(columnType, DruidExpression.longLiteral(months));
retVal = new DruidLiteral(ExpressionType.LONG, months);
} else if (SqlTypeName.STRING_TYPES.contains(sqlTypeName)) {
return DruidExpression.ofStringLiteral(RexLiteral.stringValue(rexNode));
final String s = RexLiteral.stringValue(rexNode);
retVal = new DruidLiteral(ExpressionType.STRING, s);
} else if (SqlTypeName.TIMESTAMP == sqlTypeName || SqlTypeName.DATE == sqlTypeName) {
if (RexLiteral.isNullLiteral(rexNode)) {
return DruidExpression.ofLiteral(columnType, DruidExpression.nullLiteral());
} else {
return DruidExpression.ofLiteral(
columnType,
DruidExpression.longLiteral(
Calcites.calciteDateTimeLiteralToJoda(rexNode, plannerContext.getTimeZone()).getMillis()
)
);
}
} else if (SqlTypeName.BOOLEAN == sqlTypeName) {
return DruidExpression.ofLiteral(
columnType,
DruidExpression.longLiteral(RexLiteral.booleanValue(rexNode) ? 1 : 0)
retVal = new DruidLiteral(
ExpressionType.LONG,
Calcites.calciteDateTimeLiteralToJoda(rexNode, plannerContext.getTimeZone()).getMillis()
);
} else if (SqlTypeName.BOOLEAN == sqlTypeName) {
retVal = new DruidLiteral(ExpressionType.LONG, RexLiteral.booleanValue(rexNode) ? 1L : 0L);
} else {
// Can't translate other literals.
return null;
}
return retVal;
}
/**
@ -647,8 +673,8 @@ public class Expressions
final DruidExpression rhsExpression = toDruidExpression(plannerContext, rowSignature, rhs);
final Expr rhsParsed = rhsExpression != null
? plannerContext.parseExpression(rhsExpression.getExpression())
: null;
? plannerContext.parseExpression(rhsExpression.getExpression())
: null;
// rhs must be a literal
if (rhsParsed == null || !rhsParsed.isLiteral()) {
return null;
@ -815,7 +841,9 @@ public class Expressions
}
} else if (rexNode instanceof RexCall) {
final SqlOperator operator = ((RexCall) rexNode).getOperator();
final SqlOperatorConversion conversion = plannerContext.getPlannerToolbox().operatorTable().lookupOperatorConversion(operator);
final SqlOperatorConversion conversion = plannerContext.getPlannerToolbox()
.operatorTable()
.lookupOperatorConversion(operator);
if (conversion == null) {
return null;

View File

@ -25,7 +25,6 @@ import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.Evals;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
@ -34,10 +33,8 @@ import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.query.filter.ArrayContainsElementFilter;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.EqualityFilter;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.NullFilter;
import org.apache.druid.query.filter.OrDimFilter;
import org.apache.druid.query.filter.TypedInFilter;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DruidExpression;
@ -158,27 +155,13 @@ public class ArrayOverlapOperatorConversion extends BaseExpressionDimFilterOpera
);
}
} else {
if (plannerContext.isUseBoundsAndSelectors() || NullHandling.replaceWithDefault() || !simpleExtractionExpr.isDirectColumnAccess()) {
final InDimFilter.ValuesSet valuesSet = InDimFilter.ValuesSet.create();
for (final Object arrayElement : arrayElements) {
valuesSet.add(Evals.asString(arrayElement));
}
return new InDimFilter(
simpleExtractionExpr.getSimpleExtraction().getColumn(),
valuesSet,
simpleExtractionExpr.getSimpleExtraction().getExtractionFn(),
null
);
} else {
return new TypedInFilter(
simpleExtractionExpr.getSimpleExtraction().getColumn(),
ExpressionType.toColumnType((ExpressionType) exprEval.type().getElementType()),
Arrays.asList(arrayElements),
null,
null
);
}
return ScalarInArrayOperatorConversion.makeInFilter(
plannerContext,
simpleExtractionExpr.getSimpleExtraction().getColumn(),
simpleExtractionExpr.getSimpleExtraction().getExtractionFn(),
Arrays.asList(arrayElements),
ExpressionType.toColumnType((ExpressionType) exprEval.type().getElementType())
);
}
}

View File

@ -19,32 +19,135 @@
package org.apache.druid.sql.calcite.expression.builtin;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.math.expr.Evals;
import org.apache.druid.query.extraction.ExtractionFn;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.TypedInFilter;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DirectOperatorConversion;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.DruidLiteral;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public class ScalarInArrayOperatorConversion extends DirectOperatorConversion
{
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("SCALAR_IN_ARRAY")
.operandTypeChecker(
OperandTypes.sequence(
"'SCALAR_IN_ARRAY(expr, array)'",
OperandTypes.or(
OperandTypes.family(SqlTypeFamily.CHARACTER),
OperandTypes.family(SqlTypeFamily.NUMERIC)
),
OperandTypes.family(SqlTypeFamily.ARRAY)
)
public static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("SCALAR_IN_ARRAY")
.operandTypeChecker(
OperandTypes.sequence(
"'SCALAR_IN_ARRAY(expr, array)'",
OperandTypes.or(
OperandTypes.family(SqlTypeFamily.CHARACTER),
OperandTypes.family(SqlTypeFamily.NUMERIC)
),
OperandTypes.family(SqlTypeFamily.ARRAY)
)
.returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE)
.build();
)
.returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE)
.build();
public ScalarInArrayOperatorConversion()
{
super(SQL_FUNCTION, "scalar_in_array");
}
@Nullable
@Override
public DimFilter toDruidFilter(
final PlannerContext plannerContext,
final RowSignature rowSignature,
@Nullable final VirtualColumnRegistry virtualColumnRegistry,
final RexNode rexNode
)
{
final RexCall call = (RexCall) rexNode;
final RexNode scalarOperand = call.getOperands().get(0);
final RexNode arrayOperand = call.getOperands().get(1);
final DruidExpression scalarExpression = Expressions.toDruidExpression(plannerContext, rowSignature, scalarOperand);
final String scalarColumn;
final ExtractionFn scalarExtractionFn;
if (scalarExpression == null) {
return null;
}
if (scalarExpression.isDirectColumnAccess()) {
scalarColumn = scalarExpression.getDirectColumn();
scalarExtractionFn = null;
} else if (scalarExpression.isSimpleExtraction() && plannerContext.isUseLegacyInFilter()) {
scalarColumn = scalarExpression.getSimpleExtraction().getColumn();
scalarExtractionFn = scalarExpression.getSimpleExtraction().getExtractionFn();
} else {
scalarColumn = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(
scalarExpression,
scalarExpression.getDruidType()
);
scalarExtractionFn = null;
}
if (Calcites.isLiteral(arrayOperand, true, true)) {
final RelDataType elementType = arrayOperand.getType().getComponentType();
final List<RexNode> arrayElements = ((RexCall) arrayOperand).getOperands();
final List<Object> arrayElementLiteralValues = new ArrayList<>(arrayElements.size());
for (final RexNode arrayElement : arrayElements) {
final DruidLiteral arrayElementEval = Expressions.calciteLiteralToDruidLiteral(plannerContext, arrayElement);
if (arrayElementEval == null) {
return null;
}
arrayElementLiteralValues.add(arrayElementEval.value());
}
return makeInFilter(
plannerContext,
scalarColumn,
scalarExtractionFn,
arrayElementLiteralValues,
Calcites.getColumnTypeForRelDataType(elementType)
);
}
return null;
}
/**
* Create an {@link InDimFilter} or {@link TypedInFilter} based on a list of provided values.
*/
public static DimFilter makeInFilter(
final PlannerContext plannerContext,
final String columnName,
@Nullable final ExtractionFn extractionFn,
final List<Object> matchValues,
final ColumnType matchValueType
)
{
if (plannerContext.isUseLegacyInFilter() || extractionFn != null) {
final InDimFilter.ValuesSet valuesSet = InDimFilter.ValuesSet.create();
for (final Object matchValue : matchValues) {
valuesSet.add(Evals.asString(matchValue));
}
return new InDimFilter(columnName, valuesSet, extractionFn, null);
} else {
return new TypedInFilter(columnName, matchValueType, matchValues, null, null);
}
}
}

View File

@ -100,7 +100,8 @@ public abstract class TimeArithmeticOperatorConversion implements SqlOperatorCon
expression ->
rightRexNode.isA(SqlKind.LITERAL) ?
StringUtils.format("'P%sM'", RexLiteral.value(rightRexNode)) :
StringUtils.format("concat('P', %s, 'M')", expression)
StringUtils.format("concat('P', %s, 'M')", expression),
ColumnType.STRING
),
DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.longLiteral(direction > 0 ? 1 : -1)),
DruidExpression.ofStringLiteral(plannerContext.getTimeZone().getID())

View File

@ -37,6 +37,8 @@ import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.TypedInFilter;
import org.apache.druid.query.lookup.LookupExtractor;
import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider;
import org.apache.druid.query.lookup.RegisteredLookupExtractionFn;
@ -370,13 +372,21 @@ public class PlannerContext
* {@link org.apache.druid.query.filter.EqualityFilter}, and {@link org.apache.druid.query.filter.NullFilter} (false).
*
* Typically true when {@link NullHandling#replaceWithDefault()} and false when {@link NullHandling#sqlCompatible()}.
* Can be overriden by the undocumented context parameter {@link #CTX_SQL_USE_BOUNDS_AND_SELECTORS}.
* Can be overriden by the context parameter {@link #CTX_SQL_USE_BOUNDS_AND_SELECTORS}.
*/
public boolean isUseBoundsAndSelectors()
{
return useBoundsAndSelectors;
}
/**
* Whether we should use {@link InDimFilter} (true) or {@link TypedInFilter} (false).
*/
public boolean isUseLegacyInFilter()
{
return useBoundsAndSelectors || NullHandling.replaceWithDefault();
}
/**
* Whether we should use {@link AggregatePullUpLookupRule} to pull LOOKUP functions on injective lookups up above
* a GROUP BY.

View File

@ -65,6 +65,7 @@ import org.apache.druid.query.scan.ScanQuery;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.query.topn.DimensionTopNMetricSpec;
import org.apache.druid.query.topn.TopNQueryBuilder;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.join.JoinType;
@ -1353,69 +1354,128 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
@Test
public void testScalarInArrayFilter()
{
msqIncompatible();
testQuery(
"SELECT dim2 FROM druid.numfoo WHERE SCALAR_IN_ARRAY(dim2, ARRAY['a', 'd']) LIMIT 5",
ImmutableList.of(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.filters(
new ExpressionDimFilter("scalar_in_array(\"dim2\",array('a','d'))", ExprMacroTable.nil())
)
.columns("dim2")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.limit(5)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{"a"},
new Object[]{"a"}
)
"SELECT dim2 FROM druid.numfoo\n"
+ "WHERE\n"
+ " SCALAR_IN_ARRAY(dim2, ARRAY['a', 'd'])\n"
+ " OR SCALAR_IN_ARRAY(SUBSTRING(dim1, 1, 1), ARRAY[NULL, 'foo', 'bar'])\n"
+ " OR SCALAR_IN_ARRAY(cnt * 2, ARRAY[3])\n",
ImmutableList.of(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(
VirtualColumns.create(
NullHandling.sqlCompatible()
? ImmutableList.of(
expressionVirtualColumn("v0", "substring(\"dim1\", 0, 1)", ColumnType.STRING),
expressionVirtualColumn("v1", "(\"cnt\" * 2)", ColumnType.LONG)
)
: ImmutableList.of(
expressionVirtualColumn("v0", "(\"cnt\" * 2)", ColumnType.LONG)
)
)
)
.filters(
NullHandling.sqlCompatible()
? or(
in("dim2", Arrays.asList("a", "d")),
in("v0", Arrays.asList(null, "foo", "bar")),
in("v1", ColumnType.LONG, Collections.singletonList(3L))
)
: or(
in("dim2", Arrays.asList("a", "d")),
in("dim1", Arrays.asList(null, "foo", "bar"), new SubstringDimExtractionFn(0, 1)),
in("v0", ColumnType.LONG, Collections.singletonList(3L))
)
)
.columns("dim2")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{"a"},
new Object[]{"a"}
)
);
}
@Test
public void testNotScalarInArrayFilter()
{
testQuery(
"SELECT dim2 FROM druid.numfoo\n"
+ "WHERE NOT SCALAR_IN_ARRAY(dim2, ARRAY['a', 'd'])\n",
ImmutableList.of(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.filters(not(in("dim2", Arrays.asList("a", "d"))))
.columns("dim2")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
NullHandling.sqlCompatible()
? ImmutableList.of(
new Object[]{""},
new Object[]{"abc"}
)
: ImmutableList.of(
new Object[]{""},
new Object[]{""},
new Object[]{"abc"},
new Object[]{""}
)
);
}
@Test
public void testArrayScalarInFilter_MVD()
{
msqIncompatible();
testBuilder()
.sql(
"SELECT dim3, (CASE WHEN scalar_in_array(dim3, Array['a', 'b', 'd']) THEN 'abd' ELSE 'not abd' END) " +
"FROM druid.numfoo"
)
.expectedQueries(
ImmutableList.of(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(
new ExpressionVirtualColumn(
"v0",
"case_searched(scalar_in_array(\"dim3\",array('a','b','d')),'abd','not abd')",
ColumnType.STRING,
ExprMacroTable.nil()
)
)
.columns("dim3", "v0")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()
)
)
.expectedResults(ResultMatchMode.RELAX_NULLS,
ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "[\"abd\",\"abd\"]"},
new Object[]{"[\"b\",\"c\"]", "[\"abd\",\"not abd\"]"},
new Object[]{"d", "abd"},
new Object[]{"", "not abd"},
new Object[]{null, "not abd"},
new Object[]{null, "not abd"}
)
)
.run();
// In the fifth row, dim3 is an empty list. The Scan query in MSQ reads this with makeDimensionSelector, whereas
// the Scan query in native reads this makeColumnValueSelector. Behavior of those selectors is inconsistent.
// The DimensionSelector returns an empty list; the ColumnValueSelector returns a list containing a single null.
final String expectedValueForEmptyMvd =
queryFramework().engine().name().equals("msq-task")
? NullHandling.defaultStringValue()
: "not abd";
testBuilder()
.sql(
"SELECT dim3, (CASE WHEN scalar_in_array(dim3, Array['a', 'b', 'd']) THEN 'abd' ELSE 'not abd' END) " +
"FROM druid.numfoo"
)
.expectedQueries(
ImmutableList.of(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(
expressionVirtualColumn(
"v0",
"case_searched(scalar_in_array(\"dim3\",array('a','b','d')),'abd','not abd')",
ColumnType.STRING
)
)
.columns("dim3", "v0")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()
)
)
.expectedResults(
ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "[\"abd\",\"abd\"]"},
new Object[]{"[\"b\",\"c\"]", "[\"abd\",\"not abd\"]"},
new Object[]{"d", "abd"},
new Object[]{"", "not abd"},
new Object[]{NullHandling.defaultStringValue(), expectedValueForEmptyMvd},
new Object[]{NullHandling.defaultStringValue(), "not abd"}
)
)
.run();
}
@Test

View File

@ -19,9 +19,12 @@
package org.apache.druid.sql.calcite.expression;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
@ -90,7 +93,7 @@ public class DruidExpressionTest extends InitializedNullHandlingTest
}
@Test
public void longLiteral_roundTrip()
public void test_longLiteral_roundTrip()
{
final long[] longs = {
0,
@ -107,4 +110,124 @@ public class DruidExpressionTest extends InitializedNullHandlingTest
Assert.assertEquals(n, ((Number) expr.getLiteralValue()).longValue());
}
}
@Test
public void test_ofLiteral_nullString()
{
final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.STRING, null));
Assert.assertEquals(ColumnType.STRING, expression.getDruidType());
Assert.assertEquals("null", expression.getExpression());
}
@Test
public void test_ofLiteral_nullLong()
{
final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.LONG, null));
Assert.assertEquals(ColumnType.LONG, expression.getDruidType());
Assert.assertEquals("null", expression.getExpression());
}
@Test
public void test_ofLiteral_nullDouble()
{
final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.DOUBLE, null));
Assert.assertEquals(ColumnType.DOUBLE, expression.getDruidType());
Assert.assertEquals("null", expression.getExpression());
}
@Test
public void test_ofLiteral_nullArray()
{
final DruidExpression expression =
DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.STRING_ARRAY, null));
Assert.assertEquals(ColumnType.STRING_ARRAY, expression.getDruidType());
Assert.assertEquals("null", expression.getExpression());
}
@Test
public void test_ofLiteral_string()
{
final String s = "abcdé\n \\\" ' \uD83E\uDD20 \txyz";
final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.STRING, s));
Assert.assertEquals(ColumnType.STRING, expression.getDruidType());
Assert.assertEquals("'abcdé\\u000A \\u005C\\u0022 \\u0027 \\uD83E\\uDD20 \\u0009xyz'", expression.getExpression());
Assert.assertEquals(s, Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue());
}
@Test
public void test_ofLiteral_emptyString()
{
final String s = "";
final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.STRING, s));
Assert.assertEquals(ColumnType.STRING, expression.getDruidType());
Assert.assertEquals("''", expression.getExpression());
Assert.assertEquals(
NullHandling.emptyToNullIfNeeded(s),
Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue()
);
}
@Test
public void test_ofLiteral_long()
{
final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.LONG, -123));
Assert.assertEquals(ColumnType.LONG, expression.getDruidType());
Assert.assertEquals("-123", expression.getExpression());
Assert.assertEquals(-123L, Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue());
}
@Test
public void test_ofLiteral_double()
{
final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.DOUBLE, -123.4));
Assert.assertEquals(ColumnType.DOUBLE, expression.getDruidType());
Assert.assertEquals("-123.4", expression.getExpression());
Assert.assertEquals(-123.4, Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue());
}
@Test
public void test_ofLiteral_doubleNan()
{
final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.DOUBLE, Double.NaN));
Assert.assertEquals(ColumnType.DOUBLE, expression.getDruidType());
Assert.assertEquals("NaN", expression.getExpression());
Assert.assertEquals(Double.NaN, Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue());
}
@Test
public void test_ofLiteral_doubleNegativeInfinity()
{
final DruidExpression expression =
DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.DOUBLE, Double.NEGATIVE_INFINITY));
Assert.assertEquals(ColumnType.DOUBLE, expression.getDruidType());
Assert.assertEquals("-Infinity", expression.getExpression());
Assert.assertEquals(
Double.NEGATIVE_INFINITY,
Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue()
);
}
@Test
public void test_ofLiteral_doublePositiveInfinity()
{
final DruidExpression expression =
DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.DOUBLE, Double.POSITIVE_INFINITY));
Assert.assertEquals(ColumnType.DOUBLE, expression.getDruidType());
Assert.assertEquals("Infinity", expression.getExpression());
Assert.assertEquals(
Double.POSITIVE_INFINITY,
Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue()
);
}
}

View File

@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.avatica.util.TimeUnitRange;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlOperator;
@ -33,6 +34,8 @@ import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.query.extraction.RegexDimExtractionFn;
import org.apache.druid.query.extraction.SubstringDimExtractionFn;
@ -65,12 +68,17 @@ import org.apache.druid.sql.calcite.expression.builtin.TimeFormatOperatorConvers
import org.apache.druid.sql.calcite.expression.builtin.TimeParseOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.TimeShiftOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.TruncateOperatorConversion;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
import org.apache.druid.sql.calcite.planner.DruidTypeSystem;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.util.CalciteTestBase;
import org.joda.time.DateTimeZone;
import org.joda.time.Period;
import org.junit.Assert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import java.math.BigDecimal;
import java.util.Collections;
@ -104,29 +112,32 @@ public class ExpressionsTest extends CalciteTestBase
.build();
private static final Map<String, Object> BINDINGS = ImmutableMap.<String, Object>builder()
.put("t", DateTimes.of("2000-02-03T04:05:06").getMillis())
.put("a", 10)
.put("b", 25)
.put("p", 3)
.put("x", 2.25)
.put("y", 3.0)
.put("z", -2.25)
.put("o", 0)
.put("nan", Double.NaN)
.put("inf", Double.POSITIVE_INFINITY)
.put("-inf", Double.NEGATIVE_INFINITY)
.put("fnan", Float.NaN)
.put("finf", Float.POSITIVE_INFINITY)
.put("-finf", Float.NEGATIVE_INFINITY)
.put("s", "foo")
.put("hexstr", "EF")
.put("intstr", "-100")
.put("spacey", " hey there ")
.put("newliney", "beep\nboop")
.put("tstr", "2000-02-03 04:05:06")
.put("dstr", "2000-02-03")
.put("timezone", "America/Los_Angeles")
.build();
.put(
"t",
DateTimes.of("2000-02-03T04:05:06").getMillis()
)
.put("a", 10)
.put("b", 25)
.put("p", 3)
.put("x", 2.25)
.put("y", 3.0)
.put("z", -2.25)
.put("o", 0)
.put("nan", Double.NaN)
.put("inf", Double.POSITIVE_INFINITY)
.put("-inf", Double.NEGATIVE_INFINITY)
.put("fnan", Float.NaN)
.put("finf", Float.POSITIVE_INFINITY)
.put("-finf", Float.NEGATIVE_INFINITY)
.put("s", "foo")
.put("hexstr", "EF")
.put("intstr", "-100")
.put("spacey", " hey there ")
.put("newliney", "beep\nboop")
.put("tstr", "2000-02-03 04:05:06")
.put("dstr", "2000-02-03")
.put("timezone", "America/Los_Angeles")
.build();
private ExpressionTestHelper testHelper;
@ -1923,7 +1934,7 @@ public class ExpressionsTest extends CalciteTestBase
(args) -> "(" + args.get(0).getExpression() + " - " + args.get(1).getExpression() + ")",
ImmutableList.of(
DruidExpression.ofColumn(ColumnType.LONG, "t"),
DruidExpression.ofLiteral(ColumnType.STRING, "90060000")
DruidExpression.ofLiteral(ColumnType.LONG, "90060000")
)
),
DateTimes.of("2000-02-03T04:05:06").minus(period).getMillis()
@ -2815,4 +2826,173 @@ public class ExpressionsTest extends CalciteTestBase
"45.678 KB"
);
}
@Test
public void testCalciteLiteralToDruidLiteral()
{
final RexBuilder rexBuilder = new RexBuilder(DruidTypeSystem.TYPE_FACTORY);
final PlannerContext plannerContext = Mockito.mock(PlannerContext.class);
Mockito.when(plannerContext.getTimeZone()).thenReturn(DateTimeZone.UTC);
assertDruidLiteral(
new DruidLiteral(ExpressionType.STRING, null),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeNullLiteral(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR))
)
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.STRING, ""),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeLiteral("")
)
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.LONG, null),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeNullLiteral(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT))
)
);
assertDruidLiteral(
new DruidLiteral(null, null),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeNullLiteral(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.NULL))
)
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.STRING, "abc"),
Expressions.calciteLiteralToDruidLiteral(plannerContext, rexBuilder.makeLiteral("abc"))
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.LONG, 1L),
Expressions.calciteLiteralToDruidLiteral(plannerContext, rexBuilder.makeLiteral(true))
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.LONG, 123L),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeExactLiteral(
BigDecimal.valueOf(123L),
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER)
)
)
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.DOUBLE, 123.0),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeExactLiteral(
BigDecimal.valueOf(123L),
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.DECIMAL)
)
)
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.LONG, DateTimes.of("2000").getMillis()),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
Calcites.jodaToCalciteTimestampLiteral(
rexBuilder,
DateTimes.of("2000"),
DateTimeZone.UTC,
DruidTypeSystem.DEFAULT_TIMESTAMP_PRECISION
)
)
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.LONG, DateTimes.of("2000").getMillis()),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeDateLiteral(Calcites.jodaToCalciteDateString(DateTimes.of("2000"), DateTimeZone.UTC))
)
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.LONG, 3L),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeIntervalLiteral(
BigDecimal.valueOf(3),
new SqlIntervalQualifier(TimeUnit.DAY, TimeUnit.HOUR, SqlParserPos.ZERO)
)
)
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.LONG, 3),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeIntervalLiteral(
BigDecimal.valueOf(3),
new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO)
)
)
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.STRING, "123"),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeCast(
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR),
rexBuilder.makeExactLiteral(
BigDecimal.valueOf(123.7),
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER)
)
)
)
);
assertDruidLiteral(
new DruidLiteral(ExpressionType.DOUBLE, 123.0),
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeCast(
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.DOUBLE),
rexBuilder.makeExactLiteral(
BigDecimal.valueOf(123L),
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER)
)
)
)
);
Assert.assertNull(
Expressions.calciteLiteralToDruidLiteral(
plannerContext,
rexBuilder.makeCast(
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.DATE),
Calcites.jodaToCalciteTimestampLiteral(
rexBuilder,
DateTimes.of("2000-01-02T03:04:05"),
DateTimeZone.UTC,
DruidTypeSystem.DEFAULT_TIMESTAMP_PRECISION
)
)
)
);
}
private void assertDruidLiteral(
final DruidLiteral expected,
final DruidLiteral actual
)
{
Assert.assertEquals(
StringUtils.format("%s: %s", expected.type(), expected.value()),
StringUtils.format("%s: %s", actual.type(), actual.value())
);
}
}