mirror of https://github.com/apache/druid.git
Custom Calcite Rule to remove redundant references (#16402)
Custom calcite rule mimicking AggregateProjectMergeRule to extend support to expressions. The current calcite rule return null in such cases. In addition, this removes the redundant references.
This commit is contained in:
parent
760e449875
commit
b8dd7478d0
|
@ -102,6 +102,20 @@ public class GroupingAggregatorFactory extends AggregatorFactory
|
|||
)
|
||||
{
|
||||
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
|
||||
Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), "Must have a non-empty grouping dimensions");
|
||||
// (Long.SIZE - 1) is just a sanity check. In practice, it will be just few dimensions. This limit
|
||||
// also makes sure that values are always positive.
|
||||
Preconditions.checkArgument(
|
||||
groupings.size() < Long.SIZE,
|
||||
"Number of dimensions %s is more than supported %s",
|
||||
groupings.size(),
|
||||
Long.SIZE - 1
|
||||
);
|
||||
Preconditions.checkArgument(
|
||||
groupings.stream().distinct().count() == groupings.size(),
|
||||
"Encountered same dimension more than once in groupings"
|
||||
);
|
||||
|
||||
this.name = name;
|
||||
this.groupings = groupings;
|
||||
this.keyDimensions = keyDimensions;
|
||||
|
@ -254,15 +268,6 @@ public class GroupingAggregatorFactory extends AggregatorFactory
|
|||
*/
|
||||
private long groupingId(List<String> groupings, @Nullable Set<String> keyDimensions)
|
||||
{
|
||||
Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), "Must have a non-empty grouping dimensions");
|
||||
// (Long.SIZE - 1) is just a sanity check. In practice, it will be just few dimensions. This limit
|
||||
// also makes sure that values are always positive.
|
||||
Preconditions.checkArgument(
|
||||
groupings.size() < Long.SIZE,
|
||||
"Number of dimensions %s is more than supported %s",
|
||||
groupings.size(),
|
||||
Long.SIZE - 1
|
||||
);
|
||||
long temp = 0L;
|
||||
for (String groupingDimension : groupings) {
|
||||
temp = temp << 1;
|
||||
|
|
|
@ -131,6 +131,14 @@ public class GroupingAggregatorFactoryTest
|
|||
));
|
||||
makeFactory(new String[Long.SIZE], null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWithDuplicateGroupings()
|
||||
{
|
||||
exception.expect(IllegalArgumentException.class);
|
||||
exception.expectMessage("Encountered same dimension more than once in groupings");
|
||||
makeFactory(new String[]{"a", "a"}, null);
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
|
|
|
@ -90,7 +90,19 @@ public class GroupingSqlAggregator implements SqlAggregator
|
|||
}
|
||||
}
|
||||
}
|
||||
AggregatorFactory factory = new GroupingAggregatorFactory(name, arguments);
|
||||
AggregatorFactory factory;
|
||||
try {
|
||||
factory = new GroupingAggregatorFactory(name, arguments);
|
||||
}
|
||||
catch (Exception e) {
|
||||
plannerContext.setPlanningError(
|
||||
"Initialisation of Grouping Aggregator Factory in case of [%s] threw [%s]",
|
||||
aggregateCall,
|
||||
e.getMessage()
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
return Aggregation.create(factory);
|
||||
}
|
||||
|
||||
|
|
|
@ -66,6 +66,7 @@ import org.apache.druid.sql.calcite.rule.ProjectAggregatePruneUnusedCallRule;
|
|||
import org.apache.druid.sql.calcite.rule.ReverseLookupRule;
|
||||
import org.apache.druid.sql.calcite.rule.RewriteFirstValueLastValueRule;
|
||||
import org.apache.druid.sql.calcite.rule.SortCollapseRule;
|
||||
import org.apache.druid.sql.calcite.rule.logical.DruidAggregateRemoveRedundancyRule;
|
||||
import org.apache.druid.sql.calcite.rule.logical.DruidLogicalRules;
|
||||
import org.apache.druid.sql.calcite.run.EngineFeature;
|
||||
|
||||
|
@ -496,6 +497,7 @@ public class CalciteRulesManager
|
|||
rules.add(FilterJoinExcludePushToChildRule.FILTER_ON_JOIN_EXCLUDE_PUSH_TO_CHILD);
|
||||
rules.add(SortCollapseRule.instance());
|
||||
rules.add(ProjectAggregatePruneUnusedCallRule.instance());
|
||||
rules.add(DruidAggregateRemoveRedundancyRule.instance());
|
||||
|
||||
return rules.build();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.sql.calcite.rule.logical;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Sets;
|
||||
import org.apache.calcite.plan.RelOptRule;
|
||||
import org.apache.calcite.plan.RelOptRuleCall;
|
||||
import org.apache.calcite.plan.RelOptUtil;
|
||||
import org.apache.calcite.rel.RelNode;
|
||||
import org.apache.calcite.rel.core.Aggregate;
|
||||
import org.apache.calcite.rel.core.Aggregate.Group;
|
||||
import org.apache.calcite.rel.core.AggregateCall;
|
||||
import org.apache.calcite.rel.core.Project;
|
||||
import org.apache.calcite.rel.rules.TransformationRule;
|
||||
import org.apache.calcite.rex.RexInputRef;
|
||||
import org.apache.calcite.rex.RexNode;
|
||||
import org.apache.calcite.tools.RelBuilder;
|
||||
import org.apache.calcite.util.ImmutableBitSet;
|
||||
import org.apache.calcite.util.Util;
|
||||
import org.apache.calcite.util.mapping.Mappings;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
import org.immutables.value.Value;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Planner rule that recognizes a {@link Aggregate}
|
||||
* on top of a {@link Project} and if possible
|
||||
* aggregate through the project or removes the project.
|
||||
* <p>
|
||||
* This is updated version of {@link org.apache.calcite.rel.rules.AggregateProjectMergeRule}
|
||||
* to be able to handle expressions.
|
||||
*/
|
||||
@Value.Enclosing
|
||||
public class DruidAggregateRemoveRedundancyRule
|
||||
extends RelOptRule
|
||||
implements TransformationRule
|
||||
{
|
||||
|
||||
/**
|
||||
* Creates a DruidAggregateRemoveRedundancyRule.
|
||||
*/
|
||||
private static final DruidAggregateRemoveRedundancyRule INSTANCE = new DruidAggregateRemoveRedundancyRule();
|
||||
|
||||
private DruidAggregateRemoveRedundancyRule()
|
||||
{
|
||||
super(operand(Aggregate.class, operand(Project.class, any())));
|
||||
}
|
||||
|
||||
public static DruidAggregateRemoveRedundancyRule instance()
|
||||
{
|
||||
return INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMatch(RelOptRuleCall call)
|
||||
{
|
||||
final Aggregate aggregate = call.rel(0);
|
||||
final Project project = call.rel(1);
|
||||
RelNode x = apply(call, aggregate, project);
|
||||
if (x != null) {
|
||||
call.transformTo(x);
|
||||
call.getPlanner().prune(aggregate);
|
||||
}
|
||||
}
|
||||
|
||||
public static @Nullable RelNode apply(RelOptRuleCall call, Aggregate aggregate, Project project)
|
||||
{
|
||||
final Set<Integer> interestingFields = RelOptUtil.getAllFields(aggregate);
|
||||
if (interestingFields.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
final Map<Integer, Integer> map = new HashMap<>();
|
||||
final Map<RexNode, Integer> assignedRefForExpr = new HashMap<>();
|
||||
List<RexNode> newRexNodes = new ArrayList<>();
|
||||
for (int source : interestingFields) {
|
||||
final RexNode rex = project.getProjects().get(source);
|
||||
if (!assignedRefForExpr.containsKey(rex)) {
|
||||
RexNode newNode = new RexInputRef(source, rex.getType());
|
||||
assignedRefForExpr.put(rex, newRexNodes.size());
|
||||
newRexNodes.add(newNode);
|
||||
}
|
||||
map.put(source, assignedRefForExpr.get(rex));
|
||||
}
|
||||
|
||||
if (newRexNodes.size() == project.getProjects().size()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
|
||||
ImmutableList<ImmutableBitSet> newGroupingSets = null;
|
||||
if (aggregate.getGroupType() != Group.SIMPLE) {
|
||||
newGroupingSets =
|
||||
ImmutableBitSet.ORDERING.immutableSortedCopy(
|
||||
Sets.newTreeSet(ImmutableBitSet.permute(aggregate.getGroupSets(), map)));
|
||||
}
|
||||
|
||||
final ImmutableList.Builder<AggregateCall> aggCalls = ImmutableList.builder();
|
||||
final int sourceCount = aggregate.getInput().getRowType().getFieldCount();
|
||||
final int targetCount = newRexNodes.size();
|
||||
final Mappings.TargetMapping targetMapping = Mappings.target(map, sourceCount, targetCount);
|
||||
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
|
||||
aggCalls.add(aggregateCall.transform(targetMapping));
|
||||
}
|
||||
|
||||
final RelBuilder relBuilder = call.builder();
|
||||
relBuilder.push(project);
|
||||
relBuilder.project(newRexNodes);
|
||||
|
||||
final Aggregate newAggregate =
|
||||
aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
|
||||
newGroupSet, newGroupingSets, aggCalls.build()
|
||||
);
|
||||
relBuilder.push(newAggregate);
|
||||
|
||||
final List<Integer> newKeys =
|
||||
Util.transform(
|
||||
aggregate.getGroupSet().asList(),
|
||||
key -> Objects.requireNonNull(
|
||||
map.get(key),
|
||||
() -> "no value found for key " + key + " in " + map
|
||||
)
|
||||
);
|
||||
|
||||
// Add a project if the group set is not in the same order or
|
||||
// contains duplicates.
|
||||
if (!newKeys.equals(newGroupSet.asList())) {
|
||||
final List<Integer> posList = new ArrayList<>();
|
||||
for (int newKey : newKeys) {
|
||||
posList.add(newGroupSet.indexOf(newKey));
|
||||
}
|
||||
for (int i = newAggregate.getGroupCount();
|
||||
i < newAggregate.getRowType().getFieldCount(); i++) {
|
||||
posList.add(i);
|
||||
}
|
||||
relBuilder.project(relBuilder.fields(posList));
|
||||
}
|
||||
|
||||
return relBuilder.build();
|
||||
}
|
||||
}
|
|
@ -8788,8 +8788,8 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
)
|
||||
.setDimensions(
|
||||
dimensions(
|
||||
new DefaultDimensionSpec("dim1", "d0", ColumnType.STRING),
|
||||
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
|
||||
new DefaultDimensionSpec("v0", "d0", ColumnType.LONG),
|
||||
new DefaultDimensionSpec("dim1", "d1", ColumnType.STRING)
|
||||
)
|
||||
)
|
||||
.setAggregatorSpecs(
|
||||
|
@ -8832,9 +8832,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("_a1"),
|
||||
and(
|
||||
notNull("d0"),
|
||||
notNull("d1"),
|
||||
equality("a1", 0L, ColumnType.LONG),
|
||||
expressionFilter("\"d1\"")
|
||||
expressionFilter("\"d0\"")
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -12938,8 +12938,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
|
||||
.setDimensions(
|
||||
dimensions(
|
||||
new DefaultDimensionSpec("v0", "d0", ColumnType.LONG),
|
||||
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
|
||||
new DefaultDimensionSpec("v0", "d0", ColumnType.LONG)
|
||||
)
|
||||
)
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
|
@ -15680,10 +15679,63 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.build()
|
||||
)
|
||||
).expectedResults(
|
||||
ResultMatchMode.RELAX_NULLS,
|
||||
ImmutableList.of(
|
||||
NullHandling.sqlCompatible() ? ImmutableList.of(
|
||||
new Object[]{null, null, null}
|
||||
) : ImmutableList.of(
|
||||
new Object[]{false, false, ""}
|
||||
)
|
||||
).run();
|
||||
}
|
||||
|
||||
@SqlTestFrameworkConfig.NumMergeBuffers(4)
|
||||
@Test
|
||||
public void testGroupingSetsWithAggrgateCase()
|
||||
{
|
||||
cannotVectorize();
|
||||
msqIncompatible();
|
||||
final Map<String, Object> queryContext = ImmutableMap.of(
|
||||
PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT, false,
|
||||
PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT, true
|
||||
);
|
||||
testBuilder()
|
||||
.sql(
|
||||
"SELECT\n"
|
||||
+ " TIME_FLOOR(\"__time\", 'PT1H') ,\n"
|
||||
+ " COUNT(DISTINCT \"page\") ,\n"
|
||||
+ " COUNT(DISTINCT CASE WHEN \"channel\" = '#it.wikipedia' THEN \"user\" END), \n"
|
||||
+ " COUNT(DISTINCT \"user\") FILTER (WHERE \"channel\" = '#it.wikipedia'), "
|
||||
+ " COUNT(DISTINCT \"user\") \n"
|
||||
+ "FROM \"wikipedia\"\n"
|
||||
+ "GROUP BY 1"
|
||||
)
|
||||
.queryContext(queryContext)
|
||||
.expectedResults(
|
||||
ImmutableList.of(
|
||||
new Object[]{1442016000000L, 264L, 5L, 5L, 149L},
|
||||
new Object[]{1442019600000L, 1090L, 14L, 14L, 506L},
|
||||
new Object[]{1442023200000L, 1045L, 10L, 10L, 459L},
|
||||
new Object[]{1442026800000L, 766L, 10L, 10L, 427L},
|
||||
new Object[]{1442030400000L, 781L, 6L, 6L, 427L},
|
||||
new Object[]{1442034000000L, 1223L, 10L, 10L, 448L},
|
||||
new Object[]{1442037600000L, 2092L, 13L, 13L, 498L},
|
||||
new Object[]{1442041200000L, 2181L, 21L, 21L, 574L},
|
||||
new Object[]{1442044800000L, 1552L, 36L, 36L, 707L},
|
||||
new Object[]{1442048400000L, 1624L, 44L, 44L, 770L},
|
||||
new Object[]{1442052000000L, 1710L, 37L, 37L, 785L},
|
||||
new Object[]{1442055600000L, 1532L, 40L, 40L, 799L},
|
||||
new Object[]{1442059200000L, 1633L, 45L, 45L, 855L},
|
||||
new Object[]{1442062800000L, 1958L, 44L, 44L, 905L},
|
||||
new Object[]{1442066400000L, 1779L, 48L, 48L, 886L},
|
||||
new Object[]{1442070000000L, 1868L, 37L, 37L, 949L},
|
||||
new Object[]{1442073600000L, 1846L, 50L, 50L, 969L},
|
||||
new Object[]{1442077200000L, 2168L, 38L, 38L, 941L},
|
||||
new Object[]{1442080800000L, 2043L, 40L, 40L, 925L},
|
||||
new Object[]{1442084400000L, 1924L, 32L, 32L, 930L},
|
||||
new Object[]{1442088000000L, 1736L, 31L, 31L, 882L},
|
||||
new Object[]{1442091600000L, 1672L, 40L, 40L, 861L},
|
||||
new Object[]{1442095200000L, 1504L, 28L, 28L, 716L},
|
||||
new Object[]{1442098800000L, 1407L, 20L, 20L, 631L}
|
||||
)
|
||||
).run();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue