diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java index 1fae32de1fb..dda69f7aff1 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java @@ -39,6 +39,7 @@ import org.apache.druid.sql.calcite.table.RowSignatures; import java.math.BigDecimal; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; /** * A Calcite {@code RexExecutor} that reduces Calcite expressions by evaluating them using Druid's own built-in @@ -145,7 +146,31 @@ public class DruidRexExecutor implements RexExecutor } } else if (sqlTypeName == SqlTypeName.ARRAY) { assert exprResult.isArray(); - literal = rexBuilder.makeLiteral(Arrays.asList(exprResult.asArray()), constExp.getType(), true); + if (SqlTypeName.NUMERIC_TYPES.contains(constExp.getType().getComponentType().getSqlTypeName())) { + if (exprResult.type().getElementType().is(ExprType.LONG)) { + List resultAsBigDecimalList = Arrays.stream(exprResult.asLongArray()) + .map(BigDecimal::valueOf) + .collect(Collectors.toList()); + literal = rexBuilder.makeLiteral(resultAsBigDecimalList, constExp.getType(), true); + } else { + List resultAsBigDecimalList = Arrays.stream(exprResult.asDoubleArray()).map( + doubleVal -> { + if (Double.isNaN(doubleVal) || Double.isInfinite(doubleVal)) { + String expression = druidExpression.getExpression(); + throw new IAE( + "'%s' contains an element that evaluates to '%s' which is not supported in SQL. You can either cast the element in the array to bigint or char or change the expression itself", + expression, + Double.toString(doubleVal) + ); + } + return BigDecimal.valueOf(doubleVal); + } + ).collect(Collectors.toList()); + literal = rexBuilder.makeLiteral(resultAsBigDecimalList, constExp.getType(), true); + } + } else { + literal = rexBuilder.makeLiteral(Arrays.asList(exprResult.asArray()), constExp.getType(), true); + } } else if (sqlTypeName == SqlTypeName.OTHER && constExp.getType() instanceof RowSignatures.ComplexSqlType) { // complex constant is not reducible, so just leave it as an expression literal = constExp; diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index d4a9157ef30..61e41c72480 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -194,8 +194,7 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest @Test public void testSomeArrayFunctionsWithScanQuery() throws Exception { - // array constructor turns decimals into ints for some reason, this needs fixed in the future - // also, yes these outputs are strange sometimes, arrays are in a partial state of existence so end up a bit + // Yes these outputs are strange sometimes, arrays are in a partial state of existence so end up a bit // stringy for now this is because virtual column selectors are coercing values back to stringish so that // multi-valued string dimensions can be grouped on. List expectedResults; @@ -211,13 +210,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest 0.0, "[\"a\",\"b\",\"c\"]", "[1,2,3]", - "[1,2,4]", + "[1.9,2.2,4.3]", "[\"a\",\"b\",\"foo\"]", "[\"foo\",\"a\"]", "[1,2,7]", "[0,1,2]", - "[1,2,1]", - "[0,1,2]", + "[1.2,2.2,1.0]", + "[0.0,1.1,2.2]", "[\"a\",\"a\",\"b\"]", "[7,0]", "[1.0,0.0]", @@ -239,13 +238,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest null, "[\"a\",\"b\",\"c\"]", "[1,2,3]", - "[1,2,4]", + "[1.9,2.2,4.3]", "[\"a\",\"b\",\"foo\"]", "[\"foo\",\"a\"]", "[1,2,7]", "[null,1,2]", - "[1,2,1]", - "[null,1,2]", + "[1.2,2.2,1.0]", + "[null,1.1,2.2]", "[\"a\",\"a\",\"b\"]", "[7,null]", "[1.0,null]", @@ -297,13 +296,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest expressionVirtualColumn("v13", "array_offset(array(\"d1\"),0)", ColumnType.STRING), expressionVirtualColumn("v14", "array_ordinal(array(\"l1\"),1)", ColumnType.STRING), expressionVirtualColumn("v15", "array_ordinal(array(\"d1\"),1)", ColumnType.STRING), - expressionVirtualColumn("v2", "array(1,2,4)", ColumnType.STRING), + expressionVirtualColumn("v2", "array(1.9,2.2,4.3)", ColumnType.STRING), expressionVirtualColumn("v3", "array_append(\"dim3\",'foo')", ColumnType.STRING), expressionVirtualColumn("v4", "array_prepend('foo',array(\"dim2\"))", ColumnType.STRING), expressionVirtualColumn("v5", "array_append(array(1,2),\"l1\")", ColumnType.STRING), expressionVirtualColumn("v6", "array_prepend(\"l2\",array(1,2))", ColumnType.STRING), - expressionVirtualColumn("v7", "array_append(array(1,2),\"d1\")", ColumnType.STRING), - expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1,2))", ColumnType.STRING), + expressionVirtualColumn("v7", "array_append(array(1.2,2.2),\"d1\")", ColumnType.STRING), + expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1.1,2.2))", ColumnType.STRING), expressionVirtualColumn("v9", "array_concat(\"dim2\",\"dim3\")", ColumnType.STRING) ) .columns( @@ -357,13 +356,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest "[\"a\",\"b\"]", Arrays.asList("a", "b", "c"), Arrays.asList(1L, 2L, 3L), - Arrays.asList(1L, 2L, 4L), + Arrays.asList(1.9, 2.2, 4.3), "[\"a\",\"b\",\"foo\"]", Arrays.asList("foo", "a"), Arrays.asList(1L, 2L, 7L), Arrays.asList(0L, 1L, 2L), - Arrays.asList(1L, 2L, 1L), - Arrays.asList(0L, 1L, 2L), + Arrays.asList(1.2, 2.2, 1.0), + Arrays.asList(0.0, 1.1, 2.2), "[\"a\",\"a\",\"b\"]", Arrays.asList(7L, 0L), Arrays.asList(1.0, 0.0) @@ -377,13 +376,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest "[\"a\",\"b\"]", Arrays.asList("a", "b", "c"), Arrays.asList(1L, 2L, 3L), - Arrays.asList(1L, 2L, 4L), + Arrays.asList(1.9, 2.2, 4.3), "[\"a\",\"b\",\"foo\"]", Arrays.asList("foo", "a"), Arrays.asList(1L, 2L, 7L), Arrays.asList(null, 1L, 2L), - Arrays.asList(1L, 2L, 1L), - Arrays.asList(null, 1L, 2L), + Arrays.asList(1.2, 2.2, 1.0), + Arrays.asList(null, 1.1, 2.2), "[\"a\",\"a\",\"b\"]", Arrays.asList(7L, null), Arrays.asList(1.0, null) @@ -420,13 +419,13 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest expressionVirtualColumn("v1", "array(1,2,3)", ColumnType.STRING), expressionVirtualColumn("v10", "array_concat(array(\"l1\"),array(\"l2\"))", ColumnType.STRING), expressionVirtualColumn("v11", "array_concat(array(\"d1\"),array(\"d2\"))", ColumnType.STRING), - expressionVirtualColumn("v2", "array(1,2,4)", ColumnType.STRING), + expressionVirtualColumn("v2", "array(1.9,2.2,4.3)", ColumnType.STRING), expressionVirtualColumn("v3", "array_append(\"dim3\",'foo')", ColumnType.STRING), expressionVirtualColumn("v4", "array_prepend('foo',array(\"dim2\"))", ColumnType.STRING), expressionVirtualColumn("v5", "array_append(array(1,2),\"l1\")", ColumnType.STRING), expressionVirtualColumn("v6", "array_prepend(\"l2\",array(1,2))", ColumnType.STRING), - expressionVirtualColumn("v7", "array_append(array(1,2),\"d1\")", ColumnType.STRING), - expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1,2))", ColumnType.STRING), + expressionVirtualColumn("v7", "array_append(array(1.2,2.2),\"d1\")", ColumnType.STRING), + expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1.1,2.2))", ColumnType.STRING), expressionVirtualColumn("v9", "array_concat(\"dim2\",\"dim3\")", ColumnType.STRING) ) .columns( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java index a44995e489f..e9f751c1f80 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java @@ -32,6 +32,8 @@ import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.ArraySqlType; +import org.apache.calcite.sql.type.BasicSqlType; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFactoryImpl; import org.apache.calcite.sql.type.SqlTypeFamily; @@ -136,4 +138,46 @@ public class DruidRexExecutorTest extends InitializedNullHandlingTest ) ); } + + @Test + public void testArrayOfDoublesReduction() + { + DruidRexExecutor rexy = new DruidRexExecutor(PLANNER_CONTEXT); + List reduced = new ArrayList<>(); + BasicSqlType basicSqlType = new BasicSqlType(DruidTypeSystem.INSTANCE, SqlTypeName.DECIMAL); + ArraySqlType arraySqlType = new ArraySqlType(basicSqlType, false); + List elements = ImmutableList.of(BigDecimal.valueOf(50.12), BigDecimal.valueOf(12.1)); + RexNode literal = rexBuilder.makeLiteral(elements, arraySqlType, true); + rexy.reduce(rexBuilder, ImmutableList.of(literal), reduced); + Assert.assertEquals(1, reduced.size()); + Assert.assertEquals( + DruidExpression.fromExpression("array(50.12,12.1)"), + Expressions.toDruidExpression( + PLANNER_CONTEXT, + RowSignature.empty(), + reduced.get(0) + ) + ); + } + + @Test + public void testArrayOfLongsReduction() + { + DruidRexExecutor rexy = new DruidRexExecutor(PLANNER_CONTEXT); + List reduced = new ArrayList<>(); + BasicSqlType basicSqlType = new BasicSqlType(DruidTypeSystem.INSTANCE, SqlTypeName.INTEGER); + ArraySqlType arraySqlType = new ArraySqlType(basicSqlType, false); + List elements = ImmutableList.of(BigDecimal.valueOf(50), BigDecimal.valueOf(12)); + RexNode literal = rexBuilder.makeLiteral(elements, arraySqlType, true); + rexy.reduce(rexBuilder, ImmutableList.of(literal), reduced); + Assert.assertEquals(1, reduced.size()); + Assert.assertEquals( + DruidExpression.fromExpression("array(50,12)"), + Expressions.toDruidExpression( + PLANNER_CONTEXT, + RowSignature.empty(), + reduced.get(0) + ) + ); + } }