From ee41cc770f3a2b666e380d25ffbc5469a42a1d36 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 9 Aug 2022 15:17:45 -0700 Subject: [PATCH] fix issue with SQL sum aggregator due to bug with DruidTypeSystem and AggregateRemoveRule (#12880) * fix issue with SQL sum aggregator due to bug with DruidTypeSystem and AggregateRemoveRule * fix style * add comment about using custom sum function --- .../aggregation/builtin/SumSqlAggregator.java | 87 ++++++++++++++++++- .../druid/sql/calcite/CalciteQueryTest.java | 46 ++++++++++ 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator.java index cd7a13d9357..f4dcad3ed59 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SumSqlAggregator.java @@ -20,8 +20,17 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlSplittableAggFunction; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.util.Optionality; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; @@ -35,10 +44,18 @@ import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException; public class SumSqlAggregator extends SimpleSqlAggregator { + /** + * We are using a custom SUM function instead of {@link org.apache.calcite.sql.fun.SqlStdOperatorTable#SUM} to + * work around the issue described in https://issues.apache.org/jira/browse/CALCITE-4609. Once we upgrade Calcite + * to 1.27.0+ we can return to using the built-in SUM function, and {@link DruidSumAggFunction and + * {@link DruidSumSplitter} can be removed. + */ + private static final SqlAggFunction DRUID_SUM = new DruidSumAggFunction(); + @Override public SqlAggFunction calciteFunction() { - return SqlStdOperatorTable.SUM; + return DRUID_SUM; } @Override @@ -74,4 +91,70 @@ public class SumSqlAggregator extends SimpleSqlAggregator throw new UnsupportedSQLQueryException("Sum aggregation is not supported for '%s' type", aggregationType); } } + + /** + * Customized verison of {@link org.apache.calcite.sql.fun.SqlSumAggFunction} with a customized + * implementation of {@link #unwrap(Class)} to provide a customized {@link SqlSplittableAggFunction} that correctly + * honors Druid's type system. The default sum implementation of {@link SqlSplittableAggFunction} assumes that it can + * reduce its output to its input in the case of a single row, which means that it doesn't necessarily reflect the + * output type as if it were run through the SUM function (e.g. INTEGER -> BIGINT) + */ + private static class DruidSumAggFunction extends SqlAggFunction + { + public DruidSumAggFunction() + { + super( + "SUM", + null, + SqlKind.SUM, + ReturnTypes.AGG_SUM, + null, + OperandTypes.NUMERIC, + SqlFunctionCategory.NUMERIC, + false, + false, + Optionality.FORBIDDEN + ); + } + + @Override + public T unwrap(Class clazz) + { + if (clazz == SqlSplittableAggFunction.class) { + return clazz.cast(DruidSumSplitter.INSTANCE); + } + return super.unwrap(clazz); + } + } + + /** + * The default sum implementation of {@link SqlSplittableAggFunction} assumes that it can reduce its output to its + * input in the case of a single row for the {@link #singleton(RexBuilder, RelDataType, AggregateCall)} method, which + * is fine for the default type system where the output type of SUM is the same numeric type as the inputs, but + * Druid SUM always produces DOUBLE or BIGINT, so this is incorrect for + * {@link org.apache.druid.sql.calcite.planner.DruidTypeSystem}. + */ + private static class DruidSumSplitter extends SqlSplittableAggFunction.AbstractSumSplitter + { + public static DruidSumSplitter INSTANCE = new DruidSumSplitter(); + + @Override + public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType, AggregateCall aggregateCall) + { + final int arg = aggregateCall.getArgList().get(0); + final RelDataTypeField field = inputRowType.getFieldList().get(arg); + final RexNode inputRef = rexBuilder.makeInputRef(field.getType(), arg); + // if input and output do not aggree, we must cast the input to the output type + if (!aggregateCall.getType().equals(field.getType())) { + return rexBuilder.makeCast(aggregateCall.getType(), inputRef); + } + return inputRef; + } + + @Override + protected SqlAggFunction getMergeAggFunctionOfTopSplit() + { + return DRUID_SUM; + } + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 610590237ac..39e5dcd799b 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -14044,4 +14044,50 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ) ); } + + @Test + public void testSubqueryTypeMismatchWithLiterals() throws Exception + { + testQuery( + "SELECT \n" + + " dim1,\n" + + " SUM(CASE WHEN sum_l1 = 0 THEN 1 ELSE 0 END) AS outer_l1\n" + + "from (\n" + + " select \n" + + " dim1,\n" + + " SUM(l1) as sum_l1\n" + + " from numfoo\n" + + " group by dim1\n" + + ")\n" + + "group by 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Intervals.ETERNITY)) + .setGranularity(Granularities.ALL) + .addDimension(new DefaultDimensionSpec("dim1", "_d0", ColumnType.STRING)) + .addAggregator(new LongSumAggregatorFactory("a0", "l1")) + .setPostAggregatorSpecs(ImmutableList.of( + expressionPostAgg("p0", "case_searched((\"a0\" == 0),1,0)") + )) + .build() + ), + useDefault ? ImmutableList.of( + new Object[]{"", 0L}, + new Object[]{"1", 1L}, + new Object[]{"10.1", 0L}, + new Object[]{"2", 1L}, + new Object[]{"abc", 1L}, + new Object[]{"def", 1L} + ) : ImmutableList.of( + // in sql compatible mode, null does not equal 0 so the values which were 1 previously are not in this mode + new Object[]{"", 0L}, + new Object[]{"1", 0L}, + new Object[]{"10.1", 0L}, + new Object[]{"2", 1L}, + new Object[]{"abc", 0L}, + new Object[]{"def", 0L} + ) + ); + } }