Improve the DruidRexExecutor w.r.t handling of numeric arrays (#11968)

DruidRexExecutor while reducing Arrays, specially numeric arrays, doesn't convert the value from ExprResult's type to BigDecimal, which causes makeLiteral to cast the values. Also, if NaN or Infinite values are present in the array, the error is a generic NumberFormatException. For example:

SELECT ARRAY[1.11, 2.22] returns [1, 2]
SELECT SQRT(-1) throws a generic NumberFormatException instead of IAE

This PR introduces change to cast the numeric values to BigDecimal since Calcite's library understands that easily, and doesn't perform casts.
This commit is contained in:
Laksh Singla 2021-11-23 11:40:59 +05:30 committed by GitHub
parent ed0606db69
commit b5a25f24f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 89 additions and 21 deletions

View File

@ -39,6 +39,7 @@ import org.apache.druid.sql.calcite.table.RowSignatures;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
/**
* A Calcite {@code RexExecutor} that reduces Calcite expressions by evaluating them using Druid's own built-in
@ -145,7 +146,31 @@ public class DruidRexExecutor implements RexExecutor
}
} else if (sqlTypeName == SqlTypeName.ARRAY) {
assert exprResult.isArray();
literal = rexBuilder.makeLiteral(Arrays.asList(exprResult.asArray()), constExp.getType(), true);
if (SqlTypeName.NUMERIC_TYPES.contains(constExp.getType().getComponentType().getSqlTypeName())) {
if (exprResult.type().getElementType().is(ExprType.LONG)) {
List<BigDecimal> resultAsBigDecimalList = Arrays.stream(exprResult.asLongArray())
.map(BigDecimal::valueOf)
.collect(Collectors.toList());
literal = rexBuilder.makeLiteral(resultAsBigDecimalList, constExp.getType(), true);
} else {
List<BigDecimal> resultAsBigDecimalList = Arrays.stream(exprResult.asDoubleArray()).map(
doubleVal -> {
if (Double.isNaN(doubleVal) || Double.isInfinite(doubleVal)) {
String expression = druidExpression.getExpression();
throw new IAE(
"'%s' contains an element that evaluates to '%s' which is not supported in SQL. You can either cast the element in the array to bigint or char or change the expression itself",
expression,
Double.toString(doubleVal)
);
}
return BigDecimal.valueOf(doubleVal);
}
).collect(Collectors.toList());
literal = rexBuilder.makeLiteral(resultAsBigDecimalList, constExp.getType(), true);
}
} else {
literal = rexBuilder.makeLiteral(Arrays.asList(exprResult.asArray()), constExp.getType(), true);
}
} else if (sqlTypeName == SqlTypeName.OTHER && constExp.getType() instanceof RowSignatures.ComplexSqlType) {
// complex constant is not reducible, so just leave it as an expression
literal = constExp;

View File

@ -194,8 +194,7 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
@Test
public void testSomeArrayFunctionsWithScanQuery() throws Exception
{
// array constructor turns decimals into ints for some reason, this needs fixed in the future
// also, yes these outputs are strange sometimes, arrays are in a partial state of existence so end up a bit
// Yes these outputs are strange sometimes, arrays are in a partial state of existence so end up a bit
// stringy for now this is because virtual column selectors are coercing values back to stringish so that
// multi-valued string dimensions can be grouped on.
List<Object[]> expectedResults;
@ -211,13 +210,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
0.0,
"[\"a\",\"b\",\"c\"]",
"[1,2,3]",
"[1,2,4]",
"[1.9,2.2,4.3]",
"[\"a\",\"b\",\"foo\"]",
"[\"foo\",\"a\"]",
"[1,2,7]",
"[0,1,2]",
"[1,2,1]",
"[0,1,2]",
"[1.2,2.2,1.0]",
"[0.0,1.1,2.2]",
"[\"a\",\"a\",\"b\"]",
"[7,0]",
"[1.0,0.0]",
@ -239,13 +238,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
null,
"[\"a\",\"b\",\"c\"]",
"[1,2,3]",
"[1,2,4]",
"[1.9,2.2,4.3]",
"[\"a\",\"b\",\"foo\"]",
"[\"foo\",\"a\"]",
"[1,2,7]",
"[null,1,2]",
"[1,2,1]",
"[null,1,2]",
"[1.2,2.2,1.0]",
"[null,1.1,2.2]",
"[\"a\",\"a\",\"b\"]",
"[7,null]",
"[1.0,null]",
@ -297,13 +296,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
expressionVirtualColumn("v13", "array_offset(array(\"d1\"),0)", ColumnType.STRING),
expressionVirtualColumn("v14", "array_ordinal(array(\"l1\"),1)", ColumnType.STRING),
expressionVirtualColumn("v15", "array_ordinal(array(\"d1\"),1)", ColumnType.STRING),
expressionVirtualColumn("v2", "array(1,2,4)", ColumnType.STRING),
expressionVirtualColumn("v2", "array(1.9,2.2,4.3)", ColumnType.STRING),
expressionVirtualColumn("v3", "array_append(\"dim3\",'foo')", ColumnType.STRING),
expressionVirtualColumn("v4", "array_prepend('foo',array(\"dim2\"))", ColumnType.STRING),
expressionVirtualColumn("v5", "array_append(array(1,2),\"l1\")", ColumnType.STRING),
expressionVirtualColumn("v6", "array_prepend(\"l2\",array(1,2))", ColumnType.STRING),
expressionVirtualColumn("v7", "array_append(array(1,2),\"d1\")", ColumnType.STRING),
expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1,2))", ColumnType.STRING),
expressionVirtualColumn("v7", "array_append(array(1.2,2.2),\"d1\")", ColumnType.STRING),
expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1.1,2.2))", ColumnType.STRING),
expressionVirtualColumn("v9", "array_concat(\"dim2\",\"dim3\")", ColumnType.STRING)
)
.columns(
@ -357,13 +356,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
"[\"a\",\"b\"]",
Arrays.asList("a", "b", "c"),
Arrays.asList(1L, 2L, 3L),
Arrays.asList(1L, 2L, 4L),
Arrays.asList(1.9, 2.2, 4.3),
"[\"a\",\"b\",\"foo\"]",
Arrays.asList("foo", "a"),
Arrays.asList(1L, 2L, 7L),
Arrays.asList(0L, 1L, 2L),
Arrays.asList(1L, 2L, 1L),
Arrays.asList(0L, 1L, 2L),
Arrays.asList(1.2, 2.2, 1.0),
Arrays.asList(0.0, 1.1, 2.2),
"[\"a\",\"a\",\"b\"]",
Arrays.asList(7L, 0L),
Arrays.asList(1.0, 0.0)
@ -377,13 +376,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
"[\"a\",\"b\"]",
Arrays.asList("a", "b", "c"),
Arrays.asList(1L, 2L, 3L),
Arrays.asList(1L, 2L, 4L),
Arrays.asList(1.9, 2.2, 4.3),
"[\"a\",\"b\",\"foo\"]",
Arrays.asList("foo", "a"),
Arrays.asList(1L, 2L, 7L),
Arrays.asList(null, 1L, 2L),
Arrays.asList(1L, 2L, 1L),
Arrays.asList(null, 1L, 2L),
Arrays.asList(1.2, 2.2, 1.0),
Arrays.asList(null, 1.1, 2.2),
"[\"a\",\"a\",\"b\"]",
Arrays.asList(7L, null),
Arrays.asList(1.0, null)
@ -420,13 +419,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
expressionVirtualColumn("v1", "array(1,2,3)", ColumnType.STRING),
expressionVirtualColumn("v10", "array_concat(array(\"l1\"),array(\"l2\"))", ColumnType.STRING),
expressionVirtualColumn("v11", "array_concat(array(\"d1\"),array(\"d2\"))", ColumnType.STRING),
expressionVirtualColumn("v2", "array(1,2,4)", ColumnType.STRING),
expressionVirtualColumn("v2", "array(1.9,2.2,4.3)", ColumnType.STRING),
expressionVirtualColumn("v3", "array_append(\"dim3\",'foo')", ColumnType.STRING),
expressionVirtualColumn("v4", "array_prepend('foo',array(\"dim2\"))", ColumnType.STRING),
expressionVirtualColumn("v5", "array_append(array(1,2),\"l1\")", ColumnType.STRING),
expressionVirtualColumn("v6", "array_prepend(\"l2\",array(1,2))", ColumnType.STRING),
expressionVirtualColumn("v7", "array_append(array(1,2),\"d1\")", ColumnType.STRING),
expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1,2))", ColumnType.STRING),
expressionVirtualColumn("v7", "array_append(array(1.2,2.2),\"d1\")", ColumnType.STRING),
expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1.1,2.2))", ColumnType.STRING),
expressionVirtualColumn("v9", "array_concat(\"dim2\",\"dim3\")", ColumnType.STRING)
)
.columns(

View File

@ -32,6 +32,8 @@ import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.BasicSqlType;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeFamily;
@ -136,4 +138,46 @@ public class DruidRexExecutorTest extends InitializedNullHandlingTest
)
);
}
@Test
public void testArrayOfDoublesReduction()
{
DruidRexExecutor rexy = new DruidRexExecutor(PLANNER_CONTEXT);
List<RexNode> reduced = new ArrayList<>();
BasicSqlType basicSqlType = new BasicSqlType(DruidTypeSystem.INSTANCE, SqlTypeName.DECIMAL);
ArraySqlType arraySqlType = new ArraySqlType(basicSqlType, false);
List<BigDecimal> elements = ImmutableList.of(BigDecimal.valueOf(50.12), BigDecimal.valueOf(12.1));
RexNode literal = rexBuilder.makeLiteral(elements, arraySqlType, true);
rexy.reduce(rexBuilder, ImmutableList.of(literal), reduced);
Assert.assertEquals(1, reduced.size());
Assert.assertEquals(
DruidExpression.fromExpression("array(50.12,12.1)"),
Expressions.toDruidExpression(
PLANNER_CONTEXT,
RowSignature.empty(),
reduced.get(0)
)
);
}
@Test
public void testArrayOfLongsReduction()
{
DruidRexExecutor rexy = new DruidRexExecutor(PLANNER_CONTEXT);
List<RexNode> reduced = new ArrayList<>();
BasicSqlType basicSqlType = new BasicSqlType(DruidTypeSystem.INSTANCE, SqlTypeName.INTEGER);
ArraySqlType arraySqlType = new ArraySqlType(basicSqlType, false);
List<BigDecimal> elements = ImmutableList.of(BigDecimal.valueOf(50), BigDecimal.valueOf(12));
RexNode literal = rexBuilder.makeLiteral(elements, arraySqlType, true);
rexy.reduce(rexBuilder, ImmutableList.of(literal), reduced);
Assert.assertEquals(1, reduced.size());
Assert.assertEquals(
DruidExpression.fromExpression("array(50,12)"),
Expressions.toDruidExpression(
PLANNER_CONTEXT,
RowSignature.empty(),
reduced.get(0)
)
);
}
}