Restore usage of filtered SUM (#17378)

This commit is contained in:
Zoltan Haindrich 2024-12-12 10:30:42 +01:00 committed by GitHub
parent 05c3cbce08
commit 1a38434d8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 593 additions and 48 deletions

View File

@ -625,6 +625,15 @@ public class QueryContext
); );
} }
public boolean isExtendedFilteredSumRewrite()
{
return getBoolean(
QueryContexts.EXTENDED_FILTERED_SUM_REWRITE_ENABLED,
QueryContexts.DEFAULT_EXTENDED_FILTERED_SUM_REWRITE_ENABLED
);
}
public QueryResourceId getQueryResourceId() public QueryResourceId getQueryResourceId()
{ {
return new QueryResourceId(getString(QueryContexts.QUERY_RESOURCE_ID)); return new QueryResourceId(getString(QueryContexts.QUERY_RESOURCE_ID));

View File

@ -89,6 +89,22 @@ public class QueryContexts
public static final String UNCOVERED_INTERVALS_LIMIT_KEY = "uncoveredIntervalsLimit"; public static final String UNCOVERED_INTERVALS_LIMIT_KEY = "uncoveredIntervalsLimit";
public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold"; public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold";
public static final String CATALOG_VALIDATION_ENABLED = "catalogValidationEnabled"; public static final String CATALOG_VALIDATION_ENABLED = "catalogValidationEnabled";
/**
* Context parameter to enable/disable the extended filtered sum rewrite logic.
*
* Controls the rewrite of:
* <pre>
* SUM(CASE WHEN COND THEN COL1 ELSE 0 END)
* to
* SUM(COL1) FILTER (COND)
* </pre>
* managed by {@link DruidAggregateCaseToFilterRule}. Defaults to true for performance,
* but may produce incorrect results when the condition never matches (expected 0).
* This is for testing and can be removed once a correct and high-performance rewrite
* is implemented.
*/
public static final String EXTENDED_FILTERED_SUM_REWRITE_ENABLED = "extendedFilteredSumRewrite";
// projection context keys // projection context keys
public static final String NO_PROJECTIONS = "noProjections"; public static final String NO_PROJECTIONS = "noProjections";
@ -139,6 +155,7 @@ public class QueryContexts
public static final boolean DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING = false; public static final boolean DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING = false;
public static final boolean DEFAULT_CATALOG_VALIDATION_ENABLED = true; public static final boolean DEFAULT_CATALOG_VALIDATION_ENABLED = true;
public static final boolean DEFAULT_USE_NESTED_FOR_UNKNOWN_TYPE_IN_SUBQUERY = false; public static final boolean DEFAULT_USE_NESTED_FOR_UNKNOWN_TYPE_IN_SUBQUERY = false;
public static final boolean DEFAULT_EXTENDED_FILTERED_SUM_REWRITE_ENABLED = true;
@SuppressWarnings("unused") // Used by Jackson serialization @SuppressWarnings("unused") // Used by Jackson serialization

View File

@ -394,6 +394,17 @@ public class QueryContextTest
); );
} }
@Test
public void testExtendedFilteredSumRewrite()
{
assertTrue(QueryContext.empty().isExtendedFilteredSumRewrite());
assertFalse(
QueryContext
.of(ImmutableMap.of(QueryContexts.EXTENDED_FILTERED_SUM_REWRITE_ENABLED, false))
.isExtendedFilteredSumRewrite()
);
}
// This test is a bit silly. It is retained because another test uses the // This test is a bit silly. It is retained because another test uses the
// LegacyContextQuery test. // LegacyContextQuery test.
@Test @Test

View File

