mirror of
https://github.com/apache/druid.git
synced 2025-02-17 07:25:02 +00:00
Remove DruidAggregateCaseToFilterRule (#14940)
The issue due to which the custom rule was added has been fixed as a part of https://issues.apache.org/jira/browse/CALCITE-3763 and accommodated during Calcite upgrade
This commit is contained in:
parent
6ee0b06e38
commit
23308c050d
@ -56,7 +56,6 @@ import org.apache.druid.sql.calcite.rule.ExtensionCalciteRuleProvider;
|
|||||||
import org.apache.druid.sql.calcite.rule.FilterJoinExcludePushToChildRule;
|
import org.apache.druid.sql.calcite.rule.FilterJoinExcludePushToChildRule;
|
||||||
import org.apache.druid.sql.calcite.rule.ProjectAggregatePruneUnusedCallRule;
|
import org.apache.druid.sql.calcite.rule.ProjectAggregatePruneUnusedCallRule;
|
||||||
import org.apache.druid.sql.calcite.rule.SortCollapseRule;
|
import org.apache.druid.sql.calcite.rule.SortCollapseRule;
|
||||||
import org.apache.druid.sql.calcite.rule.logical.DruidAggregateCaseToFilterRule;
|
|
||||||
import org.apache.druid.sql.calcite.rule.logical.DruidLogicalRules;
|
import org.apache.druid.sql.calcite.rule.logical.DruidLogicalRules;
|
||||||
import org.apache.druid.sql.calcite.run.EngineFeature;
|
import org.apache.druid.sql.calcite.run.EngineFeature;
|
||||||
|
|
||||||
@ -318,7 +317,6 @@ public class CalciteRulesManager
|
|||||||
final ImmutableList.Builder<RelOptRule> retVal = ImmutableList
|
final ImmutableList.Builder<RelOptRule> retVal = ImmutableList
|
||||||
.<RelOptRule>builder()
|
.<RelOptRule>builder()
|
||||||
.addAll(baseRuleSet(plannerContext))
|
.addAll(baseRuleSet(plannerContext))
|
||||||
.add(DruidAggregateCaseToFilterRule.INSTANCE)
|
|
||||||
.add(new DruidLogicalRules(plannerContext).rules().toArray(new RelOptRule[0]));
|
.add(new DruidLogicalRules(plannerContext).rules().toArray(new RelOptRule[0]));
|
||||||
return retVal.build();
|
return retVal.build();
|
||||||
}
|
}
|
||||||
|
@ -1,339 +0,0 @@
|
|||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
* or more contributor license agreements. See the NOTICE file
|
|
||||||
* distributed with this work for additional information
|
|
||||||
* regarding copyright ownership. The ASF licenses this file
|
|
||||||
* to you under the Apache License, Version 2.0 (the
|
|
||||||
* "License"); you may not use this file except in compliance
|
|
||||||
* with the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing,
|
|
||||||
* software distributed under the License is distributed on an
|
|
||||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
* KIND, either express or implied. See the License for the
|
|
||||||
* specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.druid.sql.calcite.rule.logical;
|
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableList;
|
|
||||||
import org.apache.calcite.plan.RelOptCluster;
|
|
||||||
import org.apache.calcite.plan.RelOptRule;
|
|
||||||
import org.apache.calcite.plan.RelOptRuleCall;
|
|
||||||
import org.apache.calcite.rel.RelCollations;
|
|
||||||
import org.apache.calcite.rel.core.Aggregate;
|
|
||||||
import org.apache.calcite.rel.core.AggregateCall;
|
|
||||||
import org.apache.calcite.rel.core.Project;
|
|
||||||
import org.apache.calcite.rel.core.RelFactories;
|
|
||||||
import org.apache.calcite.rel.type.RelDataType;
|
|
||||||
import org.apache.calcite.rel.type.RelDataTypeFactory;
|
|
||||||
import org.apache.calcite.rex.RexBuilder;
|
|
||||||
import org.apache.calcite.rex.RexCall;
|
|
||||||
import org.apache.calcite.rex.RexLiteral;
|
|
||||||
import org.apache.calcite.rex.RexNode;
|
|
||||||
import org.apache.calcite.sql.SqlKind;
|
|
||||||
import org.apache.calcite.sql.SqlPostfixOperator;
|
|
||||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
|
||||||
import org.apache.calcite.sql.type.SqlTypeName;
|
|
||||||
import org.apache.calcite.tools.RelBuilder;
|
|
||||||
import org.apache.calcite.tools.RelBuilderFactory;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A copy of {@link org.apache.calcite.rel.rules.AggregateCaseToFilterRule} except that it fixes a bug to eliminate
|
|
||||||
* left-over projects for converted aggregates to filter-aggregates. The elimination of left-over projects is necessary
|
|
||||||
* with the new planning since it determines the cost of the plan and hence determines which plan is going to get picked
|
|
||||||
* as the cheapest one.
|
|
||||||
* This fix will also be contributed upstream to Calcite project, and we can remove this rule once the fix is a part of
|
|
||||||
* the Calcite version we use.
|
|
||||||
*/
|
|
||||||
public class DruidAggregateCaseToFilterRule extends RelOptRule
|
|
||||||
{
|
|
||||||
public static final DruidAggregateCaseToFilterRule INSTANCE =
|
|
||||||
new DruidAggregateCaseToFilterRule(RelFactories.LOGICAL_BUILDER, null);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an AggregateCaseToFilterRule.
|
|
||||||
*/
|
|
||||||
protected DruidAggregateCaseToFilterRule(
|
|
||||||
RelBuilderFactory relBuilderFactory,
|
|
||||||
String description
|
|
||||||
)
|
|
||||||
{
|
|
||||||
super(operand(Aggregate.class, operand(Project.class, any())),
|
|
||||||
relBuilderFactory, description
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean matches(final RelOptRuleCall call)
|
|
||||||
{
|
|
||||||
final Aggregate aggregate = call.rel(0);
|
|
||||||
final Project project = call.rel(1);
|
|
||||||
|
|
||||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
|
||||||
final int singleArg = soleArgument(aggregateCall);
|
|
||||||
if (singleArg >= 0
|
|
||||||
&& isThreeArgCase(project.getProjects().get(singleArg))) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onMatch(RelOptRuleCall call)
|
|
||||||
{
|
|
||||||
final Aggregate aggregate = call.rel(0);
|
|
||||||
final Project project = call.rel(1);
|
|
||||||
final List<AggregateCall> newCalls =
|
|
||||||
new ArrayList<>(aggregate.getAggCallList().size());
|
|
||||||
List<RexNode> newProjects;
|
|
||||||
|
|
||||||
// TODO : fix grouping columns
|
|
||||||
Set<Integer> groupUsedFields = new HashSet<>();
|
|
||||||
for (int fieldNumber : aggregate.getGroupSet()) {
|
|
||||||
groupUsedFields.add(fieldNumber);
|
|
||||||
}
|
|
||||||
|
|
||||||
List<RexNode> updatedProjects = new ArrayList<>();
|
|
||||||
for (int i = 0; i < project.getProjects().size(); i++) {
|
|
||||||
if (groupUsedFields.contains(i)) {
|
|
||||||
updatedProjects.add(project.getProjects().get(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
newProjects = updatedProjects;
|
|
||||||
|
|
||||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
|
||||||
AggregateCall newCall =
|
|
||||||
transform(aggregateCall, project, newProjects);
|
|
||||||
|
|
||||||
// Possibly CAST the new aggregator to an appropriate type.
|
|
||||||
newCalls.add(newCall);
|
|
||||||
}
|
|
||||||
final RelBuilder relBuilder = call.builder()
|
|
||||||
.push(project.getInput())
|
|
||||||
.project(newProjects);
|
|
||||||
|
|
||||||
final RelBuilder.GroupKey groupKey =
|
|
||||||
relBuilder.groupKey(
|
|
||||||
aggregate.getGroupSet(),
|
|
||||||
aggregate.getGroupSets()
|
|
||||||
);
|
|
||||||
|
|
||||||
relBuilder.aggregate(groupKey, newCalls)
|
|
||||||
.convert(aggregate.getRowType(), false);
|
|
||||||
|
|
||||||
call.transformTo(relBuilder.build());
|
|
||||||
call.getPlanner().prune(aggregate);
|
|
||||||
}
|
|
||||||
|
|
||||||
private AggregateCall transform(AggregateCall aggregateCall, Project project, List<RexNode> newProjects)
|
|
||||||
{
|
|
||||||
final int singleArg = soleArgument(aggregateCall);
|
|
||||||
if (singleArg < 0) {
|
|
||||||
Set<Integer> newFields = new HashSet<>();
|
|
||||||
for (int fieldNumber : aggregateCall.getArgList()) {
|
|
||||||
newProjects.add(project.getProjects().get(fieldNumber));
|
|
||||||
newFields.add(newProjects.size() - 1);
|
|
||||||
}
|
|
||||||
int newFilterArg = -1;
|
|
||||||
if (aggregateCall.hasFilter()) {
|
|
||||||
newProjects.add(project.getProjects().get(aggregateCall.filterArg));
|
|
||||||
newFilterArg = newProjects.size() - 1;
|
|
||||||
}
|
|
||||||
return AggregateCall.create(aggregateCall.getAggregation(),
|
|
||||||
aggregateCall.isDistinct(),
|
|
||||||
aggregateCall.isApproximate(),
|
|
||||||
aggregateCall.ignoreNulls(),
|
|
||||||
new ArrayList<>(newFields),
|
|
||||||
newFilterArg,
|
|
||||||
aggregateCall.getCollation(),
|
|
||||||
aggregateCall.getType(),
|
|
||||||
aggregateCall.getName()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
final RexNode rexNode = project.getProjects().get(singleArg);
|
|
||||||
if (!isThreeArgCase(rexNode)) {
|
|
||||||
newProjects.add(rexNode);
|
|
||||||
int callArg = newProjects.size() - 1;
|
|
||||||
int newFilterArg = -1;
|
|
||||||
if (aggregateCall.hasFilter()) {
|
|
||||||
newProjects.add(project.getProjects().get(aggregateCall.filterArg));
|
|
||||||
newFilterArg = newProjects.size() - 1;
|
|
||||||
}
|
|
||||||
return AggregateCall.create(aggregateCall.getAggregation(),
|
|
||||||
aggregateCall.isDistinct(),
|
|
||||||
aggregateCall.isApproximate(),
|
|
||||||
aggregateCall.ignoreNulls(),
|
|
||||||
ImmutableList.of(callArg),
|
|
||||||
newFilterArg,
|
|
||||||
aggregateCall.getCollation(),
|
|
||||||
aggregateCall.getType(),
|
|
||||||
aggregateCall.getName()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
final RelOptCluster cluster = project.getCluster();
|
|
||||||
final RexBuilder rexBuilder = cluster.getRexBuilder();
|
|
||||||
final RexCall caseCall = (RexCall) rexNode;
|
|
||||||
|
|
||||||
// If one arg is null and the other is not, reverse them and set "flip",
|
|
||||||
// which negates the filter.
|
|
||||||
final boolean flip = RexLiteral.isNullLiteral(caseCall.operands.get(1))
|
|
||||||
&& !RexLiteral.isNullLiteral(caseCall.operands.get(2));
|
|
||||||
final RexNode arg1 = caseCall.operands.get(flip ? 2 : 1);
|
|
||||||
final RexNode arg2 = caseCall.operands.get(flip ? 1 : 2);
|
|
||||||
|
|
||||||
// Operand 1: Filter
|
|
||||||
final SqlPostfixOperator op =
|
|
||||||
flip ? SqlStdOperatorTable.IS_FALSE : SqlStdOperatorTable.IS_TRUE;
|
|
||||||
final RexNode filterFromCase =
|
|
||||||
rexBuilder.makeCall(op, caseCall.operands.get(0));
|
|
||||||
|
|
||||||
// Combine the CASE filter with an honest-to-goodness SQL FILTER, if the
|
|
||||||
// latter is present.
|
|
||||||
final RexNode filter;
|
|
||||||
if (aggregateCall.filterArg >= 0) {
|
|
||||||
filter = rexBuilder.makeCall(SqlStdOperatorTable.AND,
|
|
||||||
project.getProjects().get(aggregateCall.filterArg), filterFromCase
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
filter = filterFromCase;
|
|
||||||
}
|
|
||||||
|
|
||||||
final SqlKind kind = aggregateCall.getAggregation().getKind();
|
|
||||||
if (aggregateCall.isDistinct()) {
|
|
||||||
// Just one style supported:
|
|
||||||
// COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END)
|
|
||||||
// =>
|
|
||||||
// COUNT(DISTINCT y) FILTER(WHERE x = 'foo')
|
|
||||||
|
|
||||||
if (kind == SqlKind.COUNT
|
|
||||||
&& RexLiteral.isNullLiteral(arg2)) {
|
|
||||||
newProjects.add(arg1);
|
|
||||||
newProjects.add(filter);
|
|
||||||
return AggregateCall.create(SqlStdOperatorTable.COUNT, true, false,
|
|
||||||
false, ImmutableList.of(newProjects.size() - 2),
|
|
||||||
newProjects.size() - 1, RelCollations.EMPTY,
|
|
||||||
aggregateCall.getType(), aggregateCall.getName()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
newProjects.add(rexNode);
|
|
||||||
int callArg = newProjects.size() - 1;
|
|
||||||
int newFilterArg = -1;
|
|
||||||
if (aggregateCall.hasFilter()) {
|
|
||||||
newProjects.add(project.getProjects().get(aggregateCall.filterArg));
|
|
||||||
newFilterArg = newProjects.size() - 1;
|
|
||||||
}
|
|
||||||
return AggregateCall.create(aggregateCall.getAggregation(),
|
|
||||||
aggregateCall.isDistinct(),
|
|
||||||
aggregateCall.isApproximate(),
|
|
||||||
aggregateCall.ignoreNulls(),
|
|
||||||
ImmutableList.of(callArg),
|
|
||||||
newFilterArg,
|
|
||||||
aggregateCall.getCollation(),
|
|
||||||
aggregateCall.getType(),
|
|
||||||
aggregateCall.getName()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Four styles supported:
|
|
||||||
//
|
|
||||||
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END)
|
|
||||||
// => operands (x = 'foo', cnt, null)
|
|
||||||
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
|
|
||||||
// => operands (x = 'foo', cnt, 0); must be SUM
|
|
||||||
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
|
|
||||||
// => operands (x = 'foo', 1, 0); must be SUM
|
|
||||||
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END)
|
|
||||||
// => operands (x = 'foo', 'dummy', null)
|
|
||||||
|
|
||||||
if (kind == SqlKind.COUNT // Case C
|
|
||||||
&& arg1.isA(SqlKind.LITERAL)
|
|
||||||
&& !RexLiteral.isNullLiteral(arg1)
|
|
||||||
&& RexLiteral.isNullLiteral(arg2)) {
|
|
||||||
newProjects.add(filter);
|
|
||||||
return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
|
|
||||||
false, ImmutableList.of(), newProjects.size() - 1,
|
|
||||||
RelCollations.EMPTY, aggregateCall.getType(),
|
|
||||||
aggregateCall.getName()
|
|
||||||
);
|
|
||||||
} else if (kind == SqlKind.SUM // Case B
|
|
||||||
&& isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
|
|
||||||
&& isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
|
|
||||||
|
|
||||||
newProjects.add(filter);
|
|
||||||
final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
|
|
||||||
final RelDataType dataType =
|
|
||||||
typeFactory.createTypeWithNullability(
|
|
||||||
typeFactory.createSqlType(SqlTypeName.BIGINT), false);
|
|
||||||
return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
|
|
||||||
false, ImmutableList.of(), newProjects.size() - 1,
|
|
||||||
RelCollations.EMPTY, dataType, aggregateCall.getName()
|
|
||||||
);
|
|
||||||
} else if ((RexLiteral.isNullLiteral(arg2) // Case A1
|
|
||||||
&& aggregateCall.getAggregation().allowsFilter())
|
|
||||||
|| (kind == SqlKind.SUM // Case A2
|
|
||||||
&& isIntLiteral(arg2)
|
|
||||||
&& RexLiteral.intValue(arg2) == 0)) {
|
|
||||||
newProjects.add(arg1);
|
|
||||||
newProjects.add(filter);
|
|
||||||
return AggregateCall.create(aggregateCall.getAggregation(), false,
|
|
||||||
false, false, ImmutableList.of(newProjects.size() - 2),
|
|
||||||
newProjects.size() - 1, RelCollations.EMPTY,
|
|
||||||
aggregateCall.getType(), aggregateCall.getName()
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
newProjects.add(rexNode);
|
|
||||||
int callArg = newProjects.size() - 1;
|
|
||||||
int newFilterArg = -1;
|
|
||||||
if (aggregateCall.hasFilter()) {
|
|
||||||
newProjects.add(project.getProjects().get(aggregateCall.filterArg));
|
|
||||||
newFilterArg = newProjects.size() - 1;
|
|
||||||
}
|
|
||||||
return AggregateCall.create(aggregateCall.getAggregation(),
|
|
||||||
aggregateCall.isDistinct(),
|
|
||||||
aggregateCall.isApproximate(),
|
|
||||||
aggregateCall.ignoreNulls(),
|
|
||||||
ImmutableList.of(callArg),
|
|
||||||
newFilterArg,
|
|
||||||
aggregateCall.getCollation(),
|
|
||||||
aggregateCall.getType(),
|
|
||||||
aggregateCall.getName()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the argument, if an aggregate call has a single argument,
|
|
||||||
* otherwise -1.
|
|
||||||
*/
|
|
||||||
private static int soleArgument(AggregateCall aggregateCall)
|
|
||||||
{
|
|
||||||
return aggregateCall.getArgList().size() == 1
|
|
||||||
? aggregateCall.getArgList().get(0)
|
|
||||||
: -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static boolean isThreeArgCase(final RexNode rexNode)
|
|
||||||
{
|
|
||||||
return rexNode.getKind() == SqlKind.CASE
|
|
||||||
&& ((RexCall) rexNode).operands.size() == 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static boolean isIntLiteral(final RexNode rexNode)
|
|
||||||
{
|
|
||||||
return rexNode instanceof RexLiteral
|
|
||||||
&& SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName());
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user