use Calcites.getColumnTypeForRelDataType for SQL CAST operator conversion (#13890)

* use Calcites.getColumnTypeForRelDataType for SQL CAST operator conversion

* fix comment

* intervals are strings but also longs
This commit is contained in:
Clint Wylie 2023-03-07 13:12:15 -08:00 committed by GitHub
parent ca4df85941
commit 3924f0eff4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 129 additions and 66 deletions

View File

@ -20,7 +20,6 @@
package org.apache.druid.sql.calcite.expression.builtin;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
@ -30,6 +29,7 @@ import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.PeriodGranularity;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DruidExpression;
@ -39,46 +39,10 @@ import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.joda.time.Period;
import java.util.Map;
import java.util.function.Function;
public class CastOperatorConversion implements SqlOperatorConversion
{
private static final Map<SqlTypeName, ExprType> EXPRESSION_TYPES;
static {
final ImmutableMap.Builder<SqlTypeName, ExprType> builder = ImmutableMap.builder();
for (SqlTypeName type : SqlTypeName.FRACTIONAL_TYPES) {
builder.put(type, ExprType.DOUBLE);
}
for (SqlTypeName type : SqlTypeName.INT_TYPES) {
builder.put(type, ExprType.LONG);
}
for (SqlTypeName type : SqlTypeName.STRING_TYPES) {
builder.put(type, ExprType.STRING);
}
// Booleans are treated as longs in Druid expressions, using two-value logic (positive = true, nonpositive = false).
builder.put(SqlTypeName.BOOLEAN, ExprType.LONG);
// Timestamps are treated as longs (millis since the epoch) in Druid expressions.
builder.put(SqlTypeName.TIMESTAMP, ExprType.LONG);
builder.put(SqlTypeName.DATE, ExprType.LONG);
for (SqlTypeName type : SqlTypeName.DAY_INTERVAL_TYPES) {
builder.put(type, ExprType.LONG);
}
for (SqlTypeName type : SqlTypeName.YEAR_INTERVAL_TYPES) {
builder.put(type, ExprType.LONG);
}
EXPRESSION_TYPES = builder.build();
}
@Override
public SqlOperator calciteOperator()
{
@ -118,28 +82,34 @@ public class CastOperatorConversion implements SqlOperatorConversion
} else {
// Handle other casts. If either type is ANY, use the other type instead. If both are ANY, this means nulls
// downstream, Druid will try its best
final ExprType fromExprType = SqlTypeName.ANY.equals(fromType)
? EXPRESSION_TYPES.get(toType)
: EXPRESSION_TYPES.get(fromType);
final ExprType toExprType = SqlTypeName.ANY.equals(toType)
? EXPRESSION_TYPES.get(fromType)
: EXPRESSION_TYPES.get(toType);
final ColumnType fromDruidType = Calcites.getColumnTypeForRelDataType(operand.getType());
final ColumnType toDruidType = Calcites.getColumnTypeForRelDataType(rexNode.getType());
if (fromExprType == null || toExprType == null) {
final ExpressionType fromExpressionType = SqlTypeName.ANY.equals(fromType)
? ExpressionType.fromColumnType(toDruidType)
: ExpressionType.fromColumnType(fromDruidType);
final ExpressionType toExpressionType = SqlTypeName.ANY.equals(toType)
? ExpressionType.fromColumnType(fromDruidType)
: ExpressionType.fromColumnType(toDruidType);
if (fromExpressionType == null || toExpressionType == null) {
// We have no runtime type for these SQL types.
return null;
}
final DruidExpression typeCastExpression;
if (fromExprType != toExprType) {
if (fromExpressionType.equals(toExpressionType)) {
typeCastExpression = operandExpression;
} else if (SqlTypeName.INTERVAL_TYPES.contains(fromType) && toExpressionType.is(ExprType.LONG)) {
// intervals can be longs without an explicit cast
typeCastExpression = operandExpression;
} else {
// Ignore casts for simple extractions (use Function.identity) since it is ok in many cases.
typeCastExpression = operandExpression.map(
Function.identity(),
expression -> StringUtils.format("CAST(%s, '%s')", expression, toExprType.toString())
expression -> StringUtils.format("CAST(%s, '%s')", expression, toExpressionType.asTypeString())
);
} else {
typeCastExpression = operandExpression;
}
if (toType == SqlTypeName.DATE) {

View File

@ -56,9 +56,16 @@ class ReductionOperatorConversionHelper
boolean hasDouble = false;
boolean isString = false;
for (int i = 0; i < n; i++) {
RelDataType type = opBinding.getOperandType(i);
SqlTypeName sqlTypeName = type.getSqlTypeName();
ColumnType valueType = Calcites.getColumnTypeForRelDataType(type);
final RelDataType type = opBinding.getOperandType(i);
final SqlTypeName sqlTypeName = type.getSqlTypeName();
final ColumnType valueType;
if (SqlTypeName.INTERVAL_TYPES.contains(type.getSqlTypeName())) {
// handle intervals as a LONG type even though it is a string
valueType = ColumnType.LONG;
} else {
valueType = Calcites.getColumnTypeForRelDataType(type);
}
// Return types are listed in order of preference:
if (valueType != null) {

View File

@ -160,7 +160,7 @@ public class Calcites
return ColumnType.DOUBLE;
} else if (isLongType(sqlTypeName)) {
return ColumnType.LONG;
} else if (SqlTypeName.CHAR_TYPES.contains(sqlTypeName)) {
} else if (isStringType(sqlTypeName)) {
return ColumnType.STRING;
} else if (SqlTypeName.OTHER == sqlTypeName) {
if (type instanceof RowSignatures.ComplexSqlType) {
@ -178,6 +178,12 @@ public class Calcites
}
}
public static boolean isStringType(SqlTypeName sqlTypeName)
{
return SqlTypeName.CHAR_TYPES.contains(sqlTypeName) ||
SqlTypeName.INTERVAL_TYPES.contains(sqlTypeName);
}
public static boolean isDoubleType(SqlTypeName sqlTypeName)
{
return SqlTypeName.FRACTIONAL_TYPES.contains(sqlTypeName) || SqlTypeName.APPROX_TYPES.contains(sqlTypeName);

View File

@ -29,8 +29,10 @@ import org.apache.druid.query.Druids;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.ExpressionDimFilter;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.LikeDimFilter;
import org.apache.druid.query.filter.OrDimFilter;
import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.GroupByQueryConfig;
@ -39,6 +41,7 @@ import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.query.scan.ScanQuery;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.virtual.ListFilteredVirtualColumn;
import org.apache.druid.sql.SqlPlanningException;
import org.apache.druid.sql.calcite.filtration.Filtration;
@ -1847,4 +1850,87 @@ public class CalciteMultiValueStringQueryTest extends BaseCalciteQueryTest
exception -> exception.expect(RuntimeException.class)
);
}
@Test
public void testMultiValueStringOverlapFilterCoalesceNvl()
{
testQuery(
"SELECT COALESCE(dim3, 'other') FROM druid.numfoo "
+ "WHERE MV_OVERLAP(COALESCE(MV_TO_ARRAY(dim3), ARRAY['other']), ARRAY['a', 'b', 'other']) OR "
+ "MV_OVERLAP(NVL(MV_TO_ARRAY(dim3), ARRAY['other']), ARRAY['a', 'b', 'other']) LIMIT 5",
ImmutableList.of(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.eternityInterval()
.virtualColumns(
new ExpressionVirtualColumn(
"v0",
"case_searched(notnull(\"dim3\"),\"dim3\",'other')",
ColumnType.STRING,
queryFramework().macroTable()
)
)
.filters(
new OrDimFilter(
new ExpressionDimFilter(
"case_searched(notnull(mv_to_array(\"dim3\")),array_overlap(mv_to_array(\"dim3\"),array('a','b','other')),1)",
null,
queryFramework().macroTable()
),
new ExpressionDimFilter(
"case_searched(notnull(mv_to_array(\"dim3\")),array_overlap(mv_to_array(\"dim3\"),array('a','b','other')),1)",
null,
queryFramework().macroTable()
)
)
)
.columns("v0")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.limit(5)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
NullHandling.replaceWithDefault()
? ImmutableList.of(
new Object[]{"[\"a\",\"b\"]"},
new Object[]{"[\"b\",\"c\"]"},
new Object[]{"other"},
new Object[]{"other"},
new Object[]{"other"}
)
: ImmutableList.of(
new Object[]{"[\"a\",\"b\"]"},
new Object[]{"[\"b\",\"c\"]"},
new Object[]{"other"},
new Object[]{"other"}
)
);
}
@Test
public void testMultiValueStringOverlapFilterInconsistentUsage()
{
testQueryThrows(
"SELECT COALESCE(dim3, 'other') FROM druid.numfoo "
+ "WHERE MV_OVERLAP(COALESCE(dim3, ARRAY['other']), ARRAY['a', 'b', 'other']) LIMIT 5",
e -> {
e.expect(SqlPlanningException.class);
e.expectMessage("Illegal mixing of types in CASE or COALESCE statement");
}
);
}
@Test
public void testMultiValueStringOverlapFilterInconsistentUsage2()
{
testQueryThrows(
"SELECT COALESCE(dim3, 'other') FROM druid.numfoo "
+ "WHERE MV_OVERLAP(COALESCE(dim3, 'other'), ARRAY['a', 'b', 'other']) LIMIT 5",
e -> {
e.expect(RuntimeException.class);
e.expectMessage("Invalid expression: (case_searched [(notnull [dim3]), (array_overlap [dim3, [a, b, other]]), 1]); [dim3] used as both scalar and array variables");
}
);
}
}

View File

@ -1751,8 +1751,7 @@ public class ExpressionsTest extends ExpressionTestBase
(args) -> "(" + args.get(0).getExpression() + " - " + args.get(1).getExpression() + ")",
ImmutableList.of(
DruidExpression.ofColumn(ColumnType.LONG, "t"),
// RexNode type of "interval day to minute" is not converted to druid long... yet
DruidExpression.ofLiteral(null, "90060000")
DruidExpression.ofLiteral(ColumnType.STRING, "90060000")
)
),
DateTimes.of("2000-02-03T04:05:06").minus(period).getMillis()
@ -1779,8 +1778,7 @@ public class ExpressionsTest extends ExpressionTestBase
DruidExpression.functionCall("timestamp_shift"),
ImmutableList.of(
DruidExpression.ofColumn(ColumnType.LONG, "t"),
// RexNode type "interval year to month" is not reported as ColumnType.STRING
DruidExpression.ofLiteral(null, DruidExpression.stringLiteral("P13M")),
DruidExpression.ofLiteral(ColumnType.STRING, DruidExpression.stringLiteral("P13M")),
DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.longLiteral(-1)),
DruidExpression.ofStringLiteral("UTC")
)

View File

@ -246,10 +246,8 @@ public class GreatestExpressionTest extends ExpressionTestBase
}
@Test
public void testInvalidType()
public void testIntervalYearMonth()
{
expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH");
testExpression(
Collections.singletonList(
testHelper.makeLiteral(
@ -257,8 +255,8 @@ public class GreatestExpressionTest extends ExpressionTestBase
new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO)
)
),
null,
null
buildExpectedExpression(13),
13L
);
}

View File

@ -247,10 +247,8 @@ public class LeastExpressionTest extends ExpressionTestBase
}
@Test
public void testInvalidType()
public void testIntervalYearMonth()
{
expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH");
testExpression(
Collections.singletonList(
testHelper.makeLiteral(
@ -258,8 +256,8 @@ public class LeastExpressionTest extends ExpressionTestBase
new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO)
)
),
null,
null
buildExpectedExpression(13),
13L
);
}