From b8dd7478d079fdce36d3ed94a86749c6b8162833 Mon Sep 17 00:00:00 2001 From: Sree Charan Manamala <155449160+sreemanamala@users.noreply.github.com> Date: Tue, 14 May 2024 10:08:05 +0530 Subject: [PATCH] Custom Calcite Rule to remove redundant references (#16402) Custom calcite rule mimicking AggregateProjectMergeRule to extend support to expressions. The current calcite rule return null in such cases. In addition, this removes the redundant references. --- .../GroupingAggregatorFactory.java | 23 ++- .../GroupingAggregatorFactoryTest.java | 8 + .../builtin/GroupingSqlAggregator.java | 14 +- .../calcite/planner/CalciteRulesManager.java | 2 + .../DruidAggregateRemoveRedundancyRule.java | 164 ++++++++++++++++++ .../druid/sql/calcite/CalciteQueryTest.java | 74 ++++++-- 6 files changed, 264 insertions(+), 21 deletions(-) create mode 100644 sql/src/main/java/org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.java diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java index 8f8f7be4a14..e87c23951db 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java @@ -102,6 +102,20 @@ public class GroupingAggregatorFactory extends AggregatorFactory ) { Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); + Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), "Must have a non-empty grouping dimensions"); + // (Long.SIZE - 1) is just a sanity check. In practice, it will be just few dimensions. This limit + // also makes sure that values are always positive. + Preconditions.checkArgument( + groupings.size() < Long.SIZE, + "Number of dimensions %s is more than supported %s", + groupings.size(), + Long.SIZE - 1 + ); + Preconditions.checkArgument( + groupings.stream().distinct().count() == groupings.size(), + "Encountered same dimension more than once in groupings" + ); + this.name = name; this.groupings = groupings; this.keyDimensions = keyDimensions; @@ -254,15 +268,6 @@ public class GroupingAggregatorFactory extends AggregatorFactory */ private long groupingId(List groupings, @Nullable Set keyDimensions) { - Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), "Must have a non-empty grouping dimensions"); - // (Long.SIZE - 1) is just a sanity check. In practice, it will be just few dimensions. This limit - // also makes sure that values are always positive. - Preconditions.checkArgument( - groupings.size() < Long.SIZE, - "Number of dimensions %s is more than supported %s", - groupings.size(), - Long.SIZE - 1 - ); long temp = 0L; for (String groupingDimension : groupings) { temp = temp << 1; diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java index 2be56bd0f0e..c9772ab9534 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java @@ -131,6 +131,14 @@ public class GroupingAggregatorFactoryTest )); makeFactory(new String[Long.SIZE], null); } + + @Test + public void testWithDuplicateGroupings() + { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Encountered same dimension more than once in groupings"); + makeFactory(new String[]{"a", "a"}, null); + } } @RunWith(Parameterized.class) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java index 7c123dab927..1209ee30eaf 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java @@ -90,7 +90,19 @@ public class GroupingSqlAggregator implements SqlAggregator } } } - AggregatorFactory factory = new GroupingAggregatorFactory(name, arguments); + AggregatorFactory factory; + try { + factory = new GroupingAggregatorFactory(name, arguments); + } + catch (Exception e) { + plannerContext.setPlanningError( + "Initialisation of Grouping Aggregator Factory in case of [%s] threw [%s]", + aggregateCall, + e.getMessage() + ); + return null; + } + return Aggregation.create(factory); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java index 4326f63340d..829c44b18c6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java @@ -66,6 +66,7 @@ import org.apache.druid.sql.calcite.rule.ProjectAggregatePruneUnusedCallRule; import org.apache.druid.sql.calcite.rule.ReverseLookupRule; import org.apache.druid.sql.calcite.rule.RewriteFirstValueLastValueRule; import org.apache.druid.sql.calcite.rule.SortCollapseRule; +import org.apache.druid.sql.calcite.rule.logical.DruidAggregateRemoveRedundancyRule; import org.apache.druid.sql.calcite.rule.logical.DruidLogicalRules; import org.apache.druid.sql.calcite.run.EngineFeature; @@ -496,6 +497,7 @@ public class CalciteRulesManager rules.add(FilterJoinExcludePushToChildRule.FILTER_ON_JOIN_EXCLUDE_PUSH_TO_CHILD); rules.add(SortCollapseRule.instance()); rules.add(ProjectAggregatePruneUnusedCallRule.instance()); + rules.add(DruidAggregateRemoveRedundancyRule.instance()); return rules.build(); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.java new file mode 100644 index 00000000000..1ef91dcb6ba --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.rule.logical; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.Aggregate.Group; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.rules.TransformationRule; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.Util; +import org.apache.calcite.util.mapping.Mappings; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** + * Planner rule that recognizes a {@link Aggregate} + * on top of a {@link Project} and if possible + * aggregate through the project or removes the project. + *

+ * This is updated version of {@link org.apache.calcite.rel.rules.AggregateProjectMergeRule} + * to be able to handle expressions. + */ +@Value.Enclosing +public class DruidAggregateRemoveRedundancyRule + extends RelOptRule + implements TransformationRule +{ + + /** + * Creates a DruidAggregateRemoveRedundancyRule. + */ + private static final DruidAggregateRemoveRedundancyRule INSTANCE = new DruidAggregateRemoveRedundancyRule(); + + private DruidAggregateRemoveRedundancyRule() + { + super(operand(Aggregate.class, operand(Project.class, any()))); + } + + public static DruidAggregateRemoveRedundancyRule instance() + { + return INSTANCE; + } + + @Override + public void onMatch(RelOptRuleCall call) + { + final Aggregate aggregate = call.rel(0); + final Project project = call.rel(1); + RelNode x = apply(call, aggregate, project); + if (x != null) { + call.transformTo(x); + call.getPlanner().prune(aggregate); + } + } + + public static @Nullable RelNode apply(RelOptRuleCall call, Aggregate aggregate, Project project) + { + final Set interestingFields = RelOptUtil.getAllFields(aggregate); + if (interestingFields.isEmpty()) { + return null; + } + final Map map = new HashMap<>(); + final Map assignedRefForExpr = new HashMap<>(); + List newRexNodes = new ArrayList<>(); + for (int source : interestingFields) { + final RexNode rex = project.getProjects().get(source); + if (!assignedRefForExpr.containsKey(rex)) { + RexNode newNode = new RexInputRef(source, rex.getType()); + assignedRefForExpr.put(rex, newRexNodes.size()); + newRexNodes.add(newNode); + } + map.put(source, assignedRefForExpr.get(rex)); + } + + if (newRexNodes.size() == project.getProjects().size()) { + return null; + } + + final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map); + ImmutableList newGroupingSets = null; + if (aggregate.getGroupType() != Group.SIMPLE) { + newGroupingSets = + ImmutableBitSet.ORDERING.immutableSortedCopy( + Sets.newTreeSet(ImmutableBitSet.permute(aggregate.getGroupSets(), map))); + } + + final ImmutableList.Builder aggCalls = ImmutableList.builder(); + final int sourceCount = aggregate.getInput().getRowType().getFieldCount(); + final int targetCount = newRexNodes.size(); + final Mappings.TargetMapping targetMapping = Mappings.target(map, sourceCount, targetCount); + for (AggregateCall aggregateCall : aggregate.getAggCallList()) { + aggCalls.add(aggregateCall.transform(targetMapping)); + } + + final RelBuilder relBuilder = call.builder(); + relBuilder.push(project); + relBuilder.project(newRexNodes); + + final Aggregate newAggregate = + aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), + newGroupSet, newGroupingSets, aggCalls.build() + ); + relBuilder.push(newAggregate); + + final List newKeys = + Util.transform( + aggregate.getGroupSet().asList(), + key -> Objects.requireNonNull( + map.get(key), + () -> "no value found for key " + key + " in " + map + ) + ); + + // Add a project if the group set is not in the same order or + // contains duplicates. + if (!newKeys.equals(newGroupSet.asList())) { + final List posList = new ArrayList<>(); + for (int newKey : newKeys) { + posList.add(newGroupSet.indexOf(newKey)); + } + for (int i = newAggregate.getGroupCount(); + i < newAggregate.getRowType().getFieldCount(); i++) { + posList.add(i); + } + relBuilder.project(relBuilder.fields(posList)); + } + + return relBuilder.build(); + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index b5a32f4301c..9b302534ef6 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -8788,8 +8788,8 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ) .setDimensions( dimensions( - new DefaultDimensionSpec("dim1", "d0", ColumnType.STRING), - new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + new DefaultDimensionSpec("v0", "d0", ColumnType.LONG), + new DefaultDimensionSpec("dim1", "d1", ColumnType.STRING) ) ) .setAggregatorSpecs( @@ -8832,9 +8832,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest new FilteredAggregatorFactory( new CountAggregatorFactory("_a1"), and( - notNull("d0"), + notNull("d1"), equality("a1", 0L, ColumnType.LONG), - expressionFilter("\"d1\"") + expressionFilter("\"d0\"") ) ) ) @@ -12938,8 +12938,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG)) .setDimensions( dimensions( - new DefaultDimensionSpec("v0", "d0", ColumnType.LONG), - new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + new DefaultDimensionSpec("v0", "d0", ColumnType.LONG) ) ) .setContext(QUERY_CONTEXT_DEFAULT) @@ -15680,10 +15679,63 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ) ).expectedResults( - ResultMatchMode.RELAX_NULLS, - ImmutableList.of( - new Object[]{null, null, null} - ) - ); + NullHandling.sqlCompatible() ? ImmutableList.of( + new Object[]{null, null, null} + ) : ImmutableList.of( + new Object[]{false, false, ""} + ) + ).run(); + } + + @SqlTestFrameworkConfig.NumMergeBuffers(4) + @Test + public void testGroupingSetsWithAggrgateCase() + { + cannotVectorize(); + msqIncompatible(); + final Map queryContext = ImmutableMap.of( + PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT, false, + PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT, true + ); + testBuilder() + .sql( + "SELECT\n" + + " TIME_FLOOR(\"__time\", 'PT1H') ,\n" + + " COUNT(DISTINCT \"page\") ,\n" + + " COUNT(DISTINCT CASE WHEN \"channel\" = '#it.wikipedia' THEN \"user\" END), \n" + + " COUNT(DISTINCT \"user\") FILTER (WHERE \"channel\" = '#it.wikipedia'), " + + " COUNT(DISTINCT \"user\") \n" + + "FROM \"wikipedia\"\n" + + "GROUP BY 1" + ) + .queryContext(queryContext) + .expectedResults( + ImmutableList.of( + new Object[]{1442016000000L, 264L, 5L, 5L, 149L}, + new Object[]{1442019600000L, 1090L, 14L, 14L, 506L}, + new Object[]{1442023200000L, 1045L, 10L, 10L, 459L}, + new Object[]{1442026800000L, 766L, 10L, 10L, 427L}, + new Object[]{1442030400000L, 781L, 6L, 6L, 427L}, + new Object[]{1442034000000L, 1223L, 10L, 10L, 448L}, + new Object[]{1442037600000L, 2092L, 13L, 13L, 498L}, + new Object[]{1442041200000L, 2181L, 21L, 21L, 574L}, + new Object[]{1442044800000L, 1552L, 36L, 36L, 707L}, + new Object[]{1442048400000L, 1624L, 44L, 44L, 770L}, + new Object[]{1442052000000L, 1710L, 37L, 37L, 785L}, + new Object[]{1442055600000L, 1532L, 40L, 40L, 799L}, + new Object[]{1442059200000L, 1633L, 45L, 45L, 855L}, + new Object[]{1442062800000L, 1958L, 44L, 44L, 905L}, + new Object[]{1442066400000L, 1779L, 48L, 48L, 886L}, + new Object[]{1442070000000L, 1868L, 37L, 37L, 949L}, + new Object[]{1442073600000L, 1846L, 50L, 50L, 969L}, + new Object[]{1442077200000L, 2168L, 38L, 38L, 941L}, + new Object[]{1442080800000L, 2043L, 40L, 40L, 925L}, + new Object[]{1442084400000L, 1924L, 32L, 32L, 930L}, + new Object[]{1442088000000L, 1736L, 31L, 31L, 882L}, + new Object[]{1442091600000L, 1672L, 40L, 40L, 861L}, + new Object[]{1442095200000L, 1504L, 28L, 28L, 716L}, + new Object[]{1442098800000L, 1407L, 20L, 20L, 631L} + ) + ).run(); } }