From bfbd7ec4329b4f083e1c2044b847a04b9f8912cc Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 15 Jun 2021 12:26:59 -0700 Subject: [PATCH] fix a bugs related to SQL type inference return type nullability (#11327) * fix a bunch of type inference nullability bugs * fixes * style * fix test * fix concat --- .../org/apache/druid/math/expr/Function.java | 3 + .../apache/druid/math/expr/FunctionTest.java | 10 +++ docs/querying/sql.md | 2 +- .../expression/TimestampShiftExprMacro.java | 12 ++- .../expression/TimestampShiftMacroTest.java | 19 +++++ .../expression/OperatorConversions.java | 43 +++++++--- .../ArrayLengthOperatorConversion.java | 2 +- .../ArrayOffsetOfOperatorConversion.java | 2 +- .../ArrayOrdinalOfOperatorConversion.java | 2 +- .../ArrayToStringOperatorConversion.java | 2 +- .../builtin/BTrimOperatorConversion.java | 2 +- .../builtin/ConcatOperatorConversion.java | 19 ++--- .../builtin/DateTruncOperatorConversion.java | 2 +- .../builtin/LPadOperatorConversion.java | 2 +- .../builtin/LTrimOperatorConversion.java | 2 +- .../builtin/LeftOperatorConversion.java | 2 +- .../MillisToTimestampOperatorConversion.java | 2 +- .../builtin/ParseLongOperatorConversion.java | 2 +- .../builtin/RPadOperatorConversion.java | 2 +- .../builtin/RTrimOperatorConversion.java | 2 +- .../builtin/RepeatOperatorConversion.java | 2 +- .../builtin/ReverseOperatorConversion.java | 2 +- .../builtin/RightOperatorConversion.java | 2 +- .../StringFormatOperatorConversion.java | 2 +- .../builtin/StrposOperatorConversion.java | 2 +- .../builtin/TextcatOperatorConversion.java | 2 +- .../builtin/TimeCeilOperatorConversion.java | 2 +- .../TimeExtractOperatorConversion.java | 2 +- .../builtin/TimeFloorOperatorConversion.java | 2 +- .../builtin/TimeFormatOperatorConversion.java | 2 +- .../builtin/TimeShiftOperatorConversion.java | 2 +- .../TimestampToMillisOperatorConversion.java | 2 +- .../druid/sql/calcite/CalciteQueryTest.java | 49 ++++++++++++ .../expression/OperatorConversionsTest.java | 79 +++++++++++++++++++ 34 files changed, 232 insertions(+), 54 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index 711acadf08b..a0b8e59a17f 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -2721,6 +2721,9 @@ public interface Function @Override protected ExprEval eval(String x, int y) { + if (x == null) { + return ExprEval.of(null); + } return ExprEval.of(y < 1 ? NullHandling.defaultStringValue() : StringUtils.repeat(x, y)); } } diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java index 5ded90fd25f..8ee989846d4 100644 --- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -596,6 +596,16 @@ public class FunctionTest extends InitializedNullHandlingTest assertExpr("bitwiseConvertDoubleToLongBits(null)", null); } + @Test + public void testRepeat() + { + assertExpr("repeat('hello', 2)", "hellohello"); + assertExpr("repeat('hello', -1)", null); + assertExpr("repeat(null, 10)", null); + assertExpr("repeat(nonexistent, 10)", null); + } + + private void assertExpr(final String expression, @Nullable final Object expectedResult) { final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); diff --git a/docs/querying/sql.md b/docs/querying/sql.md index 6f837de8ec1..12d0da989d0 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -303,7 +303,7 @@ columns in this mode are not nullable; any null or missing values will be treate In SQL compatible mode (`false`), NULLs are treated more closely to the SQL standard. The property affects both storage and querying, so for correct behavior, it should be set on all Druid service types to be available at both ingestion time and query time. There is some overhead associated with the ability to handle NULLs; see -the [segment internals](../design/segments.md#sql-compatible-null-handling)documentation for more details. +the [segment internals](../design/segments.md#sql-compatible-null-handling) documentation for more details. ## Aggregation functions diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java index e7c39d00a85..de1293aef00 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java @@ -99,7 +99,11 @@ public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro @Override public ExprEval eval(final ObjectBinding bindings) { - return ExprEval.of(chronology.add(period, args.get(0).eval(bindings).asLong(), step)); + ExprEval timestamp = args.get(0).eval(bindings); + if (timestamp.isNumericNull()) { + return ExprEval.of(null); + } + return ExprEval.of(chronology.add(period, timestamp.asLong(), step)); } @Override @@ -128,10 +132,14 @@ public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro @Override public ExprEval eval(final ObjectBinding bindings) { + ExprEval timestamp = args.get(0).eval(bindings); + if (timestamp.isNumericNull()) { + return ExprEval.of(null); + } final Period period = getPeriod(args, bindings); final Chronology chronology = getTimeZone(args, bindings); final int step = getStep(args, bindings); - return ExprEval.of(chronology.add(period, args.get(0).eval(bindings).asLong(), step)); + return ExprEval.of(chronology.add(period, timestamp.asLong(), step)); } @Override diff --git a/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java index c4710f9c360..05945b1cc70 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java @@ -20,6 +20,7 @@ package org.apache.druid.query.expression; import com.google.common.collect.ImmutableList; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.IAE; import org.apache.druid.math.expr.Expr; @@ -219,6 +220,24 @@ public class TimestampShiftMacroTest extends MacroTestBase ); } + @Test + public void testNull() + { + Expr expr = apply( + ImmutableList.of( + ExprEval.ofLong(null).toExpr(), + ExprEval.of("P1M").toExpr(), + ExprEval.of(1L).toExpr() + ) + ); + + if (NullHandling.replaceWithDefault()) { + Assert.assertEquals(2678400000L, expr.eval(ExprUtils.nilBindings()).value()); + } else { + Assert.assertNull(expr.eval(ExprUtils.nilBindings()).value()); + } + } + private static class NotLiteralExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr { NotLiteralExpr(String name) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java index 6f060f19c92..36607bfcb43 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java @@ -48,6 +48,7 @@ import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.calcite.util.Static; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; @@ -255,11 +256,12 @@ public class OperatorConversions } /** - * Sets the return type of the operator to "typeName", marked as non-nullable. + * Sets the return type of the operator to "typeName", marked as non-nullable. If this method is used it implies the + * operator should never, ever, return null. * - * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or - * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods - * cannot be mixed; you must call exactly one. + * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)} + * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before + * calling {@link #build()}. These methods cannot be mixed; you must call exactly one. */ public OperatorBuilder returnTypeNonNull(final SqlTypeName typeName) { @@ -274,9 +276,9 @@ public class OperatorConversions /** * Sets the return type of the operator to "typeName", marked as nullable. * - * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or - * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods - * cannot be mixed; you must call exactly one. + * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)} + * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before + * calling {@link #build()}. These methods cannot be mixed; you must call exactly one. */ public OperatorBuilder returnTypeNullable(final SqlTypeName typeName) { @@ -287,12 +289,27 @@ public class OperatorConversions ); return this; } + + /** + * Sets the return type of the operator to "typeName", marked as nullable if any of its operands are nullable. + * + * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)} + * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before + * calling {@link #build()}. These methods cannot be mixed; you must call exactly one. + */ + public OperatorBuilder returnTypeCascadeNullable(final SqlTypeName typeName) + { + Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times"); + this.returnTypeInference = ReturnTypes.cascade(ReturnTypes.explicit(typeName), SqlTypeTransforms.TO_NULLABLE); + return this; + } + /** * Sets the return type of the operator to an array type with elements of "typeName", marked as nullable. * - * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or - * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods - * cannot be mixed; you must call exactly one. + * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)} + * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before + * calling {@link #build()}. These methods cannot be mixed; you must call exactly one. */ public OperatorBuilder returnTypeNullableArray(final SqlTypeName elementTypeName) { @@ -308,9 +325,9 @@ public class OperatorConversions /** * Provides customized return type inference logic. * - * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or - * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods - * cannot be mixed; you must call exactly one. + * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)} + * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before + * calling {@link #build()}. These methods cannot be mixed; you must call exactly one. */ public OperatorBuilder returnTypeInference(final SqlReturnTypeInference returnTypeInference) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java index 073d93556d8..9e67cc3ea31 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java @@ -43,7 +43,7 @@ public class ArrayLengthOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.INTEGER) + .returnTypeCascadeNullable(SqlTypeName.INTEGER) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java index 51cad2feda4..ca026c5e10b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java @@ -47,7 +47,7 @@ public class ArrayOffsetOfOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.INTEGER) + .returnTypeNullable(SqlTypeName.INTEGER) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java index 12edb572743..dfc1501d52b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java @@ -47,7 +47,7 @@ public class ArrayOrdinalOfOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.INTEGER) + .returnTypeCascadeNullable(SqlTypeName.INTEGER) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java index 5d316a59116..285993b399c 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java @@ -47,7 +47,7 @@ public class ArrayToStringOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java index d77c20b5d78..648d54b9380 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java @@ -37,7 +37,7 @@ public class BTrimOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("BTRIM") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(1) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java index e7dbf504614..7ffc47dd17e 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java @@ -22,29 +22,22 @@ package org.apache.druid.sql.calcite.expression.builtin; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; -import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; -import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; public class ConcatOperatorConversion implements SqlOperatorConversion { - private static final SqlFunction SQL_FUNCTION = new SqlFunction( - "CONCAT", - SqlKind.OTHER_FUNCTION, - ReturnTypes.explicit( - factory -> Calcites.createSqlType(factory, SqlTypeName.VARCHAR) - ), - null, - OperandTypes.SAME_VARIADIC, - SqlFunctionCategory.STRING - ); + private static final SqlFunction SQL_FUNCTION = OperatorConversions + .operatorBuilder("CONCAT") + .operandTypeChecker(OperandTypes.SAME_VARIADIC) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) + .functionCategory(SqlFunctionCategory.STRING) + .build(); @Override public SqlFunction calciteOperator() diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java index f496e0ab967..574fa2fd46b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java @@ -67,7 +67,7 @@ public class DateTruncOperatorConversion implements SqlOperatorConversion .operatorBuilder("DATE_TRUNC") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP) .requiredOperands(2) - .returnTypeNonNull(SqlTypeName.TIMESTAMP) + .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java index 2d13b02fcbb..3d98d3e9f05 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java @@ -37,7 +37,7 @@ public class LPadOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("LPAD") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(2) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java index 70ec0c97e62..233ded0acb6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java @@ -37,7 +37,7 @@ public class LTrimOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("LTRIM") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(1) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java index 252343cddba..deeffa50760 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java @@ -39,7 +39,7 @@ public class LeftOperatorConversion implements SqlOperatorConversion .operatorBuilder("LEFT") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java index e8b8e748a64..2456f059f50 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java @@ -39,7 +39,7 @@ public class MillisToTimestampOperatorConversion implements SqlOperatorConversio private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("MILLIS_TO_TIMESTAMP") .operandTypes(SqlTypeFamily.EXACT_NUMERIC) - .returnTypeNonNull(SqlTypeName.TIMESTAMP) + .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java index 9fd710fb1cb..4de200002d0 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java @@ -38,7 +38,7 @@ public class ParseLongOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder(NAME) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) - .returnTypeNonNull(SqlTypeName.BIGINT) + .returnTypeCascadeNullable(SqlTypeName.BIGINT) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(1) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java index 47c8eadc2f8..5ab8454643c 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java @@ -37,7 +37,7 @@ public class RPadOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("RPAD") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(2) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java index 6aa8f1b28a6..bc96610d126 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java @@ -37,7 +37,7 @@ public class RTrimOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("RTRIM") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(1) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java index 9521a0443bc..55b01be9c5b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java @@ -39,7 +39,7 @@ public class RepeatOperatorConversion implements SqlOperatorConversion .operatorBuilder("REPEAT") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java index 70280abf2f9..6014231ab54 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java @@ -37,7 +37,7 @@ public class ReverseOperatorConversion implements SqlOperatorConversion .operatorBuilder("REVERSE") .operandTypes(SqlTypeFamily.CHARACTER) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java index 863bbccd557..5f454a5f980 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java @@ -39,7 +39,7 @@ public class RightOperatorConversion implements SqlOperatorConversion .operatorBuilder("RIGHT") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java index b2aabbb2d11..133d6226dc8 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java @@ -42,7 +42,7 @@ public class StringFormatOperatorConversion implements SqlOperatorConversion .operatorBuilder("STRING_FORMAT") .operandTypeChecker(new StringFormatOperandTypeChecker()) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java index e18c0896a5d..c36405f0662 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java @@ -38,7 +38,7 @@ public class StrposOperatorConversion implements SqlOperatorConversion .operatorBuilder("STRPOS") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.INTEGER) + .returnTypeCascadeNullable(SqlTypeName.INTEGER) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java index ee160d6b3ef..c44375c131d 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java @@ -36,7 +36,7 @@ public class TextcatOperatorConversion implements SqlOperatorConversion .operatorBuilder("textcat") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .requiredOperands(2) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java index 81b2dfa12ae..359612c0852 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java @@ -41,7 +41,7 @@ public class TimeCeilOperatorConversion implements SqlOperatorConversion .operatorBuilder("TIME_CEIL") .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER) .requiredOperands(2) - .returnTypeNonNull(SqlTypeName.TIMESTAMP) + .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java index 35accd1f9b3..000923c4fd6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java @@ -44,7 +44,7 @@ public class TimeExtractOperatorConversion implements SqlOperatorConversion .operatorBuilder("TIME_EXTRACT") .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .requiredOperands(2) - .returnTypeNonNull(SqlTypeName.BIGINT) + .returnTypeCascadeNullable(SqlTypeName.BIGINT) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java index 87c07f25b7e..20377a03aac 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java @@ -56,7 +56,7 @@ public class TimeFloorOperatorConversion implements SqlOperatorConversion .operatorBuilder("TIME_FLOOR") .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER) .requiredOperands(2) - .returnTypeNonNull(SqlTypeName.TIMESTAMP) + .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java index 1f7b6f95d32..e44734f84ba 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java @@ -47,7 +47,7 @@ public class TimeFormatOperatorConversion implements SqlOperatorConversion .operatorBuilder("TIME_FORMAT") .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .requiredOperands(1) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java index 25b05c40f1d..a4fd210aa4b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java @@ -45,7 +45,7 @@ public class TimeShiftOperatorConversion implements SqlOperatorConversion .operatorBuilder("TIME_SHIFT") .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) .requiredOperands(3) - .returnTypeNonNull(SqlTypeName.TIMESTAMP) + .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java index ae4565579fb..ece14e2dd63 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java @@ -39,7 +39,7 @@ public class TimestampToMillisOperatorConversion implements SqlOperatorConversio private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("TIMESTAMP_TO_MILLIS") .operandTypes(SqlTypeFamily.TIMESTAMP) - .returnTypeNonNull(SqlTypeName.BIGINT) + .returnTypeCascadeNullable(SqlTypeName.BIGINT) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index f724cd0797e..5014a35cca7 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -17562,4 +17562,53 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build()), ImmutableList.of(new Object[]{6L})); } + + @Test + public void testExpressionCounts() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT\n" + + " COUNT(reverse(dim2)),\n" + + " COUNT(left(dim2, 5)),\n" + + " COUNT(strpos(dim2, 'a'))\n" + + "FROM druid.numfoo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .virtualColumns( + expressionVirtualColumn("v0", "reverse(\"dim2\")", ValueType.STRING), + expressionVirtualColumn("v1", "left(\"dim2\",5)", ValueType.STRING), + expressionVirtualColumn("v2", "(strpos(\"dim2\",'a') + 1)", ValueType.LONG) + ) + .aggregators( + aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + not(selector("v0", null, null)) + ), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a1"), + not(selector("v1", null, null)) + ), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a2"), + not(selector("v2", null, null)) + ) + ) + ) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + // in default mode strpos is 6 because the '+ 1' of the expression (no null numbers in + // default mode so is 0 + 1 for null rows) + ? new Object[]{3L, 3L, 6L} + : new Object[]{4L, 4L, 4L} + ) + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java index 0268bb636d8..5f70dc5902a 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java @@ -31,12 +31,14 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlTypeFactoryImpl; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker; +import org.apache.druid.sql.calcite.planner.DruidTypeSystem; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -275,6 +277,69 @@ public class OperatorConversionsTest ); } + @Test + public void testNullForNullableOperandNonNullOutput() + { + SqlFunction function = OperatorConversions + .operatorBuilder("testNullForNullableNonnull") + .operandTypes(SqlTypeFamily.CHARACTER) + .requiredOperands(1) + .returnTypeNonNull(SqlTypeName.CHAR) + .build(); + SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); + SqlCallBinding binding = mockCallBinding( + function, + ImmutableList.of( + new OperandSpec(SqlTypeName.CHAR, false, true) + ) + ); + Assert.assertTrue(typeChecker.checkOperandTypes(binding, true)); + RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding); + Assert.assertFalse(returnType.isNullable()); + } + + @Test + public void testNullForNullableOperandCascadeNullOutput() + { + SqlFunction function = OperatorConversions + .operatorBuilder("testNullForNullableCascade") + .operandTypes(SqlTypeFamily.CHARACTER) + .requiredOperands(1) + .returnTypeCascadeNullable(SqlTypeName.CHAR) + .build(); + SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); + SqlCallBinding binding = mockCallBinding( + function, + ImmutableList.of( + new OperandSpec(SqlTypeName.CHAR, false, true) + ) + ); + Assert.assertTrue(typeChecker.checkOperandTypes(binding, true)); + RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding); + Assert.assertTrue(returnType.isNullable()); + } + + @Test + public void testNullForNullableOperandAlwaysNullableOutput() + { + SqlFunction function = OperatorConversions + .operatorBuilder("testNullForNullableNonnull") + .operandTypes(SqlTypeFamily.CHARACTER) + .requiredOperands(1) + .returnTypeNullable(SqlTypeName.CHAR) + .build(); + SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); + SqlCallBinding binding = mockCallBinding( + function, + ImmutableList.of( + new OperandSpec(SqlTypeName.CHAR, false, false) + ) + ); + Assert.assertTrue(typeChecker.checkOperandTypes(binding, true)); + RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding); + Assert.assertTrue(returnType.isNullable()); + } + @Test public void testNullForNonNullableOperand() { @@ -359,6 +424,7 @@ public class OperatorConversionsTest ) { SqlValidator validator = Mockito.mock(SqlValidator.class); + Mockito.when(validator.getTypeFactory()).thenReturn(new SqlTypeFactoryImpl(DruidTypeSystem.INSTANCE)); List operands = new ArrayList<>(actualOperands.size()); for (OperandSpec operand : actualOperands) { final SqlNode node; @@ -368,6 +434,12 @@ public class OperatorConversionsTest node = Mockito.mock(SqlNode.class); } RelDataType relDataType = Mockito.mock(RelDataType.class); + + if (operand.isNullable) { + Mockito.when(relDataType.isNullable()).thenReturn(true); + } else { + Mockito.when(relDataType.isNullable()).thenReturn(false); + } Mockito.when(validator.deriveType(ArgumentMatchers.any(), ArgumentMatchers.eq(node))) .thenReturn(relDataType); Mockito.when(relDataType.getSqlTypeName()).thenReturn(operand.type); @@ -394,11 +466,18 @@ public class OperatorConversionsTest { private final SqlTypeName type; private final boolean isLiteral; + private final boolean isNullable; private OperandSpec(SqlTypeName type, boolean isLiteral) + { + this(type, isLiteral, type == SqlTypeName.NULL); + } + + private OperandSpec(SqlTypeName type, boolean isLiteral, boolean isNullable) { this.type = type; this.isLiteral = isLiteral; + this.isNullable = isNullable; } } }