diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java index a06198e4265..50bdf80771a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java @@ -73,12 +73,13 @@ public class GroupByRules if (call.filterArg >= 0) { // AGG(xxx) FILTER(WHERE yyy) - if (project == null) { - // We need some kind of projection to support filtered aggregations. - return null; - } - final RexNode expression = project.getProjects().get(call.filterArg); + final RexNode expression = Expressions.fromFieldAccess( + rexBuilder.getTypeFactory(), + rowSignature, + project, + call.filterArg); + final DimFilter nonOptimizedFilter = Expressions.toFilter( plannerContext, rowSignature, diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java index 661f5cd6ec9..c90a961d96a 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java @@ -32,6 +32,7 @@ import org.apache.druid.query.LookupDataSource; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.extraction.SubstringDimExtractionFn; import org.apache.druid.query.groupby.GroupByQuery; @@ -1923,4 +1924,26 @@ public class CalciteSelectQueryTest extends BaseCalciteQueryTest ) ); } + + @Test + public void testAggregateFilterInTheAbsenceOfProjection() + { + cannotVectorize(); + testQuery( + "select count(1) filter (where __time > date '2023-01-01') " + + " from druid.foo where 'a' = 'b'", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(InlineDataSource.fromIterable( + ImmutableList.of(), + RowSignature.builder().add("$f1", ColumnType.LONG).build())) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), expressionFilter("\"$f1\"")))) + .context(QUERY_CONTEXT_DEFAULT) + .build()), + ImmutableList.of(new Object[] {0L})); + } }