@ -54,6 +54,7 @@ import org.apache.druid.sql.calcite.external.ExternalTableScanRule;
import org.apache.druid.sql.calcite.rule.AggregatePullUpLookupRule; import org.apache.druid.sql.calcite.rule.AggregatePullUpLookupRule;
import org.apache.druid.sql.calcite.rule.CaseToCoalesceRule; import org.apache.druid.sql.calcite.rule.CaseToCoalesceRule;
import org.apache.druid.sql.calcite.rule.CoalesceLookupRule; import org.apache.druid.sql.calcite.rule.CoalesceLookupRule;
import org.apache.druid.sql.calcite.rule.DruidAggregateCaseToFilterRule;
import org.apache.druid.sql.calcite.rule.DruidLogicalValuesRule; import org.apache.druid.sql.calcite.rule.DruidLogicalValuesRule;
import org.apache.druid.sql.calcite.rule.DruidRelToDruidRule; import org.apache.druid.sql.calcite.rule.DruidRelToDruidRule;
import org.apache.druid.sql.calcite.rule.DruidRules; import org.apache.druid.sql.calcite.rule.DruidRules;
@ -119,7 +120,6 @@ public class CalciteRulesManager
CoreRules.FILTER_PROJECT_TRANSPOSE, CoreRules.FILTER_PROJECT_TRANSPOSE,
CoreRules.JOIN_PUSH_EXPRESSIONS, CoreRules.JOIN_PUSH_EXPRESSIONS,
CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT, CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT,
CoreRules.AGGREGATE_CASE_TO_FILTER,
CoreRules.FILTER_AGGREGATE_TRANSPOSE, CoreRules.FILTER_AGGREGATE_TRANSPOSE,
CoreRules.PROJECT_WINDOW_TRANSPOSE, CoreRules.PROJECT_WINDOW_TRANSPOSE,
CoreRules.MATCH, CoreRules.MATCH,
@ -495,6 +495,7 @@ public class CalciteRulesManager
rules.addAll(BASE_RULES); rules.addAll(BASE_RULES);
rules.addAll(ABSTRACT_RULES); rules.addAll(ABSTRACT_RULES);
rules.addAll(ABSTRACT_RELATIONAL_RULES); rules.addAll(ABSTRACT_RELATIONAL_RULES);
rules.add(new DruidAggregateCaseToFilterRule(plannerContext.queryContext().isExtendedFilteredSumRewrite()));
rules.addAll(configurableRuleSet(plannerContext)); rules.addAll(configurableRuleSet(plannerContext));
if (plannerContext.getJoinAlgorithm().requiresSubquery()) { if (plannerContext.getJoinAlgorithm().requiresSubquery()) {

View File

@ -0,0 +1,349 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.rule;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.rules.AggregateCaseToFilterRule;
import org.apache.calcite.rel.rules.SubstitutionRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlPostfixOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.checkerframework.checker.nullness.qual.Nullable;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
/**
* Druid extension of {@link AggregateCaseToFilterRule}.
*
* Turning on extendedFilteredSumRewrite enables rewrites of:
* <pre>
* SUM(CASE WHEN COND THEN COL1 ELSE 0 END)
* </pre>
* to:
* <pre>
* SUM(COL1) FILTER (WHERE COND)
* </pre>
* <p>
* This rewrite improves performance but introduces a known inconsistency when
* the condition never matches, as the expected result (0) is replaced with `null`.
* <p>
* Example behavior:
* <pre>
* +-----------------+--------------+----------+------+--------------+
* | input row count | cond matches | valueCol | orig | filtered-SUM |
* +-----------------+--------------+----------+------+--------------+
* | 0 | * | * | null | null |
* | >0 | none | * | 0 | null |
* | >0 | all | null | null | null |
* | >0 | N>0 | 1 | N | N |
* +-----------------+--------------+----------+------+--------------+
* </pre>
*/
public class DruidAggregateCaseToFilterRule extends RelOptRule implements SubstitutionRule
{
private boolean extendedFilteredSumRewrite;
public DruidAggregateCaseToFilterRule(boolean extendedFilteredSumRewrite)
{
super(operand(Aggregate.class, operand(Project.class, any())));
this.extendedFilteredSumRewrite = extendedFilteredSumRewrite;
}
@Override
public boolean matches(final RelOptRuleCall call)
{
final Aggregate aggregate = call.rel(0);
final Project project = call.rel(1);
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
final int singleArg = soleArgument(aggregateCall);
if (singleArg >= 0
&& isThreeArgCase(project.getProjects().get(singleArg))) {
return true;
}
}
return false;
}
@Override
public void onMatch(RelOptRuleCall call)
{
final Aggregate aggregate = call.rel(0);
final Project project = call.rel(1);
final List<AggregateCall> newCalls = new ArrayList<>(aggregate.getAggCallList().size());
final List<RexNode> newProjects = new ArrayList<>(project.getProjects());
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
AggregateCall newCall = transform(aggregateCall, project, newProjects);
if (newCall == null) {
newCalls.add(aggregateCall);
} else {
newCalls.add(newCall);
}
}
if (newCalls.equals(aggregate.getAggCallList())) {
return;
}
final RelBuilder relBuilder = call.builder()
.push(project.getInput())
.project(newProjects);
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets());
relBuilder.aggregate(groupKey, newCalls)
.convert(aggregate.getRowType(), false);
call.transformTo(relBuilder.build());
call.getPlanner().prune(aggregate);
}
private @Nullable AggregateCall transform(AggregateCall call,
Project project, List<RexNode> newProjects)
{
final int singleArg = soleArgument(call);
if (singleArg < 0) {
return null;
}
final RexNode rexNode = project.getProjects().get(singleArg);
if (!isThreeArgCase(rexNode)) {
return null;
}
final RelOptCluster cluster = project.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
final RexCall caseCall = (RexCall) rexNode;
// If one arg is null and the other is not, reverse them and set "flip",
// which negates the filter.
final boolean flip = RexLiteral.isNullLiteral(caseCall.operands.get(1))
&& !RexLiteral.isNullLiteral(caseCall.operands.get(2));
final RexNode arg1 = caseCall.operands.get(flip ? 2 : 1);
final RexNode arg2 = caseCall.operands.get(flip ? 1 : 2);
// Operand 1: Filter
final SqlPostfixOperator op = flip ? SqlStdOperatorTable.IS_NOT_TRUE : SqlStdOperatorTable.IS_TRUE;
final RexNode filterFromCase = rexBuilder.makeCall(op, caseCall.operands.get(0));
// Combine the CASE filter with an honest-to-goodness SQL FILTER, if the
// latter is present.
final RexNode filter;
if (call.filterArg >= 0) {
filter = rexBuilder.makeCall(
SqlStdOperatorTable.AND,
project.getProjects().get(call.filterArg),
filterFromCase
);
} else {
filter = filterFromCase;
}
RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
final SqlKind kind = call.getAggregation().getKind();
if (call.isDistinct()) {
// Just one style supported:
// COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END)
// =>
// COUNT(DISTINCT y) FILTER(WHERE x = 'foo')
if (kind == SqlKind.COUNT
&& RexLiteral.isNullLiteral(arg2)) {
newProjects.add(arg1);
newProjects.add(filter);
return AggregateCall.create(
SqlStdOperatorTable.COUNT,
true,
false,
false,
call.rexList,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
null,
RelCollations.EMPTY,
call.getType(),
call.getName()
);
}
return null;
}
// Four styles supported:
//
// A1: AGG(CASE WHEN x = 'foo' THEN expr END)
// => AGG(expr) FILTER (x = 'foo')
// A2: SUM0(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
// => SUM0(cnt) FILTER (x = 'foo')
// B: SUM0(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
// => COUNT() FILTER (x = 'foo')
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END)
// => COUNT() FILTER (x = 'foo')
if (kind == SqlKind.COUNT // Case C
&& arg1.isA(SqlKind.LITERAL)
&& !RexLiteral.isNullLiteral(arg1)
&& RexLiteral.isNullLiteral(arg2)) {
newProjects.add(filter);
return AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
false,
false,
call.rexList, ImmutableList.of(), newProjects.size() - 1, null,
RelCollations.EMPTY, call.getType(),
call.getName());
} else if (kind == SqlKind.SUM0 // Case B
&& isIntLiteral(arg1, BigDecimal.ONE)
&& isIntLiteral(arg2, BigDecimal.ZERO)) {
newProjects.add(filter);
final RelDataType dataType = typeFactory
.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), false);
return AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
false,
false,
call.rexList,
ImmutableList.of(),
newProjects.size() - 1,
null,
RelCollations.EMPTY,
dataType,
call.getName()
);
} else if ((RexLiteral.isNullLiteral(arg2) // Case A1
&& call.getAggregation().allowsFilter())
|| (kind == SqlKind.SUM0 // Case A2
&& isIntLiteral(arg2, BigDecimal.ZERO))) {
newProjects.add(arg1);
newProjects.add(filter);
return AggregateCall.create(
call.getAggregation(),
false,
false,
false,
call.rexList,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
null,
RelCollations.EMPTY,
call.getType(),
call.getName()
);
}
// Rewrites
// D1: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
// => SUM0(cnt) FILTER (x = 'foo')
// D2: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
// => COUNT() FILTER (x = 'foo')
//
// https://issues.apache.org/jira/browse/CALCITE-5953
// have restricted this rewrite as in case there are no rows it may not be equvivalent;
// however it may have some performance impact in Druid
if (extendedFilteredSumRewrite &&
kind == SqlKind.SUM && isIntLiteral(arg2, BigDecimal.ZERO)) {
if (isIntLiteral(arg1, BigDecimal.ONE)) { // D2
newProjects.add(filter);
final RelDataType dataType = typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BIGINT), false
);
return AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
false,
false,
call.rexList,
ImmutableList.of(),
newProjects.size() - 1,
null,
RelCollations.EMPTY,
dataType,
call.getName()
);
} else { // D1
newProjects.add(arg1);
newProjects.add(filter);
RelDataType newType = typeFactory.createTypeWithNullability(call.getType(), true);
return AggregateCall.create(
call.getAggregation(),
false,
false,
false,
call.rexList,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
null,
RelCollations.EMPTY,
newType,
call.getName()
);
}
}
return null;
}
/**
* Returns the argument, if an aggregate call has a single argument, otherwise
* -1.
*/
private static int soleArgument(AggregateCall aggregateCall)
{
return aggregateCall.getArgList().size() == 1
? aggregateCall.getArgList().get(0)
: -1;
}
private static boolean isThreeArgCase(final RexNode rexNode)
{
return rexNode.getKind() == SqlKind.CASE
&& ((RexCall) rexNode).operands.size() == 3;
}
private static boolean isIntLiteral(RexNode rexNode, BigDecimal value)
{
return rexNode instanceof RexLiteral
&& SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName())
&& value.equals(((RexLiteral) rexNode).getValueAs(BigDecimal.class));
}
}

