SQL: Fix CASE-filtered aggregations with GROUP BY. (#4943)

This commit is contained in:
Gian Merlino 2017-10-12 15:40:43 -07:00 committed by Jonathan Wei
parent 32f36beaae
commit 57a4038379
2 changed files with 38 additions and 3 deletions

View File

@ -87,9 +87,13 @@ public class CaseFilteredAggregatorRule extends RelOptRule
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final List<AggregateCall> newCalls = new ArrayList<>(aggregate.getAggCallList().size());
final List<RexNode> newProjects = new ArrayList<>(project.getChildExps());
final List<RexNode> newCasts = new ArrayList<>(aggregate.getAggCallList().size());
final List<RexNode> newCasts = new ArrayList<>(aggregate.getGroupCount() + aggregate.getAggCallList().size());
final RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
for (int fieldNumber : aggregate.getGroupSet()) {
newCasts.add(rexBuilder.makeInputRef(project.getChildExps().get(fieldNumber).getType(), fieldNumber));
}
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
AggregateCall newCall = null;
@ -197,7 +201,6 @@ public class CaseFilteredAggregatorRule extends RelOptRule
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(
aggregate.getGroupSet(),
aggregate.indicator,
aggregate.getGroupSets()
);

View File

@ -2158,6 +2158,38 @@ public class CalciteQueryTest
);
}
@Test
public void testCaseFilteredAggregationWithGroupBy() throws Exception
{
testQuery(
"SELECT\n"
+ " cnt,\n"
+ " SUM(CASE WHEN dim1 <> '1' THEN 1 ELSE 0 END) + SUM(cnt)\n"
+ "FROM druid.foo\n"
+ "GROUP BY cnt",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(DIMS(new DefaultDimensionSpec("cnt", "d0", ValueType.LONG)))
.setAggregatorSpecs(AGGS(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
NOT(SELECTOR("dim1", "1", null))
),
new LongSumAggregatorFactory("a1", "cnt")
))
.setPostAggregatorSpecs(ImmutableList.of(EXPRESSION_POST_AGG("p0", "(\"a0\" + \"a1\")")))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1L, 11L}
)
);
}
@Test
@Ignore // https://issues.apache.org/jira/browse/CALCITE-1910
public void testFilteredAggregationWithNotIn() throws Exception
@ -3405,7 +3437,7 @@ public class CalciteQueryTest
.build()
),
ImmutableList.of(
new Object[] {978393600000L, "def", 1L}
new Object[]{978393600000L, "def", 1L}
)
);
}