SQL: Fix CASE-filtered aggregations with GROUP BY. (#4943)

This commit is contained in:
Gian Merlino 2017-10-12 15:40:43 -07:00 committed by Jonathan Wei
parent 32f36beaae
commit 57a4038379
2 changed files with 38 additions and 3 deletions

View File

@ -87,9 +87,13 @@ public class CaseFilteredAggregatorRule extends RelOptRule
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final List<AggregateCall> newCalls = new ArrayList<>(aggregate.getAggCallList().size()); final List<AggregateCall> newCalls = new ArrayList<>(aggregate.getAggCallList().size());
final List<RexNode> newProjects = new ArrayList<>(project.getChildExps()); final List<RexNode> newProjects = new ArrayList<>(project.getChildExps());
final List<RexNode> newCasts = new ArrayList<>(aggregate.getAggCallList().size()); final List<RexNode> newCasts = new ArrayList<>(aggregate.getGroupCount() + aggregate.getAggCallList().size());
final RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory(); final RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
for (int fieldNumber : aggregate.getGroupSet()) {
newCasts.add(rexBuilder.makeInputRef(project.getChildExps().get(fieldNumber).getType(), fieldNumber));
}
for (AggregateCall aggregateCall : aggregate.getAggCallList()) { for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
AggregateCall newCall = null; AggregateCall newCall = null;
@ -197,7 +201,6 @@ public class CaseFilteredAggregatorRule extends RelOptRule
final RelBuilder.GroupKey groupKey = relBuilder.groupKey( final RelBuilder.GroupKey groupKey = relBuilder.groupKey(
aggregate.getGroupSet(), aggregate.getGroupSet(),
aggregate.indicator,
aggregate.getGroupSets() aggregate.getGroupSets()
); );

View File

@ -2158,6 +2158,38 @@ public class CalciteQueryTest
); );
} }
@Test
public void testCaseFilteredAggregationWithGroupBy() throws Exception
{
testQuery(
"SELECT\n"
+ " cnt,\n"
+ " SUM(CASE WHEN dim1 <> '1' THEN 1 ELSE 0 END) + SUM(cnt)\n"
+ "FROM druid.foo\n"
+ "GROUP BY cnt",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(DIMS(new DefaultDimensionSpec("cnt", "d0", ValueType.LONG)))
.setAggregatorSpecs(AGGS(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
NOT(SELECTOR("dim1", "1", null))
),
new LongSumAggregatorFactory("a1", "cnt")
))
.setPostAggregatorSpecs(ImmutableList.of(EXPRESSION_POST_AGG("p0", "(\"a0\" + \"a1\")")))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1L, 11L}
)
);
}
@Test @Test
@Ignore // https://issues.apache.org/jira/browse/CALCITE-1910 @Ignore // https://issues.apache.org/jira/browse/CALCITE-1910
public void testFilteredAggregationWithNotIn() throws Exception public void testFilteredAggregationWithNotIn() throws Exception