View File

@ -5188,7 +5188,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testFilteredAggregations() public void testFilteredAggregations()
{ {
cannotVectorizeUnlessFallback();
Druids.TimeseriesQueryBuilder builder = Druids.TimeseriesQueryBuilder builder =
Druids.newTimeseriesQueryBuilder() Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1) .dataSource(CalciteTests.DATASOURCE1)
@ -5196,18 +5195,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.granularity(Granularities.ALL) .granularity(Granularities.ALL)
.context(QUERY_CONTEXT_DEFAULT); .context(QUERY_CONTEXT_DEFAULT);
if (NullHandling.sqlCompatible()) { if (NullHandling.sqlCompatible()) {
cannotVectorizeUnlessFallback();
builder = builder.virtualColumns( builder = builder.virtualColumns(
expressionVirtualColumn("v0", "substring(\"dim1\", 0, 1)", ColumnType.STRING), expressionVirtualColumn("v0", "substring(\"dim1\", 0, 1)", ColumnType.STRING)
expressionVirtualColumn(
"v1",
"case_searched((\"dim1\" != '1'),1,0)",
ColumnType.LONG
),
expressionVirtualColumn(
"v2",
"case_searched((\"dim1\" != '1'),\"cnt\",0)",
ColumnType.LONG
)
) )
.aggregators( .aggregators(
aggregators( aggregators(
@ -5234,7 +5224,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new CountAggregatorFactory("a4"), new CountAggregatorFactory("a4"),
not(equality("dim1", "1", ColumnType.STRING)) not(equality("dim1", "1", ColumnType.STRING))
), ),
new LongSumAggregatorFactory("a5", "v1"), new FilteredAggregatorFactory(
new CountAggregatorFactory("a5"),
not(equality("dim1", "1", ColumnType.STRING))
),
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a6", "cnt"), new LongSumAggregatorFactory("a6", "cnt"),
equality("dim2", "a", ColumnType.STRING) equality("dim2", "a", ColumnType.STRING)
@ -5246,7 +5239,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
not(equality("dim1", "1", ColumnType.STRING)) not(equality("dim1", "1", ColumnType.STRING))
) )
), ),
new LongSumAggregatorFactory("a8", "v2"), new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a8", "cnt"),
not(equality("dim1", "1", ColumnType.STRING))
),
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
new LongMaxAggregatorFactory("a9", "cnt"), new LongMaxAggregatorFactory("a9", "cnt"),
not(equality("dim1", "1", ColumnType.STRING)) not(equality("dim1", "1", ColumnType.STRING))
@ -5272,16 +5268,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
); );
} else { } else {
builder = builder.virtualColumns( builder = builder.virtualColumns(
expressionVirtualColumn( )
"v0",
"case_searched((\"dim1\" != '1'),1,0)",
ColumnType.LONG
),
expressionVirtualColumn(
"v1",
"case_searched((\"dim1\" != '1'),\"cnt\",0)",
ColumnType.LONG
))
.aggregators( .aggregators(
aggregators( aggregators(
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
@ -5307,7 +5294,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new CountAggregatorFactory("a4"), new CountAggregatorFactory("a4"),
not(equality("dim1", "1", ColumnType.STRING)) not(equality("dim1", "1", ColumnType.STRING))
), ),
new LongSumAggregatorFactory("a5", "v0"), new FilteredAggregatorFactory(
new CountAggregatorFactory("a5"),
not(equality("dim1", "1", ColumnType.STRING))
),
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a6", "cnt"), new LongSumAggregatorFactory("a6", "cnt"),
equality("dim2", "a", ColumnType.STRING) equality("dim2", "a", ColumnType.STRING)
@ -5319,7 +5309,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
not(equality("dim1", "1", ColumnType.STRING)) not(equality("dim1", "1", ColumnType.STRING))
) )
), ),
new LongSumAggregatorFactory("a8", "v1"), new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a8", "cnt"),
not(equality("dim1", "1", ColumnType.STRING))
),
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
new LongMaxAggregatorFactory("a9", "cnt"), new LongMaxAggregatorFactory("a9", "cnt"),
not(equality("dim1", "1", ColumnType.STRING)) not(equality("dim1", "1", ColumnType.STRING))
@ -5373,7 +5366,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testCaseFilteredAggregationWithGroupBy() public void testCaseFilteredAggregationWithGroupBy()
{ {
cannotVectorizeUnlessFallback();
testQuery( testQuery(
"SELECT\n" "SELECT\n"
+ " cnt,\n" + " cnt,\n"
@ -5386,15 +5378,11 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity())) .setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL) .setGranularity(Granularities.ALL)
.setDimensions(dimensions(new DefaultDimensionSpec("cnt", "d0", ColumnType.LONG))) .setDimensions(dimensions(new DefaultDimensionSpec("cnt", "d0", ColumnType.LONG)))
.setVirtualColumns(
expressionVirtualColumn(
"v0",
"case_searched((\"dim1\" != '1'),1,0)",
ColumnType.LONG
)
)
.setAggregatorSpecs(aggregators( .setAggregatorSpecs(aggregators(
new LongSumAggregatorFactory("a0", "v0"), new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
not(equality("dim1", "1", ColumnType.STRING))
),
new LongSumAggregatorFactory("a1", "cnt") new LongSumAggregatorFactory("a1", "cnt")
)) ))
.setPostAggregatorSpecs( .setPostAggregatorSpecs(
@ -5409,6 +5397,52 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
); );
} }
@Test
public void testCaseFilteredAggregationWithGroupRewriteToSum()
{
testBuilder()
.sql(
"SELECT\n"
+ " cnt,\n"
+ " SUM(CASE WHEN dim1 <> '1' THEN 2 ELSE 0 END) + SUM(cnt)\n"
+ "FROM druid.foo\n"
+ "GROUP BY cnt"
)
.expectedQueries(
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new DefaultDimensionSpec("cnt", "d0", ColumnType.LONG)))
.setVirtualColumns(
expressionVirtualColumn("v0", "2", ColumnType.LONG)
)
.setAggregatorSpecs(
aggregators(
new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a0", "v0"),
not(equality("dim1", "1", ColumnType.STRING))
),
new LongSumAggregatorFactory("a1", "cnt")
)
)
.setPostAggregatorSpecs(
expressionPostAgg("p0", "(\"a0\" + \"a1\")", ColumnType.LONG)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
)
)
.expectedResults(
ImmutableList.of(
new Object[] {1L, 16L}
)
)
.run();
}
@Test @Test
public void testFilteredAggregationWithNotIn() public void testFilteredAggregationWithNotIn()
{ {
@ -9479,7 +9513,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testQueryWithSelectProjectAndIdentityProjectDoesNotRename() public void testQueryWithSelectProjectAndIdentityProjectDoesNotRename()
{ {
cannotVectorizeUnlessFallback();
msqIncompatible(); msqIncompatible();
testQuery( testQuery(
PLANNER_CONFIG_NO_HLL.withOverrides( PLANNER_CONFIG_NO_HLL.withOverrides(
@ -9506,25 +9539,30 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
"v0", "v0",
"((\"__time\" >= 947005200000) && (\"__time\" < 1641402000000))", "((\"__time\" >= 947005200000) && (\"__time\" < 1641402000000))",
ColumnType.LONG ColumnType.LONG
),
expressionVirtualColumn(
"v1",
"case_searched(((\"__time\" >= 947005200000) && (\"__time\" < 1641402000000)),1,0)",
ColumnType.LONG
) )
) )
.setDimensions( .setDimensions(
dimensions( dimensions(
new DefaultDimensionSpec("dim1", "d0", ColumnType.STRING), new DefaultDimensionSpec("v0", "d0", ColumnType.LONG),
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) new DefaultDimensionSpec("dim1", "d1", ColumnType.STRING)
) )
) )
.setAggregatorSpecs( .setAggregatorSpecs(
aggregators( aggregators(
new LongSumAggregatorFactory("a0", "v1"), new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
range(
"__time",
ColumnType.LONG,
timestamp("2000-01-04T17:00:00"),
timestamp("2022-01-05T17:00:00"),
false,
true
)
),
new GroupingAggregatorFactory( new GroupingAggregatorFactory(
"a1", "a1",
ImmutableList.of("dim1", "v0") ImmutableList.of("v0", "dim1")
) )
) )
) )
@ -9549,9 +9587,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
new CountAggregatorFactory("_a1"), new CountAggregatorFactory("_a1"),
and( and(
notNull("d0"), notNull("d1"),
equality("a1", 0L, ColumnType.LONG), equality("a1", 0L, ColumnType.LONG),
expressionFilter("\"d1\"") expressionFilter("\"d0\"")
) )
) )
) )

