mirror of https://github.com/apache/druid.git
Replace CaseFilteredAggregatorRule with Calcite equivalent. (#9113)
AggregateCaseToFilterRule was added to Calcite in https://issues.apache.org/jira/browse/CALCITE-3144, and was originally copied from Druid's CaseFilteredAggregatorRule. So there isn't a good reason to keep using our version.
This commit is contained in:
parent
bdd0d0d8a5
commit
66657012bf
|
@ -30,6 +30,7 @@ import org.apache.calcite.plan.volcano.AbstractConverter;
|
||||||
import org.apache.calcite.rel.RelNode;
|
import org.apache.calcite.rel.RelNode;
|
||||||
import org.apache.calcite.rel.core.RelFactories;
|
import org.apache.calcite.rel.core.RelFactories;
|
||||||
import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider;
|
import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider;
|
||||||
|
import org.apache.calcite.rel.rules.AggregateCaseToFilterRule;
|
||||||
import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule;
|
import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule;
|
||||||
import org.apache.calcite.rel.rules.AggregateJoinTransposeRule;
|
import org.apache.calcite.rel.rules.AggregateJoinTransposeRule;
|
||||||
import org.apache.calcite.rel.rules.AggregateProjectMergeRule;
|
import org.apache.calcite.rel.rules.AggregateProjectMergeRule;
|
||||||
|
@ -71,7 +72,6 @@ import org.apache.calcite.tools.Program;
|
||||||
import org.apache.calcite.tools.Programs;
|
import org.apache.calcite.tools.Programs;
|
||||||
import org.apache.calcite.tools.RelBuilder;
|
import org.apache.calcite.tools.RelBuilder;
|
||||||
import org.apache.druid.sql.calcite.rel.QueryMaker;
|
import org.apache.druid.sql.calcite.rel.QueryMaker;
|
||||||
import org.apache.druid.sql.calcite.rule.CaseFilteredAggregatorRule;
|
|
||||||
import org.apache.druid.sql.calcite.rule.DruidRelToDruidRule;
|
import org.apache.druid.sql.calcite.rule.DruidRelToDruidRule;
|
||||||
import org.apache.druid.sql.calcite.rule.DruidRules;
|
import org.apache.druid.sql.calcite.rule.DruidRules;
|
||||||
import org.apache.druid.sql.calcite.rule.DruidSemiJoinRule;
|
import org.apache.druid.sql.calcite.rule.DruidSemiJoinRule;
|
||||||
|
@ -237,7 +237,7 @@ public class Rules
|
||||||
}
|
}
|
||||||
|
|
||||||
rules.add(SortCollapseRule.instance());
|
rules.add(SortCollapseRule.instance());
|
||||||
rules.add(CaseFilteredAggregatorRule.instance());
|
rules.add(AggregateCaseToFilterRule.INSTANCE);
|
||||||
rules.add(ProjectAggregatePruneUnusedCallRule.instance());
|
rules.add(ProjectAggregatePruneUnusedCallRule.instance());
|
||||||
|
|
||||||
return rules.build();
|
return rules.build();
|
||||||
|
|
|
@ -1,241 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
* or more contributor license agreements. See the NOTICE file
|
|
||||||
* distributed with this work for additional information
|
|
||||||
* regarding copyright ownership. The ASF licenses this file
|
|
||||||
* to you under the Apache License, Version 2.0 (the
|
|
||||||
* "License"); you may not use this file except in compliance
|
|
||||||
* with the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing,
|
|
||||||
* software distributed under the License is distributed on an
|
|
||||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
* KIND, either express or implied. See the License for the
|
|
||||||
* specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.druid.sql.calcite.rule;
|
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableList;
|
|
||||||
import com.google.common.collect.Iterables;
|
|
||||||
import org.apache.calcite.plan.RelOptRule;
|
|
||||||
import org.apache.calcite.plan.RelOptRuleCall;
|
|
||||||
import org.apache.calcite.rel.RelNode;
|
|
||||||
import org.apache.calcite.rel.core.Aggregate;
|
|
||||||
import org.apache.calcite.rel.core.AggregateCall;
|
|
||||||
import org.apache.calcite.rel.core.Project;
|
|
||||||
import org.apache.calcite.rel.type.RelDataType;
|
|
||||||
import org.apache.calcite.rel.type.RelDataTypeFactory;
|
|
||||||
import org.apache.calcite.rex.RexBuilder;
|
|
||||||
import org.apache.calcite.rex.RexCall;
|
|
||||||
import org.apache.calcite.rex.RexLiteral;
|
|
||||||
import org.apache.calcite.rex.RexNode;
|
|
||||||
import org.apache.calcite.sql.SqlKind;
|
|
||||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
|
||||||
import org.apache.calcite.sql.type.SqlTypeName;
|
|
||||||
import org.apache.calcite.tools.RelBuilder;
|
|
||||||
import org.apache.druid.sql.calcite.planner.Calcites;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Rule that converts CASE-style filtered aggregation into true filtered aggregations.
|
|
||||||
*/
|
|
||||||
public class CaseFilteredAggregatorRule extends RelOptRule
|
|
||||||
{
|
|
||||||
private static final CaseFilteredAggregatorRule INSTANCE = new CaseFilteredAggregatorRule();
|
|
||||||
|
|
||||||
private CaseFilteredAggregatorRule()
|
|
||||||
{
|
|
||||||
super(operand(Aggregate.class, operand(Project.class, any())));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static CaseFilteredAggregatorRule instance()
|
|
||||||
{
|
|
||||||
return INSTANCE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean matches(final RelOptRuleCall call)
|
|
||||||
{
|
|
||||||
final Aggregate aggregate = call.rel(0);
|
|
||||||
final Project project = call.rel(1);
|
|
||||||
|
|
||||||
if (aggregate.indicator || aggregate.getGroupSets().size() != 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
|
||||||
if (isOneArgAggregateCall(aggregateCall)
|
|
||||||
&& isThreeArgCase(project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList())))) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onMatch(RelOptRuleCall call)
|
|
||||||
{
|
|
||||||
final Aggregate aggregate = call.rel(0);
|
|
||||||
final Project project = call.rel(1);
|
|
||||||
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
|
|
||||||
final List<AggregateCall> newCalls = new ArrayList<>(aggregate.getAggCallList().size());
|
|
||||||
final List<RexNode> newProjects = new ArrayList<>(project.getChildExps());
|
|
||||||
final List<RexNode> newCasts = new ArrayList<>(aggregate.getGroupCount() + aggregate.getAggCallList().size());
|
|
||||||
final RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
|
|
||||||
|
|
||||||
for (int fieldNumber : aggregate.getGroupSet()) {
|
|
||||||
newCasts.add(rexBuilder.makeInputRef(project.getChildExps().get(fieldNumber).getType(), fieldNumber));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
|
||||||
AggregateCall newCall = null;
|
|
||||||
|
|
||||||
if (isOneArgAggregateCall(aggregateCall)) {
|
|
||||||
final RexNode rexNode = project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList()));
|
|
||||||
|
|
||||||
if (isThreeArgCase(rexNode)) {
|
|
||||||
final RexCall caseCall = (RexCall) rexNode;
|
|
||||||
|
|
||||||
// If one arg is null and the other is not, reverse them and set "flip", which negates the filter.
|
|
||||||
final boolean flip = RexLiteral.isNullLiteral(caseCall.getOperands().get(1))
|
|
||||||
&& !RexLiteral.isNullLiteral(caseCall.getOperands().get(2));
|
|
||||||
final RexNode arg1 = caseCall.getOperands().get(flip ? 2 : 1);
|
|
||||||
final RexNode arg2 = caseCall.getOperands().get(flip ? 1 : 2);
|
|
||||||
|
|
||||||
// Operand 1: Filter
|
|
||||||
final RexNode filter;
|
|
||||||
final RelDataType booleanType = Calcites.createSqlType(typeFactory, SqlTypeName.BOOLEAN);
|
|
||||||
final RexNode filterFromCase = rexBuilder.makeCall(
|
|
||||||
booleanType,
|
|
||||||
flip ? SqlStdOperatorTable.IS_FALSE : SqlStdOperatorTable.IS_TRUE,
|
|
||||||
ImmutableList.of(caseCall.getOperands().get(0))
|
|
||||||
);
|
|
||||||
|
|
||||||
// Combine the CASE filter with an honest-to-goodness SQL FILTER, if the latter is present.
|
|
||||||
if (aggregateCall.filterArg >= 0) {
|
|
||||||
filter = rexBuilder.makeCall(
|
|
||||||
booleanType,
|
|
||||||
SqlStdOperatorTable.AND,
|
|
||||||
ImmutableList.of(project.getProjects().get(aggregateCall.filterArg), filterFromCase)
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
filter = filterFromCase;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (aggregateCall.isDistinct()) {
|
|
||||||
// Just one style supported:
|
|
||||||
// COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END) => COUNT(DISTINCT y) FILTER(WHERE x = 'foo')
|
|
||||||
|
|
||||||
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && RexLiteral.isNullLiteral(arg2)) {
|
|
||||||
newProjects.add(arg1);
|
|
||||||
newProjects.add(filter);
|
|
||||||
newCall = AggregateCall.create(
|
|
||||||
SqlStdOperatorTable.COUNT,
|
|
||||||
true,
|
|
||||||
ImmutableList.of(newProjects.size() - 2),
|
|
||||||
newProjects.size() - 1,
|
|
||||||
aggregateCall.getType(),
|
|
||||||
aggregateCall.getName()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Four styles supported:
|
|
||||||
//
|
|
||||||
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
|
|
||||||
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
|
|
||||||
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0); must be SUM
|
|
||||||
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
|
|
||||||
|
|
||||||
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT
|
|
||||||
&& arg1.isA(SqlKind.LITERAL)
|
|
||||||
&& !RexLiteral.isNullLiteral(arg1)
|
|
||||||
&& RexLiteral.isNullLiteral(arg2)) {
|
|
||||||
// Case C
|
|
||||||
newProjects.add(filter);
|
|
||||||
newCall = AggregateCall.create(
|
|
||||||
SqlStdOperatorTable.COUNT,
|
|
||||||
false,
|
|
||||||
ImmutableList.of(),
|
|
||||||
newProjects.size() - 1,
|
|
||||||
aggregateCall.getType(),
|
|
||||||
aggregateCall.getName()
|
|
||||||
);
|
|
||||||
} else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM
|
|
||||||
&& Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
|
|
||||||
&& Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
|
|
||||||
// Case B
|
|
||||||
newProjects.add(filter);
|
|
||||||
newCall = AggregateCall.create(
|
|
||||||
SqlStdOperatorTable.COUNT,
|
|
||||||
false,
|
|
||||||
ImmutableList.of(),
|
|
||||||
newProjects.size() - 1,
|
|
||||||
Calcites.createSqlType(typeFactory, SqlTypeName.BIGINT),
|
|
||||||
aggregateCall.getName()
|
|
||||||
);
|
|
||||||
} else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */
|
|
||||||
|| (aggregateCall.getAggregation().getKind() == SqlKind.SUM
|
|
||||||
&& Calcites.isIntLiteral(arg2)
|
|
||||||
&& RexLiteral.intValue(arg2) == 0) /* Case A2 */) {
|
|
||||||
newProjects.add(arg1);
|
|
||||||
newProjects.add(filter);
|
|
||||||
newCall = AggregateCall.create(
|
|
||||||
aggregateCall.getAggregation(),
|
|
||||||
false,
|
|
||||||
ImmutableList.of(newProjects.size() - 2),
|
|
||||||
newProjects.size() - 1,
|
|
||||||
aggregateCall.getType(),
|
|
||||||
aggregateCall.getName()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newCalls.add(newCall == null ? aggregateCall : newCall);
|
|
||||||
|
|
||||||
// Possibly CAST the new aggregator to an appropriate type.
|
|
||||||
final int i = newCasts.size();
|
|
||||||
final RelDataType oldType = aggregate.getRowType().getFieldList().get(i).getType();
|
|
||||||
if (newCall == null) {
|
|
||||||
newCasts.add(rexBuilder.makeInputRef(oldType, i));
|
|
||||||
} else {
|
|
||||||
newCasts.add(rexBuilder.makeCast(oldType, rexBuilder.makeInputRef(newCall.getType(), i)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!newCalls.equals(aggregate.getAggCallList())) {
|
|
||||||
final RelBuilder relBuilder = call
|
|
||||||
.builder()
|
|
||||||
.push(project.getInput())
|
|
||||||
.project(newProjects);
|
|
||||||
|
|
||||||
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(
|
|
||||||
aggregate.getGroupSet(),
|
|
||||||
aggregate.getGroupSets()
|
|
||||||
);
|
|
||||||
|
|
||||||
final RelNode newAggregate = relBuilder.aggregate(groupKey, newCalls).project(newCasts).build();
|
|
||||||
|
|
||||||
call.transformTo(newAggregate);
|
|
||||||
call.getPlanner().setImportance(aggregate, 0.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static boolean isOneArgAggregateCall(final AggregateCall aggregateCall)
|
|
||||||
{
|
|
||||||
return aggregateCall.getArgList().size() == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static boolean isThreeArgCase(final RexNode rexNode)
|
|
||||||
{
|
|
||||||
return rexNode.getKind() == SqlKind.CASE && ((RexCall) rexNode).getOperands().size() == 3;
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue