fix a bugs related to SQL type inference return type nullability (#11327)

* fix a bunch of type inference nullability bugs

* fixes

* style

* fix test

* fix concat
This commit is contained in:
Clint Wylie 2021-06-15 12:26:59 -07:00 committed by GitHub
parent 920aa414ca
commit bfbd7ec432
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 232 additions and 54 deletions

View File

@ -2721,6 +2721,9 @@ public interface Function
@Override
protected ExprEval eval(String x, int y)
{
if (x == null) {
return ExprEval.of(null);
}
return ExprEval.of(y < 1 ? NullHandling.defaultStringValue() : StringUtils.repeat(x, y));
}
}

View File

@ -596,6 +596,16 @@ public class FunctionTest extends InitializedNullHandlingTest
assertExpr("bitwiseConvertDoubleToLongBits(null)", null);
}
@Test
public void testRepeat()
{
assertExpr("repeat('hello', 2)", "hellohello");
assertExpr("repeat('hello', -1)", null);
assertExpr("repeat(null, 10)", null);
assertExpr("repeat(nonexistent, 10)", null);
}
private void assertExpr(final String expression, @Nullable final Object expectedResult)
{
final Expr expr = Parser.parse(expression, ExprMacroTable.nil());

View File

@ -303,7 +303,7 @@ columns in this mode are not nullable; any null or missing values will be treate
In SQL compatible mode (`false`), NULLs are treated more closely to the SQL standard. The property affects both storage
and querying, so for correct behavior, it should be set on all Druid service types to be available at both ingestion
time and query time. There is some overhead associated with the ability to handle NULLs; see
the [segment internals](../design/segments.md#sql-compatible-null-handling)documentation for more details.
the [segment internals](../design/segments.md#sql-compatible-null-handling) documentation for more details.
## Aggregation functions

View File

@ -99,7 +99,11 @@ public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro
@Override
public ExprEval eval(final ObjectBinding bindings)
{
return ExprEval.of(chronology.add(period, args.get(0).eval(bindings).asLong(), step));
ExprEval timestamp = args.get(0).eval(bindings);
if (timestamp.isNumericNull()) {
return ExprEval.of(null);
}
return ExprEval.of(chronology.add(period, timestamp.asLong(), step));
}
@Override
@ -128,10 +132,14 @@ public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro
@Override
public ExprEval eval(final ObjectBinding bindings)
{
ExprEval timestamp = args.get(0).eval(bindings);
if (timestamp.isNumericNull()) {
return ExprEval.of(null);
}
final Period period = getPeriod(args, bindings);
final Chronology chronology = getTimeZone(args, bindings);
final int step = getStep(args, bindings);
return ExprEval.of(chronology.add(period, args.get(0).eval(bindings).asLong(), step));
return ExprEval.of(chronology.add(period, timestamp.asLong(), step));
}
@Override

View File

@ -20,6 +20,7 @@
package org.apache.druid.query.expression;
import com.google.common.collect.ImmutableList;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.math.expr.Expr;
@ -219,6 +220,24 @@ public class TimestampShiftMacroTest extends MacroTestBase
);
}
@Test
public void testNull()
{
Expr expr = apply(
ImmutableList.of(
ExprEval.ofLong(null).toExpr(),
ExprEval.of("P1M").toExpr(),
ExprEval.of(1L).toExpr()
)
);
if (NullHandling.replaceWithDefault()) {
Assert.assertEquals(2678400000L, expr.eval(ExprUtils.nilBindings()).value());
} else {
Assert.assertNull(expr.eval(ExprUtils.nilBindings()).value());
}
}
private static class NotLiteralExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
{
NotLiteralExpr(String name)

View File

@ -48,6 +48,7 @@ import org.apache.calcite.sql.type.SqlOperandTypeInference;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.apache.calcite.util.Static;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
@ -255,11 +256,12 @@ public class OperatorConversions
}
/**
* Sets the return type of the operator to "typeName", marked as non-nullable.
* Sets the return type of the operator to "typeName", marked as non-nullable. If this method is used it implies the
* operator should never, ever, return null.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or
* {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods
* cannot be mixed; you must call exactly one.
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
* calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeNonNull(final SqlTypeName typeName)
{
@ -274,9 +276,9 @@ public class OperatorConversions
/**
* Sets the return type of the operator to "typeName", marked as nullable.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or
* {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods
* cannot be mixed; you must call exactly one.
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
* calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeNullable(final SqlTypeName typeName)
{
@ -287,12 +289,27 @@ public class OperatorConversions
);
return this;
}
/**
* Sets the return type of the operator to "typeName", marked as nullable if any of its operands are nullable.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
* calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeCascadeNullable(final SqlTypeName typeName)
{
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
this.returnTypeInference = ReturnTypes.cascade(ReturnTypes.explicit(typeName), SqlTypeTransforms.TO_NULLABLE);
return this;
}
/**
* Sets the return type of the operator to an array type with elements of "typeName", marked as nullable.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or
* {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods
* cannot be mixed; you must call exactly one.
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
* calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeNullableArray(final SqlTypeName elementTypeName)
{
@ -308,9 +325,9 @@ public class OperatorConversions
/**
* Provides customized return type inference logic.
*
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or
* {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods
* cannot be mixed; you must call exactly one.
* One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
* {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
* calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeInference(final SqlReturnTypeInference returnTypeInference)
{

View File

@ -43,7 +43,7 @@ public class ArrayLengthOperatorConversion implements SqlOperatorConversion
)
)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNonNull(SqlTypeName.INTEGER)
.returnTypeCascadeNullable(SqlTypeName.INTEGER)
.build();
@Override

View File

@ -47,7 +47,7 @@ public class ArrayOffsetOfOperatorConversion implements SqlOperatorConversion
)
)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNonNull(SqlTypeName.INTEGER)
.returnTypeNullable(SqlTypeName.INTEGER)
.build();
@Override

View File

@ -47,7 +47,7 @@ public class ArrayOrdinalOfOperatorConversion implements SqlOperatorConversion
)
)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNonNull(SqlTypeName.INTEGER)
.returnTypeCascadeNullable(SqlTypeName.INTEGER)
.build();
@Override

View File

@ -47,7 +47,7 @@ public class ArrayToStringOperatorConversion implements SqlOperatorConversion
)
)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.build();
@Override

View File

@ -37,7 +37,7 @@ public class BTrimOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("BTRIM")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1)
.build();

View File

@ -22,29 +22,22 @@ package org.apache.druid.sql.calcite.expression.builtin;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.expression.SqlOperatorConversion;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
public class ConcatOperatorConversion implements SqlOperatorConversion
{
private static final SqlFunction SQL_FUNCTION = new SqlFunction(
"CONCAT",
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(
factory -> Calcites.createSqlType(factory, SqlTypeName.VARCHAR)
),
null,
OperandTypes.SAME_VARIADIC,
SqlFunctionCategory.STRING
);
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("CONCAT")
.operandTypeChecker(OperandTypes.SAME_VARIADIC)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.build();
@Override
public SqlFunction calciteOperator()

View File

@ -67,7 +67,7 @@ public class DateTruncOperatorConversion implements SqlOperatorConversion
.operatorBuilder("DATE_TRUNC")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP)
.requiredOperands(2)
.returnTypeNonNull(SqlTypeName.TIMESTAMP)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -37,7 +37,7 @@ public class LPadOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("LPAD")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(2)
.build();

View File

@ -37,7 +37,7 @@ public class LTrimOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("LTRIM")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1)
.build();

View File

@ -39,7 +39,7 @@ public class LeftOperatorConversion implements SqlOperatorConversion
.operatorBuilder("LEFT")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.build();
@Override

View File

@ -39,7 +39,7 @@ public class MillisToTimestampOperatorConversion implements SqlOperatorConversio
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("MILLIS_TO_TIMESTAMP")
.operandTypes(SqlTypeFamily.EXACT_NUMERIC)
.returnTypeNonNull(SqlTypeName.TIMESTAMP)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -38,7 +38,7 @@ public class ParseLongOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(NAME)
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
.returnTypeNonNull(SqlTypeName.BIGINT)
.returnTypeCascadeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1)
.build();

View File

@ -37,7 +37,7 @@ public class RPadOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("RPAD")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(2)
.build();

View File

@ -37,7 +37,7 @@ public class RTrimOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("RTRIM")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1)
.build();

View File

@ -39,7 +39,7 @@ public class RepeatOperatorConversion implements SqlOperatorConversion
.operatorBuilder("REPEAT")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.build();
@Override

View File

@ -37,7 +37,7 @@ public class ReverseOperatorConversion implements SqlOperatorConversion
.operatorBuilder("REVERSE")
.operandTypes(SqlTypeFamily.CHARACTER)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.build();
@Override

View File

@ -39,7 +39,7 @@ public class RightOperatorConversion implements SqlOperatorConversion
.operatorBuilder("RIGHT")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.build();
@Override

View File

@ -42,7 +42,7 @@ public class StringFormatOperatorConversion implements SqlOperatorConversion
.operatorBuilder("STRING_FORMAT")
.operandTypeChecker(new StringFormatOperandTypeChecker())
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.build();
@Override

View File

@ -38,7 +38,7 @@ public class StrposOperatorConversion implements SqlOperatorConversion
.operatorBuilder("STRPOS")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeNonNull(SqlTypeName.INTEGER)
.returnTypeCascadeNullable(SqlTypeName.INTEGER)
.build();
@Override

View File

@ -36,7 +36,7 @@ public class TextcatOperatorConversion implements SqlOperatorConversion
.operatorBuilder("textcat")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(2)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.build();

View File

@ -41,7 +41,7 @@ public class TimeCeilOperatorConversion implements SqlOperatorConversion
.operatorBuilder("TIME_CEIL")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)
.requiredOperands(2)
.returnTypeNonNull(SqlTypeName.TIMESTAMP)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -44,7 +44,7 @@ public class TimeExtractOperatorConversion implements SqlOperatorConversion
.operatorBuilder("TIME_EXTRACT")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(2)
.returnTypeNonNull(SqlTypeName.BIGINT)
.returnTypeCascadeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -56,7 +56,7 @@ public class TimeFloorOperatorConversion implements SqlOperatorConversion
.operatorBuilder("TIME_FLOOR")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)
.requiredOperands(2)
.returnTypeNonNull(SqlTypeName.TIMESTAMP)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -47,7 +47,7 @@ public class TimeFormatOperatorConversion implements SqlOperatorConversion
.operatorBuilder("TIME_FORMAT")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.VARCHAR)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -45,7 +45,7 @@ public class TimeShiftOperatorConversion implements SqlOperatorConversion
.operatorBuilder("TIME_SHIFT")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.requiredOperands(3)
.returnTypeNonNull(SqlTypeName.TIMESTAMP)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -39,7 +39,7 @@ public class TimestampToMillisOperatorConversion implements SqlOperatorConversio
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIMESTAMP_TO_MILLIS")
.operandTypes(SqlTypeFamily.TIMESTAMP)
.returnTypeNonNull(SqlTypeName.BIGINT)
.returnTypeCascadeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -17562,4 +17562,53 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()),
ImmutableList.of(new Object[]{6L}));
}
@Test
public void testExpressionCounts() throws Exception
{
cannotVectorize();
testQuery(
"SELECT\n"
+ " COUNT(reverse(dim2)),\n"
+ " COUNT(left(dim2, 5)),\n"
+ " COUNT(strpos(dim2, 'a'))\n"
+ "FROM druid.numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn("v0", "reverse(\"dim2\")", ValueType.STRING),
expressionVirtualColumn("v1", "left(\"dim2\",5)", ValueType.STRING),
expressionVirtualColumn("v2", "(strpos(\"dim2\",'a') + 1)", ValueType.LONG)
)
.aggregators(
aggregators(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
not(selector("v0", null, null))
),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a1"),
not(selector("v1", null, null))
),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a2"),
not(selector("v2", null, null))
)
)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
useDefault
// in default mode strpos is 6 because the '+ 1' of the expression (no null numbers in
// default mode so is 0 + 1 for null rows)
? new Object[]{3L, 3L, 6L}
: new Object[]{4L, 4L, 4L}
)
);
}
}

View File

@ -31,12 +31,14 @@ import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.planner.DruidTypeSystem;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
@ -275,6 +277,69 @@ public class OperatorConversionsTest
);
}
@Test
public void testNullForNullableOperandNonNullOutput()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableNonnull")
.operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
SqlCallBinding binding = mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.CHAR, false, true)
)
);
Assert.assertTrue(typeChecker.checkOperandTypes(binding, true));
RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding);
Assert.assertFalse(returnType.isNullable());
}
@Test
public void testNullForNullableOperandCascadeNullOutput()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableCascade")
.operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.returnTypeCascadeNullable(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
SqlCallBinding binding = mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.CHAR, false, true)
)
);
Assert.assertTrue(typeChecker.checkOperandTypes(binding, true));
RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding);
Assert.assertTrue(returnType.isNullable());
}
@Test
public void testNullForNullableOperandAlwaysNullableOutput()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableNonnull")
.operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.returnTypeNullable(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
SqlCallBinding binding = mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.CHAR, false, false)
)
);
Assert.assertTrue(typeChecker.checkOperandTypes(binding, true));
RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding);
Assert.assertTrue(returnType.isNullable());
}
@Test
public void testNullForNonNullableOperand()
{
@ -359,6 +424,7 @@ public class OperatorConversionsTest
)
{
SqlValidator validator = Mockito.mock(SqlValidator.class);
Mockito.when(validator.getTypeFactory()).thenReturn(new SqlTypeFactoryImpl(DruidTypeSystem.INSTANCE));
List<SqlNode> operands = new ArrayList<>(actualOperands.size());
for (OperandSpec operand : actualOperands) {
final SqlNode node;
@ -368,6 +434,12 @@ public class OperatorConversionsTest
node = Mockito.mock(SqlNode.class);
}
RelDataType relDataType = Mockito.mock(RelDataType.class);
if (operand.isNullable) {
Mockito.when(relDataType.isNullable()).thenReturn(true);
} else {
Mockito.when(relDataType.isNullable()).thenReturn(false);
}
Mockito.when(validator.deriveType(ArgumentMatchers.any(), ArgumentMatchers.eq(node)))
.thenReturn(relDataType);
Mockito.when(relDataType.getSqlTypeName()).thenReturn(operand.type);
@ -394,11 +466,18 @@ public class OperatorConversionsTest
{
private final SqlTypeName type;
private final boolean isLiteral;
private final boolean isNullable;
private OperandSpec(SqlTypeName type, boolean isLiteral)
{
this(type, isLiteral, type == SqlTypeName.NULL);
}
private OperandSpec(SqlTypeName type, boolean isLiteral, boolean isNullable)
{
this.type = type;
this.isLiteral = isLiteral;
this.isNullable = isNullable;
}
}
}