mirror of https://github.com/apache/druid.git
fix issue with SQL sum aggregator due to bug with DruidTypeSystem and AggregateRemoveRule (#12880)
* fix issue with SQL sum aggregator due to bug with DruidTypeSystem and AggregateRemoveRule * fix style * add comment about using custom sum function
This commit is contained in:
parent
2855fb6ff8
commit
ee41cc770f
|
@ -20,8 +20,17 @@
|
|||
package org.apache.druid.sql.calcite.aggregation.builtin;
|
||||
|
||||
import org.apache.calcite.rel.core.AggregateCall;
|
||||
import org.apache.calcite.rel.type.RelDataType;
|
||||
import org.apache.calcite.rel.type.RelDataTypeField;
|
||||
import org.apache.calcite.rex.RexBuilder;
|
||||
import org.apache.calcite.rex.RexNode;
|
||||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
||||
import org.apache.calcite.sql.SqlFunctionCategory;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.SqlSplittableAggFunction;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.ReturnTypes;
|
||||
import org.apache.calcite.util.Optionality;
|
||||
import org.apache.druid.math.expr.ExprMacroTable;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
|
||||
|
@ -35,10 +44,18 @@ import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException;
|
|||
|
||||
public class SumSqlAggregator extends SimpleSqlAggregator
|
||||
{
|
||||
/**
|
||||
* We are using a custom SUM function instead of {@link org.apache.calcite.sql.fun.SqlStdOperatorTable#SUM} to
|
||||
* work around the issue described in https://issues.apache.org/jira/browse/CALCITE-4609. Once we upgrade Calcite
|
||||
* to 1.27.0+ we can return to using the built-in SUM function, and {@link DruidSumAggFunction and
|
||||
* {@link DruidSumSplitter} can be removed.
|
||||
*/
|
||||
private static final SqlAggFunction DRUID_SUM = new DruidSumAggFunction();
|
||||
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.SUM;
|
||||
return DRUID_SUM;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -74,4 +91,70 @@ public class SumSqlAggregator extends SimpleSqlAggregator
|
|||
throw new UnsupportedSQLQueryException("Sum aggregation is not supported for '%s' type", aggregationType);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Customized verison of {@link org.apache.calcite.sql.fun.SqlSumAggFunction} with a customized
|
||||
* implementation of {@link #unwrap(Class)} to provide a customized {@link SqlSplittableAggFunction} that correctly
|
||||
* honors Druid's type system. The default sum implementation of {@link SqlSplittableAggFunction} assumes that it can
|
||||
* reduce its output to its input in the case of a single row, which means that it doesn't necessarily reflect the
|
||||
* output type as if it were run through the SUM function (e.g. INTEGER -> BIGINT)
|
||||
*/
|
||||
private static class DruidSumAggFunction extends SqlAggFunction
|
||||
{
|
||||
public DruidSumAggFunction()
|
||||
{
|
||||
super(
|
||||
"SUM",
|
||||
null,
|
||||
SqlKind.SUM,
|
||||
ReturnTypes.AGG_SUM,
|
||||
null,
|
||||
OperandTypes.NUMERIC,
|
||||
SqlFunctionCategory.NUMERIC,
|
||||
false,
|
||||
false,
|
||||
Optionality.FORBIDDEN
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T unwrap(Class<T> clazz)
|
||||
{
|
||||
if (clazz == SqlSplittableAggFunction.class) {
|
||||
return clazz.cast(DruidSumSplitter.INSTANCE);
|
||||
}
|
||||
return super.unwrap(clazz);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The default sum implementation of {@link SqlSplittableAggFunction} assumes that it can reduce its output to its
|
||||
* input in the case of a single row for the {@link #singleton(RexBuilder, RelDataType, AggregateCall)} method, which
|
||||
* is fine for the default type system where the output type of SUM is the same numeric type as the inputs, but
|
||||
* Druid SUM always produces DOUBLE or BIGINT, so this is incorrect for
|
||||
* {@link org.apache.druid.sql.calcite.planner.DruidTypeSystem}.
|
||||
*/
|
||||
private static class DruidSumSplitter extends SqlSplittableAggFunction.AbstractSumSplitter
|
||||
{
|
||||
public static DruidSumSplitter INSTANCE = new DruidSumSplitter();
|
||||
|
||||
@Override
|
||||
public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType, AggregateCall aggregateCall)
|
||||
{
|
||||
final int arg = aggregateCall.getArgList().get(0);
|
||||
final RelDataTypeField field = inputRowType.getFieldList().get(arg);
|
||||
final RexNode inputRef = rexBuilder.makeInputRef(field.getType(), arg);
|
||||
// if input and output do not aggree, we must cast the input to the output type
|
||||
if (!aggregateCall.getType().equals(field.getType())) {
|
||||
return rexBuilder.makeCast(aggregateCall.getType(), inputRef);
|
||||
}
|
||||
return inputRef;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SqlAggFunction getMergeAggFunctionOfTopSplit()
|
||||
{
|
||||
return DRUID_SUM;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14044,4 +14044,50 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSubqueryTypeMismatchWithLiterals() throws Exception
|
||||
{
|
||||
testQuery(
|
||||
"SELECT \n"
|
||||
+ " dim1,\n"
|
||||
+ " SUM(CASE WHEN sum_l1 = 0 THEN 1 ELSE 0 END) AS outer_l1\n"
|
||||
+ "from (\n"
|
||||
+ " select \n"
|
||||
+ " dim1,\n"
|
||||
+ " SUM(l1) as sum_l1\n"
|
||||
+ " from numfoo\n"
|
||||
+ " group by dim1\n"
|
||||
+ ")\n"
|
||||
+ "group by 1",
|
||||
ImmutableList.of(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE3)
|
||||
.setInterval(querySegmentSpec(Intervals.ETERNITY))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.addDimension(new DefaultDimensionSpec("dim1", "_d0", ColumnType.STRING))
|
||||
.addAggregator(new LongSumAggregatorFactory("a0", "l1"))
|
||||
.setPostAggregatorSpecs(ImmutableList.of(
|
||||
expressionPostAgg("p0", "case_searched((\"a0\" == 0),1,0)")
|
||||
))
|
||||
.build()
|
||||
),
|
||||
useDefault ? ImmutableList.of(
|
||||
new Object[]{"", 0L},
|
||||
new Object[]{"1", 1L},
|
||||
new Object[]{"10.1", 0L},
|
||||
new Object[]{"2", 1L},
|
||||
new Object[]{"abc", 1L},
|
||||
new Object[]{"def", 1L}
|
||||
) : ImmutableList.of(
|
||||
// in sql compatible mode, null does not equal 0 so the values which were 1 previously are not in this mode
|
||||
new Object[]{"", 0L},
|
||||
new Object[]{"1", 0L},
|
||||
new Object[]{"10.1", 0L},
|
||||
new Object[]{"2", 1L},
|
||||
new Object[]{"abc", 0L},
|
||||
new Object[]{"def", 0L}
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue