fix issue with SQL sum aggregator due to bug with DruidTypeSystem and AggregateRemoveRule (#12880)

* fix issue with SQL sum aggregator due to bug with DruidTypeSystem and AggregateRemoveRule

* fix style

* add comment about using custom sum function
This commit is contained in:
Clint Wylie 2022-08-09 15:17:45 -07:00 committed by GitHub
parent 2855fb6ff8
commit ee41cc770f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 131 additions and 2 deletions

View File

@ -20,8 +20,17 @@
package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.util.Optionality;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
@ -35,10 +44,18 @@ import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException;
public class SumSqlAggregator extends SimpleSqlAggregator
{
/**
* We are using a custom SUM function instead of {@link org.apache.calcite.sql.fun.SqlStdOperatorTable#SUM} to
* work around the issue described in https://issues.apache.org/jira/browse/CALCITE-4609. Once we upgrade Calcite
* to 1.27.0+ we can return to using the built-in SUM function, and {@link DruidSumAggFunction and
* {@link DruidSumSplitter} can be removed.
*/
private static final SqlAggFunction DRUID_SUM = new DruidSumAggFunction();
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.SUM;
return DRUID_SUM;
}
@Override
@ -74,4 +91,70 @@ public class SumSqlAggregator extends SimpleSqlAggregator
throw new UnsupportedSQLQueryException("Sum aggregation is not supported for '%s' type", aggregationType);
}
}
/**
* Customized verison of {@link org.apache.calcite.sql.fun.SqlSumAggFunction} with a customized
* implementation of {@link #unwrap(Class)} to provide a customized {@link SqlSplittableAggFunction} that correctly
* honors Druid's type system. The default sum implementation of {@link SqlSplittableAggFunction} assumes that it can
* reduce its output to its input in the case of a single row, which means that it doesn't necessarily reflect the
* output type as if it were run through the SUM function (e.g. INTEGER -> BIGINT)
*/
private static class DruidSumAggFunction extends SqlAggFunction
{
public DruidSumAggFunction()
{
super(
"SUM",
null,
SqlKind.SUM,
ReturnTypes.AGG_SUM,
null,
OperandTypes.NUMERIC,
SqlFunctionCategory.NUMERIC,
false,
false,
Optionality.FORBIDDEN
);
}
@Override
public <T> T unwrap(Class<T> clazz)
{
if (clazz == SqlSplittableAggFunction.class) {
return clazz.cast(DruidSumSplitter.INSTANCE);
}
return super.unwrap(clazz);
}
}
/**
* The default sum implementation of {@link SqlSplittableAggFunction} assumes that it can reduce its output to its
* input in the case of a single row for the {@link #singleton(RexBuilder, RelDataType, AggregateCall)} method, which
* is fine for the default type system where the output type of SUM is the same numeric type as the inputs, but
* Druid SUM always produces DOUBLE or BIGINT, so this is incorrect for
* {@link org.apache.druid.sql.calcite.planner.DruidTypeSystem}.
*/
private static class DruidSumSplitter extends SqlSplittableAggFunction.AbstractSumSplitter
{
public static DruidSumSplitter INSTANCE = new DruidSumSplitter();
@Override
public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType, AggregateCall aggregateCall)
{
final int arg = aggregateCall.getArgList().get(0);
final RelDataTypeField field = inputRowType.getFieldList().get(arg);
final RexNode inputRef = rexBuilder.makeInputRef(field.getType(), arg);
// if input and output do not aggree, we must cast the input to the output type
if (!aggregateCall.getType().equals(field.getType())) {
return rexBuilder.makeCast(aggregateCall.getType(), inputRef);
}
return inputRef;
}
@Override
protected SqlAggFunction getMergeAggFunctionOfTopSplit()
{
return DRUID_SUM;
}
}
}

View File

@ -14044,4 +14044,50 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
)
);
}
@Test
public void testSubqueryTypeMismatchWithLiterals() throws Exception
{
testQuery(
"SELECT \n"
+ " dim1,\n"
+ " SUM(CASE WHEN sum_l1 = 0 THEN 1 ELSE 0 END) AS outer_l1\n"
+ "from (\n"
+ " select \n"
+ " dim1,\n"
+ " SUM(l1) as sum_l1\n"
+ " from numfoo\n"
+ " group by dim1\n"
+ ")\n"
+ "group by 1",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE3)
.setInterval(querySegmentSpec(Intervals.ETERNITY))
.setGranularity(Granularities.ALL)
.addDimension(new DefaultDimensionSpec("dim1", "_d0", ColumnType.STRING))
.addAggregator(new LongSumAggregatorFactory("a0", "l1"))
.setPostAggregatorSpecs(ImmutableList.of(
expressionPostAgg("p0", "case_searched((\"a0\" == 0),1,0)")
))
.build()
),
useDefault ? ImmutableList.of(
new Object[]{"", 0L},
new Object[]{"1", 1L},
new Object[]{"10.1", 0L},
new Object[]{"2", 1L},
new Object[]{"abc", 1L},
new Object[]{"def", 1L}
) : ImmutableList.of(
// in sql compatible mode, null does not equal 0 so the values which were 1 previously are not in this mode
new Object[]{"", 0L},
new Object[]{"1", 0L},
new Object[]{"10.1", 0L},
new Object[]{"2", 1L},
new Object[]{"abc", 0L},
new Object[]{"def", 0L}
)
);
}
}