Fixes for safe_divide with vectorize and datatypes (#15839)

* Fix for save_divide with vectorize

* More fixes

* Update to use expr.eval(null) for both cases when denominator is 0
This commit is contained in:
Soumyava 2024-02-08 01:10:42 -08:00 committed by GitHub
parent 1a5b57df84
commit f3996b96ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 91 additions and 12 deletions

View File

@ -1173,15 +1173,18 @@ public interface Function extends NamedFunction
); );
} }
@Override
public boolean canVectorize(Expr.InputBindingInspector inspector, List<Expr> args)
{
return false;
}
@Override @Override
protected ExprEval eval(final long x, final long y) protected ExprEval eval(final long x, final long y)
{ {
if (y == 0) { if (y == 0) {
if (x != 0) {
return ExprEval.ofLong(null); return ExprEval.ofLong(null);
} }
return ExprEval.ofLong(0);
}
return ExprEval.ofLong(x / y); return ExprEval.ofLong(x / y);
} }

View File

@ -857,11 +857,14 @@ public class FunctionTest extends InitializedNullHandlingTest
assertExpr("safe_divide(4.5, 2)", 2.25); assertExpr("safe_divide(4.5, 2)", 2.25);
assertExpr("safe_divide(3, 0)", null); assertExpr("safe_divide(3, 0)", null);
assertExpr("safe_divide(1, 0.0)", null); assertExpr("safe_divide(1, 0.0)", null);
// NaN and Infinity cases // NaN, Infinity and other weird cases
assertExpr("safe_divide(NaN, 0.0)", null); assertExpr("safe_divide(NaN, 0.0)", null);
assertExpr("safe_divide(0, NaN)", 0.0); assertExpr("safe_divide(0, NaN)", 0.0);
assertExpr("safe_divide(0, POSITIVE_INFINITY)", NullHandling.defaultLongValue()); assertExpr("safe_divide(0, maxLong)", 0L);
assertExpr("safe_divide(POSITIVE_INFINITY,0)", NullHandling.defaultLongValue()); assertExpr("safe_divide(maxLong,0)", null);
assertExpr("safe_divide(0.0, inf)", 0.0);
assertExpr("safe_divide(0.0, -inf)", -0.0);
assertExpr("safe_divide(0,0)", null);
} }
@Test @Test

View File

@ -22,8 +22,9 @@ package org.apache.druid.sql.calcite.expression.builtin;
import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Function; import org.apache.druid.math.expr.Function;
import org.apache.druid.sql.calcite.expression.DirectOperatorConversion; import org.apache.druid.sql.calcite.expression.DirectOperatorConversion;
@ -33,9 +34,10 @@ public class SafeDivideOperatorConversion extends DirectOperatorConversion
{ {
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(Function.SafeDivide.NAME)) .operatorBuilder(StringUtils.toUpperCase(Function.SafeDivide.NAME))
.operandTypeChecker(OperandTypes.ANY_NUMERIC) .operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)
.returnTypeInference(ReturnTypes.QUOTIENT_NULLABLE) .returnTypeInference(ReturnTypes.LEAST_RESTRICTIVE.andThen(SqlTypeTransforms.FORCE_NULLABLE))
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION) .functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.requiredOperandCount(2)
.build(); .build();
public SafeDivideOperatorConversion() public SafeDivideOperatorConversion()

View File

@ -571,6 +571,31 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
); );
} }
@Test
public void testSafeDivide()
{
skipVectorize();
cannotVectorize();
final Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
testQuery(
"select count(*) c from foo where ((floor(safe_divide(cast(cast(m1 as char) as bigint), 2))) = 0)",
context,
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(expressionVirtualColumn("v0", "floor(safe_divide(CAST(CAST(\"m1\", 'STRING'), 'LONG'),2))", ColumnType.LONG))
.filters(equality("v0", 0L, ColumnType.LONG))
.granularity(Granularities.ALL)
.aggregators(new CountAggregatorFactory("a0"))
.context(context)
.build()
),
ImmutableList.of(new Object[]{1L})
);
}
@Test @Test
public void testGroupByLimitWrappingOrderByAgg() public void testGroupByLimitWrappingOrderByAgg()
{ {

View File

@ -482,6 +482,52 @@ public class CalciteSelectQueryTest extends BaseCalciteQueryTest
); );
} }
@Test
public void testSafeDivideWithoutTable()
{
skipVectorize();
cannotVectorize();
final Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
testQuery(
"select SAFE_DIVIDE(0, 0), SAFE_DIVIDE(1,0), SAFE_DIVIDE(10,2.5), "
+ " SAFE_DIVIDE(10.5,3.5), SAFE_DIVIDE(10.5,3), SAFE_DIVIDE(10,2)",
context,
ImmutableList.of(
Druids.newScanQueryBuilder()
.dataSource(
InlineDataSource.fromIterable(
ImmutableList.of(
new Object[]{0L}
),
RowSignature.builder().add("ZERO", ColumnType.LONG).build()
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("v0", "v1", "v2", "v3", "v4")
.virtualColumns(
expressionVirtualColumn("v0", NullHandling.sqlCompatible() ? "null" : "0", ColumnType.LONG),
expressionVirtualColumn("v1", "4.0", ColumnType.DOUBLE),
expressionVirtualColumn("v2", "3.0", ColumnType.DOUBLE),
expressionVirtualColumn("v3", "3.5", ColumnType.DOUBLE),
expressionVirtualColumn("v4", "5", ColumnType.LONG)
)
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.legacy(false)
.context(context)
.build()
),
ImmutableList.of(new Object[]{
NullHandling.sqlCompatible() ? null : 0,
NullHandling.sqlCompatible() ? null : 0,
4.0D,
3.0D,
3.5D,
5
})
);
}
@Test @Test
public void testSafeDivideExpressions() public void testSafeDivideExpressions()
{ {
@ -498,8 +544,8 @@ public class CalciteSelectQueryTest extends BaseCalciteQueryTest
} else { } else {
expected = ImmutableList.of( expected = ImmutableList.of(
new Object[]{null, null, null, 7.0F}, new Object[]{null, null, null, 7.0F},
new Object[]{1.0F, 1L, 1.0, 3253230.0F}, new Object[]{1.0F, 1L, 1.0D, 3253230.0F},
new Object[]{0.0F, 0L, 0.0, 0.0F}, new Object[]{0.0F, null, 0.0D, 0.0F},
new Object[]{null, null, null, null}, new Object[]{null, null, null, null},
new Object[]{null, null, null, null}, new Object[]{null, null, null, null},
new Object[]{null, null, null, null} new Object[]{null, null, null, null}