Replace CaseFilteredAggregatorRule with Calcite equivalent. (#9113)

AggregateCaseToFilterRule was added to Calcite in https://issues.apache.org/jira/browse/CALCITE-3144,
and was originally copied from Druid's CaseFilteredAggregatorRule. So there isn't a good reason to
keep using our version.
This commit is contained in:
Gian Merlino 2020-01-04 19:11:18 -08:00 committed by Fangjin Yang
parent bdd0d0d8a5
commit 66657012bf
2 changed files with 2 additions and 243 deletions

View File

@ -30,6 +30,7 @@ import org.apache.calcite.plan.volcano.AbstractConverter;
import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider; import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider;
import org.apache.calcite.rel.rules.AggregateCaseToFilterRule;
import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule; import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule;
import org.apache.calcite.rel.rules.AggregateJoinTransposeRule; import org.apache.calcite.rel.rules.AggregateJoinTransposeRule;
import org.apache.calcite.rel.rules.AggregateProjectMergeRule; import org.apache.calcite.rel.rules.AggregateProjectMergeRule;
@ -71,7 +72,6 @@ import org.apache.calcite.tools.Program;
import org.apache.calcite.tools.Programs; import org.apache.calcite.tools.Programs;
import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.tools.RelBuilder;
import org.apache.druid.sql.calcite.rel.QueryMaker; import org.apache.druid.sql.calcite.rel.QueryMaker;
import org.apache.druid.sql.calcite.rule.CaseFilteredAggregatorRule;
import org.apache.druid.sql.calcite.rule.DruidRelToDruidRule; import org.apache.druid.sql.calcite.rule.DruidRelToDruidRule;
import org.apache.druid.sql.calcite.rule.DruidRules; import org.apache.druid.sql.calcite.rule.DruidRules;
import org.apache.druid.sql.calcite.rule.DruidSemiJoinRule; import org.apache.druid.sql.calcite.rule.DruidSemiJoinRule;
@ -237,7 +237,7 @@ public class Rules
} }
rules.add(SortCollapseRule.instance()); rules.add(SortCollapseRule.instance());
rules.add(CaseFilteredAggregatorRule.instance()); rules.add(AggregateCaseToFilterRule.INSTANCE);
rules.add(ProjectAggregatePruneUnusedCallRule.instance()); rules.add(ProjectAggregatePruneUnusedCallRule.instance());
return rules.build(); return rules.build();

View File

