mirror of https://github.com/apache/druid.git
SQL: Fix CASE-filtered aggregations with GROUP BY. (#4943)
This commit is contained in:
parent
32f36beaae
commit
57a4038379
|
@ -87,9 +87,13 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
||||||
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
|
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
|
||||||
final List<AggregateCall> newCalls = new ArrayList<>(aggregate.getAggCallList().size());
|
final List<AggregateCall> newCalls = new ArrayList<>(aggregate.getAggCallList().size());
|
||||||
final List<RexNode> newProjects = new ArrayList<>(project.getChildExps());
|
final List<RexNode> newProjects = new ArrayList<>(project.getChildExps());
|
||||||
final List<RexNode> newCasts = new ArrayList<>(aggregate.getAggCallList().size());
|
final List<RexNode> newCasts = new ArrayList<>(aggregate.getGroupCount() + aggregate.getAggCallList().size());
|
||||||
final RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
|
final RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
|
||||||
|
|
||||||
|
for (int fieldNumber : aggregate.getGroupSet()) {
|
||||||
|
newCasts.add(rexBuilder.makeInputRef(project.getChildExps().get(fieldNumber).getType(), fieldNumber));
|
||||||
|
}
|
||||||
|
|
||||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
||||||
AggregateCall newCall = null;
|
AggregateCall newCall = null;
|
||||||
|
|
||||||
|
@ -197,7 +201,6 @@ public class CaseFilteredAggregatorRule extends RelOptRule
|
||||||
|
|
||||||
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(
|
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(
|
||||||
aggregate.getGroupSet(),
|
aggregate.getGroupSet(),
|
||||||
aggregate.indicator,
|
|
||||||
aggregate.getGroupSets()
|
aggregate.getGroupSets()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -2158,6 +2158,38 @@ public class CalciteQueryTest
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCaseFilteredAggregationWithGroupBy() throws Exception
|
||||||
|
{
|
||||||
|
testQuery(
|
||||||
|
"SELECT\n"
|
||||||
|
+ " cnt,\n"
|
||||||
|
+ " SUM(CASE WHEN dim1 <> '1' THEN 1 ELSE 0 END) + SUM(cnt)\n"
|
||||||
|
+ "FROM druid.foo\n"
|
||||||
|
+ "GROUP BY cnt",
|
||||||
|
ImmutableList.of(
|
||||||
|
GroupByQuery.builder()
|
||||||
|
.setDataSource(CalciteTests.DATASOURCE1)
|
||||||
|
.setInterval(QSS(Filtration.eternity()))
|
||||||
|
.setGranularity(Granularities.ALL)
|
||||||
|
.setDimensions(DIMS(new DefaultDimensionSpec("cnt", "d0", ValueType.LONG)))
|
||||||
|
.setAggregatorSpecs(AGGS(
|
||||||
|
new FilteredAggregatorFactory(
|
||||||
|
new CountAggregatorFactory("a0"),
|
||||||
|
NOT(SELECTOR("dim1", "1", null))
|
||||||
|
),
|
||||||
|
new LongSumAggregatorFactory("a1", "cnt")
|
||||||
|
))
|
||||||
|
.setPostAggregatorSpecs(ImmutableList.of(EXPRESSION_POST_AGG("p0", "(\"a0\" + \"a1\")")))
|
||||||
|
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
ImmutableList.of(
|
||||||
|
new Object[]{1L, 11L}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore // https://issues.apache.org/jira/browse/CALCITE-1910
|
@Ignore // https://issues.apache.org/jira/browse/CALCITE-1910
|
||||||
public void testFilteredAggregationWithNotIn() throws Exception
|
public void testFilteredAggregationWithNotIn() throws Exception
|
||||||
|
|
Loading…
Reference in New Issue