From 57a40383792e466b6c40672431d3326c6e4b0d0a Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Thu, 12 Oct 2017 15:40:43 -0700 Subject: [PATCH] SQL: Fix CASE-filtered aggregations with GROUP BY. (#4943) --- .../rule/CaseFilteredAggregatorRule.java | 7 ++-- .../druid/sql/calcite/CalciteQueryTest.java | 34 ++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/CaseFilteredAggregatorRule.java b/sql/src/main/java/io/druid/sql/calcite/rule/CaseFilteredAggregatorRule.java index 4e11736dded..d51fee71d0d 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/CaseFilteredAggregatorRule.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/CaseFilteredAggregatorRule.java @@ -87,9 +87,13 @@ public class CaseFilteredAggregatorRule extends RelOptRule final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); final List newCalls = new ArrayList<>(aggregate.getAggCallList().size()); final List newProjects = new ArrayList<>(project.getChildExps()); - final List newCasts = new ArrayList<>(aggregate.getAggCallList().size()); + final List newCasts = new ArrayList<>(aggregate.getGroupCount() + aggregate.getAggCallList().size()); final RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory(); + for (int fieldNumber : aggregate.getGroupSet()) { + newCasts.add(rexBuilder.makeInputRef(project.getChildExps().get(fieldNumber).getType(), fieldNumber)); + } + for (AggregateCall aggregateCall : aggregate.getAggCallList()) { AggregateCall newCall = null; @@ -197,7 +201,6 @@ public class CaseFilteredAggregatorRule extends RelOptRule final RelBuilder.GroupKey groupKey = relBuilder.groupKey( aggregate.getGroupSet(), - aggregate.indicator, aggregate.getGroupSets() ); diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index 04848f5fdcb..63e62fc019c 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -2158,6 +2158,38 @@ public class CalciteQueryTest ); } + @Test + public void testCaseFilteredAggregationWithGroupBy() throws Exception + { + testQuery( + "SELECT\n" + + " cnt,\n" + + " SUM(CASE WHEN dim1 <> '1' THEN 1 ELSE 0 END) + SUM(cnt)\n" + + "FROM druid.foo\n" + + "GROUP BY cnt", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(QSS(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(DIMS(new DefaultDimensionSpec("cnt", "d0", ValueType.LONG))) + .setAggregatorSpecs(AGGS( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + NOT(SELECTOR("dim1", "1", null)) + ), + new LongSumAggregatorFactory("a1", "cnt") + )) + .setPostAggregatorSpecs(ImmutableList.of(EXPRESSION_POST_AGG("p0", "(\"a0\" + \"a1\")"))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{1L, 11L} + ) + ); + } + @Test @Ignore // https://issues.apache.org/jira/browse/CALCITE-1910 public void testFilteredAggregationWithNotIn() throws Exception @@ -3405,7 +3437,7 @@ public class CalciteQueryTest .build() ), ImmutableList.of( - new Object[] {978393600000L, "def", 1L} + new Object[]{978393600000L, "def", 1L} ) ); }