From 3924f0eff4e0a9990896f88ce2ed46d180f7b7eb Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 7 Mar 2023 13:12:15 -0800 Subject: [PATCH] use Calcites.getColumnTypeForRelDataType for SQL CAST operator conversion (#13890) * use Calcites.getColumnTypeForRelDataType for SQL CAST operator conversion * fix comment * intervals are strings but also longs --- .../builtin/CastOperatorConversion.java | 66 ++++---------- .../ReductionOperatorConversionHelper.java | 13 ++- .../druid/sql/calcite/planner/Calcites.java | 8 +- .../CalciteMultiValueStringQueryTest.java | 86 +++++++++++++++++++ .../calcite/expression/ExpressionsTest.java | 6 +- .../expression/GreatestExpressionTest.java | 8 +- .../expression/LeastExpressionTest.java | 8 +- 7 files changed, 129 insertions(+), 66 deletions(-) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java index 9062a32d0ba..7f8c1ddee88 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java @@ -20,7 +20,6 @@ package org.apache.druid.sql.calcite.expression.builtin; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlOperator; @@ -30,6 +29,7 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.PeriodGranularity; import org.apache.druid.math.expr.ExprType; +import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DruidExpression; @@ -39,46 +39,10 @@ import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.joda.time.Period; -import java.util.Map; import java.util.function.Function; public class CastOperatorConversion implements SqlOperatorConversion { - private static final Map EXPRESSION_TYPES; - - static { - final ImmutableMap.Builder builder = ImmutableMap.builder(); - - for (SqlTypeName type : SqlTypeName.FRACTIONAL_TYPES) { - builder.put(type, ExprType.DOUBLE); - } - - for (SqlTypeName type : SqlTypeName.INT_TYPES) { - builder.put(type, ExprType.LONG); - } - - for (SqlTypeName type : SqlTypeName.STRING_TYPES) { - builder.put(type, ExprType.STRING); - } - - // Booleans are treated as longs in Druid expressions, using two-value logic (positive = true, nonpositive = false). - builder.put(SqlTypeName.BOOLEAN, ExprType.LONG); - - // Timestamps are treated as longs (millis since the epoch) in Druid expressions. - builder.put(SqlTypeName.TIMESTAMP, ExprType.LONG); - builder.put(SqlTypeName.DATE, ExprType.LONG); - - for (SqlTypeName type : SqlTypeName.DAY_INTERVAL_TYPES) { - builder.put(type, ExprType.LONG); - } - - for (SqlTypeName type : SqlTypeName.YEAR_INTERVAL_TYPES) { - builder.put(type, ExprType.LONG); - } - - EXPRESSION_TYPES = builder.build(); - } - @Override public SqlOperator calciteOperator() { @@ -118,28 +82,34 @@ public class CastOperatorConversion implements SqlOperatorConversion } else { // Handle other casts. If either type is ANY, use the other type instead. If both are ANY, this means nulls // downstream, Druid will try its best - final ExprType fromExprType = SqlTypeName.ANY.equals(fromType) - ? EXPRESSION_TYPES.get(toType) - : EXPRESSION_TYPES.get(fromType); - final ExprType toExprType = SqlTypeName.ANY.equals(toType) - ? EXPRESSION_TYPES.get(fromType) - : EXPRESSION_TYPES.get(toType); + final ColumnType fromDruidType = Calcites.getColumnTypeForRelDataType(operand.getType()); + final ColumnType toDruidType = Calcites.getColumnTypeForRelDataType(rexNode.getType()); - if (fromExprType == null || toExprType == null) { + final ExpressionType fromExpressionType = SqlTypeName.ANY.equals(fromType) + ? ExpressionType.fromColumnType(toDruidType) + : ExpressionType.fromColumnType(fromDruidType); + final ExpressionType toExpressionType = SqlTypeName.ANY.equals(toType) + ? ExpressionType.fromColumnType(fromDruidType) + : ExpressionType.fromColumnType(toDruidType); + + if (fromExpressionType == null || toExpressionType == null) { // We have no runtime type for these SQL types. return null; } final DruidExpression typeCastExpression; - if (fromExprType != toExprType) { + if (fromExpressionType.equals(toExpressionType)) { + typeCastExpression = operandExpression; + } else if (SqlTypeName.INTERVAL_TYPES.contains(fromType) && toExpressionType.is(ExprType.LONG)) { + // intervals can be longs without an explicit cast + typeCastExpression = operandExpression; + } else { // Ignore casts for simple extractions (use Function.identity) since it is ok in many cases. typeCastExpression = operandExpression.map( Function.identity(), - expression -> StringUtils.format("CAST(%s, '%s')", expression, toExprType.toString()) + expression -> StringUtils.format("CAST(%s, '%s')", expression, toExpressionType.asTypeString()) ); - } else { - typeCastExpression = operandExpression; } if (toType == SqlTypeName.DATE) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java index a747e9d27da..427c93a2878 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java @@ -56,9 +56,16 @@ class ReductionOperatorConversionHelper boolean hasDouble = false; boolean isString = false; for (int i = 0; i < n; i++) { - RelDataType type = opBinding.getOperandType(i); - SqlTypeName sqlTypeName = type.getSqlTypeName(); - ColumnType valueType = Calcites.getColumnTypeForRelDataType(type); + final RelDataType type = opBinding.getOperandType(i); + final SqlTypeName sqlTypeName = type.getSqlTypeName(); + final ColumnType valueType; + + if (SqlTypeName.INTERVAL_TYPES.contains(type.getSqlTypeName())) { + // handle intervals as a LONG type even though it is a string + valueType = ColumnType.LONG; + } else { + valueType = Calcites.getColumnTypeForRelDataType(type); + } // Return types are listed in order of preference: if (valueType != null) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index aa95fee2871..331a61a1f50 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -160,7 +160,7 @@ public class Calcites return ColumnType.DOUBLE; } else if (isLongType(sqlTypeName)) { return ColumnType.LONG; - } else if (SqlTypeName.CHAR_TYPES.contains(sqlTypeName)) { + } else if (isStringType(sqlTypeName)) { return ColumnType.STRING; } else if (SqlTypeName.OTHER == sqlTypeName) { if (type instanceof RowSignatures.ComplexSqlType) { @@ -178,6 +178,12 @@ public class Calcites } } + public static boolean isStringType(SqlTypeName sqlTypeName) + { + return SqlTypeName.CHAR_TYPES.contains(sqlTypeName) || + SqlTypeName.INTERVAL_TYPES.contains(sqlTypeName); + } + public static boolean isDoubleType(SqlTypeName sqlTypeName) { return SqlTypeName.FRACTIONAL_TYPES.contains(sqlTypeName) || SqlTypeName.APPROX_TYPES.contains(sqlTypeName); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java index 534f334b834..449a6ad7550 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java @@ -29,8 +29,10 @@ import org.apache.druid.query.Druids; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.filter.AndDimFilter; +import org.apache.druid.query.filter.ExpressionDimFilter; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; +import org.apache.druid.query.filter.OrDimFilter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQueryConfig; @@ -39,6 +41,7 @@ import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; import org.apache.druid.query.ordering.StringComparators; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.apache.druid.segment.virtual.ListFilteredVirtualColumn; import org.apache.druid.sql.SqlPlanningException; import org.apache.druid.sql.calcite.filtration.Filtration; @@ -1847,4 +1850,87 @@ public class CalciteMultiValueStringQueryTest extends BaseCalciteQueryTest exception -> exception.expect(RuntimeException.class) ); } + + @Test + public void testMultiValueStringOverlapFilterCoalesceNvl() + { + testQuery( + "SELECT COALESCE(dim3, 'other') FROM druid.numfoo " + + "WHERE MV_OVERLAP(COALESCE(MV_TO_ARRAY(dim3), ARRAY['other']), ARRAY['a', 'b', 'other']) OR " + + "MV_OVERLAP(NVL(MV_TO_ARRAY(dim3), ARRAY['other']), ARRAY['a', 'b', 'other']) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .eternityInterval() + .virtualColumns( + new ExpressionVirtualColumn( + "v0", + "case_searched(notnull(\"dim3\"),\"dim3\",'other')", + ColumnType.STRING, + queryFramework().macroTable() + ) + ) + .filters( + new OrDimFilter( + new ExpressionDimFilter( + "case_searched(notnull(mv_to_array(\"dim3\")),array_overlap(mv_to_array(\"dim3\"),array('a','b','other')),1)", + null, + queryFramework().macroTable() + ), + new ExpressionDimFilter( + "case_searched(notnull(mv_to_array(\"dim3\")),array_overlap(mv_to_array(\"dim3\"),array('a','b','other')),1)", + null, + queryFramework().macroTable() + ) + ) + ) + .columns("v0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.replaceWithDefault() + ? ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"b\",\"c\"]"}, + new Object[]{"other"}, + new Object[]{"other"}, + new Object[]{"other"} + ) + : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"b\",\"c\"]"}, + new Object[]{"other"}, + new Object[]{"other"} + ) + ); + } + + @Test + public void testMultiValueStringOverlapFilterInconsistentUsage() + { + testQueryThrows( + "SELECT COALESCE(dim3, 'other') FROM druid.numfoo " + + "WHERE MV_OVERLAP(COALESCE(dim3, ARRAY['other']), ARRAY['a', 'b', 'other']) LIMIT 5", + e -> { + e.expect(SqlPlanningException.class); + e.expectMessage("Illegal mixing of types in CASE or COALESCE statement"); + } + + ); + } + + @Test + public void testMultiValueStringOverlapFilterInconsistentUsage2() + { + testQueryThrows( + "SELECT COALESCE(dim3, 'other') FROM druid.numfoo " + + "WHERE MV_OVERLAP(COALESCE(dim3, 'other'), ARRAY['a', 'b', 'other']) LIMIT 5", + e -> { + e.expect(RuntimeException.class); + e.expectMessage("Invalid expression: (case_searched [(notnull [dim3]), (array_overlap [dim3, [a, b, other]]), 1]); [dim3] used as both scalar and array variables"); + } + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java index 91ab2a839f2..f7fec59032f 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java @@ -1751,8 +1751,7 @@ public class ExpressionsTest extends ExpressionTestBase (args) -> "(" + args.get(0).getExpression() + " - " + args.get(1).getExpression() + ")", ImmutableList.of( DruidExpression.ofColumn(ColumnType.LONG, "t"), - // RexNode type of "interval day to minute" is not converted to druid long... yet - DruidExpression.ofLiteral(null, "90060000") + DruidExpression.ofLiteral(ColumnType.STRING, "90060000") ) ), DateTimes.of("2000-02-03T04:05:06").minus(period).getMillis() @@ -1779,8 +1778,7 @@ public class ExpressionsTest extends ExpressionTestBase DruidExpression.functionCall("timestamp_shift"), ImmutableList.of( DruidExpression.ofColumn(ColumnType.LONG, "t"), - // RexNode type "interval year to month" is not reported as ColumnType.STRING - DruidExpression.ofLiteral(null, DruidExpression.stringLiteral("P13M")), + DruidExpression.ofLiteral(ColumnType.STRING, DruidExpression.stringLiteral("P13M")), DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.longLiteral(-1)), DruidExpression.ofStringLiteral("UTC") ) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java index 87ce28ef3f6..8418edf994e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java @@ -246,10 +246,8 @@ public class GreatestExpressionTest extends ExpressionTestBase } @Test - public void testInvalidType() + public void testIntervalYearMonth() { - expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH"); - testExpression( Collections.singletonList( testHelper.makeLiteral( @@ -257,8 +255,8 @@ public class GreatestExpressionTest extends ExpressionTestBase new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO) ) ), - null, - null + buildExpectedExpression(13), + 13L ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java index 047f6936d30..6702769e927 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java @@ -247,10 +247,8 @@ public class LeastExpressionTest extends ExpressionTestBase } @Test - public void testInvalidType() + public void testIntervalYearMonth() { - expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH"); - testExpression( Collections.singletonList( testHelper.makeLiteral( @@ -258,8 +256,8 @@ public class LeastExpressionTest extends ExpressionTestBase new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO) ) ), - null, - null + buildExpectedExpression(13), + 13L ); }