View File

@ -0,0 +1,120 @@
!use druidtest://?numMergeBuffers=3
!set outputformat mysql
-- empty input
SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is null,
SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) FROM numfoo where l1 < -1;
+--------+--------+--------+--------+--------+
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
+--------+--------+--------+--------+--------+
| 0 | 0 | true | | |
+--------+--------+--------+--------+--------+
(1 row)
!ok
-- 0=-1,0
SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is null,
SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) FROM numfoo where l1 < 3;
+--------+--------+--------+--------+--------+
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
+--------+--------+--------+--------+--------+
| 0 | 1 | false | | |
+--------+--------+--------+--------+--------+
(1 row)
!ok
-- 0=0,0
SELECT COUNT(1)FILTER(WHERE l1=0),COUNT(1)FILTER(WHERE l1!=0),MIN(l2) is null,
SUM(CASE WHEN l1 = 0 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=0) FROM numfoo where l1 < 3;
+--------+--------+--------+--------+--------+
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
+--------+--------+--------+--------+--------+
| 1 | 0 | false | 0 | 0 |
+--------+--------+--------+--------+--------+
(1 row)
!ok
-- 7=7,null
SELECT COUNT(1)FILTER(WHERE l1=7),COUNT(1)FILTER(WHERE l1!=7),MIN(l2) is null,
SUM(CASE WHEN l1 = 7 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=7) FROM numfoo where 0 < l1 and l1 < 10;
+--------+--------+--------+--------+--------+
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
+--------+--------+--------+--------+--------+
| 1 | 0 | true | | |
+--------+--------+--------+--------+--------+
(1 row)
!ok
LogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[IS NULL($2)], EXPR$3=[$3], EXPR$4=[$4])
LogicalAggregate(group=[{}], EXPR$0=[COUNT() FILTER $0], EXPR$1=[COUNT() FILTER $1], agg#2=[MIN($2)], EXPR$3=[SUM($3)], EXPR$4=[SUM($2) FILTER $0])
LogicalProject($f1=[IS TRUE(=($0, 7))], $f2=[IS TRUE(<>($0, 7))], l2=[$1], $f4=[CASE(=($0, 7), $1, 0:BIGINT)])
LogicalFilter(condition=[SEARCH($0, Sarg[(0..10)])])
LogicalProject(l1=[$11], l2=[$12])
LogicalTableScan(table=[[druid, numfoo]])
!druidPlan
!set extendedFilteredSumRewrite false
!use druidtest://?numMergeBuffers=3
-- empty input
SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is null,
SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) FROM numfoo where l1 < -1;
+--------+--------+--------+--------+--------+
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
+--------+--------+--------+--------+--------+
| 0 | 0 | true | | |
+--------+--------+--------+--------+--------+
(1 row)
!ok
-- 0=-1,0
SELECT COUNT(1)FILTER(WHERE l1=-1),COUNT(1)FILTER(WHERE l1!=-1),MIN(l2) is null,
SUM(CASE WHEN l1 = -1 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=-1) FROM numfoo where l1 < 3;
+--------+--------+--------+--------+--------+
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
+--------+--------+--------+--------+--------+
| 0 | 1 | false | 0 | |
+--------+--------+--------+--------+--------+
(1 row)
!ok
-- 0=0,0
SELECT COUNT(1)FILTER(WHERE l1=0),COUNT(1)FILTER(WHERE l1!=0),MIN(l2) is null,
SUM(CASE WHEN l1 = 0 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=0) FROM numfoo where l1 < 3;
+--------+--------+--------+--------+--------+
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
+--------+--------+--------+--------+--------+
| 1 | 0 | false | 0 | 0 |
+--------+--------+--------+--------+--------+
(1 row)
!ok
-- 7=7,null
SELECT COUNT(1)FILTER(WHERE l1=7),COUNT(1)FILTER(WHERE l1!=7),MIN(l2) is null,
SUM(CASE WHEN l1 = 7 THEN l2 ELSE 0 END),SUM(l2) FILTER(WHERE l1=7) FROM numfoo where 0 < l1 and l1 < 10;
+--------+--------+--------+--------+--------+
| EXPR$0 | EXPR$1 | EXPR$2 | EXPR$3 | EXPR$4 |
+--------+--------+--------+--------+--------+
| 1 | 0 | true | | |
+--------+--------+--------+--------+--------+
(1 row)
!ok
LogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[IS NULL($2)], EXPR$3=[$3], EXPR$4=[$4])
LogicalAggregate(group=[{}], EXPR$0=[COUNT() FILTER $0], EXPR$1=[COUNT() FILTER $1], agg#2=[MIN($2)], EXPR$3=[SUM($3)], EXPR$4=[SUM($2) FILTER $0])
LogicalProject($f1=[IS TRUE(=($0, 7))], $f2=[IS TRUE(<>($0, 7))], l2=[$1], $f4=[CASE(=($0, 7), $1, 0:BIGINT)])
LogicalFilter(condition=[SEARCH($0, Sarg[(0..10)])])
LogicalProject(l1=[$11], l2=[$12])
LogicalTableScan(table=[[druid, numfoo]])
!druidPlan