@ -1,241 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.rule;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.druid.sql.calcite.planner.Calcites;
import java.util.ArrayList;
import java.util.List;
/**
* Rule that converts CASE-style filtered aggregation into true filtered aggregations.
*/
public class CaseFilteredAggregatorRule extends RelOptRule
{
private static final CaseFilteredAggregatorRule INSTANCE = new CaseFilteredAggregatorRule();
private CaseFilteredAggregatorRule()
{
super(operand(Aggregate.class, operand(Project.class, any())));
}
public static CaseFilteredAggregatorRule instance()
{
return INSTANCE;
}
@Override
public boolean matches(final RelOptRuleCall call)
{
final Aggregate aggregate = call.rel(0);
final Project project = call.rel(1);
if (aggregate.indicator || aggregate.getGroupSets().size() != 1) {
return false;
}
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
if (isOneArgAggregateCall(aggregateCall)
&& isThreeArgCase(project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList())))) {
return true;
}
}
return false;
}
@Override
public void onMatch(RelOptRuleCall call)
{
final Aggregate aggregate = call.rel(0);
final Project project = call.rel(1);
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final List<AggregateCall> newCalls = new ArrayList<>(aggregate.getAggCallList().size());
final List<RexNode> newProjects = new ArrayList<>(project.getChildExps());
final List<RexNode> newCasts = new ArrayList<>(aggregate.getGroupCount() + aggregate.getAggCallList().size());
final RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
for (int fieldNumber : aggregate.getGroupSet()) {
newCasts.add(rexBuilder.makeInputRef(project.getChildExps().get(fieldNumber).getType(), fieldNumber));
}
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
AggregateCall newCall = null;
if (isOneArgAggregateCall(aggregateCall)) {
final RexNode rexNode = project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList()));
if (isThreeArgCase(rexNode)) {
final RexCall caseCall = (RexCall) rexNode;
// If one arg is null and the other is not, reverse them and set "flip", which negates the filter.
final boolean flip = RexLiteral.isNullLiteral(caseCall.getOperands().get(1))
&& !RexLiteral.isNullLiteral(caseCall.getOperands().get(2));
final RexNode arg1 = caseCall.getOperands().get(flip ? 2 : 1);
final RexNode arg2 = caseCall.getOperands().get(flip ? 1 : 2);
// Operand 1: Filter
final RexNode filter;
final RelDataType booleanType = Calcites.createSqlType(typeFactory, SqlTypeName.BOOLEAN);
final RexNode filterFromCase = rexBuilder.makeCall(
booleanType,
flip ? SqlStdOperatorTable.IS_FALSE : SqlStdOperatorTable.IS_TRUE,
ImmutableList.of(caseCall.getOperands().get(0))
);
// Combine the CASE filter with an honest-to-goodness SQL FILTER, if the latter is present.
if (aggregateCall.filterArg >= 0) {
filter = rexBuilder.makeCall(
booleanType,
SqlStdOperatorTable.AND,
ImmutableList.of(project.getProjects().get(aggregateCall.filterArg), filterFromCase)
);
} else {
filter = filterFromCase;
}
if (aggregateCall.isDistinct()) {
// Just one style supported:
// COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END) => COUNT(DISTINCT y) FILTER(WHERE x = 'foo')
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && RexLiteral.isNullLiteral(arg2)) {
newProjects.add(arg1);
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
true,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
}
} else {
// Four styles supported:
//
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0); must be SUM
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT
&& arg1.isA(SqlKind.LITERAL)
&& !RexLiteral.isNullLiteral(arg1)
&& RexLiteral.isNullLiteral(arg2)) {
// Case C
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
ImmutableList.of(),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
} else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM
&& Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
&& Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
// Case B
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
ImmutableList.of(),
newProjects.size() - 1,
Calcites.createSqlType(typeFactory, SqlTypeName.BIGINT),
aggregateCall.getName()
);
} else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */
|| (aggregateCall.getAggregation().getKind() == SqlKind.SUM
&& Calcites.isIntLiteral(arg2)
&& RexLiteral.intValue(arg2) == 0) /* Case A2 */) {
newProjects.add(arg1);
newProjects.add(filter);
newCall = AggregateCall.create(
aggregateCall.getAggregation(),
false,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
}
}
}
}
newCalls.add(newCall == null ? aggregateCall : newCall);
// Possibly CAST the new aggregator to an appropriate type.
final int i = newCasts.size();
final RelDataType oldType = aggregate.getRowType().getFieldList().get(i).getType();
if (newCall == null) {
newCasts.add(rexBuilder.makeInputRef(oldType, i));
} else {
newCasts.add(rexBuilder.makeCast(oldType, rexBuilder.makeInputRef(newCall.getType(), i)));
}
}
if (!newCalls.equals(aggregate.getAggCallList())) {
final RelBuilder relBuilder = call
.builder()
.push(project.getInput())
.project(newProjects);
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(
aggregate.getGroupSet(),
aggregate.getGroupSets()
);
final RelNode newAggregate = relBuilder.aggregate(groupKey, newCalls).project(newCasts).build();
call.transformTo(newAggregate);
call.getPlanner().setImportance(aggregate, 0.0);
}
}
private static boolean isOneArgAggregateCall(final AggregateCall aggregateCall)
{
return aggregateCall.getArgList().size() == 1;
}
private static boolean isThreeArgCase(final RexNode rexNode)
{
return rexNode.getKind() == SqlKind.CASE && ((RexCall) rexNode).getOperands().size() == 3;
}
}