diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java b/processing/src/main/java/org/apache/druid/math/expr/Function.java index 365a53d3362..2c8a26759f3 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Function.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java @@ -1173,14 +1173,17 @@ public interface Function extends NamedFunction ); } + @Override + public boolean canVectorize(Expr.InputBindingInspector inspector, List args) + { + return false; + } + @Override protected ExprEval eval(final long x, final long y) { if (y == 0) { - if (x != 0) { - return ExprEval.ofLong(null); - } - return ExprEval.ofLong(0); + return ExprEval.ofLong(null); } return ExprEval.ofLong(x / y); } diff --git a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java index 670dbe93e1f..0338efa4664 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -857,11 +857,14 @@ public class FunctionTest extends InitializedNullHandlingTest assertExpr("safe_divide(4.5, 2)", 2.25); assertExpr("safe_divide(3, 0)", null); assertExpr("safe_divide(1, 0.0)", null); - // NaN and Infinity cases + // NaN, Infinity and other weird cases assertExpr("safe_divide(NaN, 0.0)", null); assertExpr("safe_divide(0, NaN)", 0.0); - assertExpr("safe_divide(0, POSITIVE_INFINITY)", NullHandling.defaultLongValue()); - assertExpr("safe_divide(POSITIVE_INFINITY,0)", NullHandling.defaultLongValue()); + assertExpr("safe_divide(0, maxLong)", 0L); + assertExpr("safe_divide(maxLong,0)", null); + assertExpr("safe_divide(0.0, inf)", 0.0); + assertExpr("safe_divide(0.0, -inf)", -0.0); + assertExpr("safe_divide(0,0)", null); } @Test diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SafeDivideOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SafeDivideOperatorConversion.java index dd09feefadd..13c715316bb 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SafeDivideOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SafeDivideOperatorConversion.java @@ -22,8 +22,9 @@ package org.apache.druid.sql.calcite.expression.builtin; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.Function; import org.apache.druid.sql.calcite.expression.DirectOperatorConversion; @@ -33,9 +34,10 @@ public class SafeDivideOperatorConversion extends DirectOperatorConversion { private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder(StringUtils.toUpperCase(Function.SafeDivide.NAME)) - .operandTypeChecker(OperandTypes.ANY_NUMERIC) - .returnTypeInference(ReturnTypes.QUOTIENT_NULLABLE) + .operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC) + .returnTypeInference(ReturnTypes.LEAST_RESTRICTIVE.andThen(SqlTypeTransforms.FORCE_NULLABLE)) .functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION) + .requiredOperandCount(2) .build(); public SafeDivideOperatorConversion() diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index c4a57774ca0..2ce2fff0431 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -571,6 +571,31 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testSafeDivide() + { + skipVectorize(); + cannotVectorize(); + final Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + + testQuery( + "select count(*) c from foo where ((floor(safe_divide(cast(cast(m1 as char) as bigint), 2))) = 0)", + context, + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(expressionVirtualColumn("v0", "floor(safe_divide(CAST(CAST(\"m1\", 'STRING'), 'LONG'),2))", ColumnType.LONG)) + .filters(equality("v0", 0L, ColumnType.LONG)) + .granularity(Granularities.ALL) + .aggregators(new CountAggregatorFactory("a0")) + .context(context) + .build() + ), + ImmutableList.of(new Object[]{1L}) + ); + } + @Test public void testGroupByLimitWrappingOrderByAgg() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java index 2cf8296dfc0..6569b52a90a 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java @@ -482,6 +482,52 @@ public class CalciteSelectQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testSafeDivideWithoutTable() + { + skipVectorize(); + cannotVectorize(); + final Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + + testQuery( + "select SAFE_DIVIDE(0, 0), SAFE_DIVIDE(1,0), SAFE_DIVIDE(10,2.5), " + + " SAFE_DIVIDE(10.5,3.5), SAFE_DIVIDE(10.5,3), SAFE_DIVIDE(10,2)", + context, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource( + InlineDataSource.fromIterable( + ImmutableList.of( + new Object[]{0L} + ), + RowSignature.builder().add("ZERO", ColumnType.LONG).build() + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("v0", "v1", "v2", "v3", "v4") + .virtualColumns( + expressionVirtualColumn("v0", NullHandling.sqlCompatible() ? "null" : "0", ColumnType.LONG), + expressionVirtualColumn("v1", "4.0", ColumnType.DOUBLE), + expressionVirtualColumn("v2", "3.0", ColumnType.DOUBLE), + expressionVirtualColumn("v3", "3.5", ColumnType.DOUBLE), + expressionVirtualColumn("v4", "5", ColumnType.LONG) + ) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(context) + .build() + ), + ImmutableList.of(new Object[]{ + NullHandling.sqlCompatible() ? null : 0, + NullHandling.sqlCompatible() ? null : 0, + 4.0D, + 3.0D, + 3.5D, + 5 + }) + ); + } + @Test public void testSafeDivideExpressions() { @@ -498,8 +544,8 @@ public class CalciteSelectQueryTest extends BaseCalciteQueryTest } else { expected = ImmutableList.of( new Object[]{null, null, null, 7.0F}, - new Object[]{1.0F, 1L, 1.0, 3253230.0F}, - new Object[]{0.0F, 0L, 0.0, 0.0F}, + new Object[]{1.0F, 1L, 1.0D, 3253230.0F}, + new Object[]{0.0F, null, 0.0D, 0.0F}, new Object[]{null, null, null, null}, new Object[]{null, null, null, null}, new Object[]{null, null, null, null}