From 9cc3015ddd6e57c3604e4c566f771f620791fa7c Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Thu, 19 Jan 2017 19:51:14 -0800 Subject: [PATCH] Bypass Calcite's SemiJoinRule and use our own. (#3843) This simplifies DruidSemiJoin, which no longer needs to add aggregation back in. It also allows some more kinds of queries to plan properly, like the one added in "testTopNFilterJoin". --- .../io/druid/sql/calcite/planner/Rules.java | 2 - .../druid/sql/calcite/rel/DruidSemiJoin.java | 101 +++----------- .../sql/calcite/rule/DruidSemiJoinRule.java | 124 +++++++++++++++--- .../druid/sql/calcite/CalciteQueryTest.java | 116 +++++++++++++++- 4 files changed, 237 insertions(+), 106 deletions(-) diff --git a/sql/src/main/java/io/druid/sql/calcite/planner/Rules.java b/sql/src/main/java/io/druid/sql/calcite/planner/Rules.java index 22b5505b0f7..fe40265cb45 100644 --- a/sql/src/main/java/io/druid/sql/calcite/planner/Rules.java +++ b/sql/src/main/java/io/druid/sql/calcite/planner/Rules.java @@ -53,7 +53,6 @@ import org.apache.calcite.rel.rules.ProjectToWindowRule; import org.apache.calcite.rel.rules.ProjectWindowTransposeRule; import org.apache.calcite.rel.rules.PruneEmptyRules; import org.apache.calcite.rel.rules.ReduceExpressionsRule; -import org.apache.calcite.rel.rules.SemiJoinRule; import org.apache.calcite.rel.rules.SortJoinTransposeRule; import org.apache.calcite.rel.rules.SortProjectTransposeRule; import org.apache.calcite.rel.rules.SortRemoveRule; @@ -124,7 +123,6 @@ public class Rules FilterJoinRule.JOIN, AbstractConverter.ExpandConversionRule.INSTANCE, JoinCommuteRule.INSTANCE, - SemiJoinRule.INSTANCE, AggregateRemoveRule.INSTANCE, UnionToDistinctRule.INSTANCE, ProjectRemoveRule.INSTANCE, diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java index 2e2a18543fd..26f8265ae54 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java +++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java @@ -22,20 +22,17 @@ package io.druid.sql.calcite.rel; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import io.druid.java.util.common.ISE; -import io.druid.java.util.common.Pair; import io.druid.java.util.common.guava.Accumulator; import io.druid.java.util.common.guava.Sequence; import io.druid.java.util.common.guava.Sequences; import io.druid.query.QueryDataSource; import io.druid.query.ResourceLimitExceededException; -import io.druid.query.dimension.DimensionSpec; import io.druid.query.filter.AndDimFilter; import io.druid.query.filter.BoundDimFilter; import io.druid.query.filter.DimFilter; import io.druid.query.filter.OrDimFilter; -import io.druid.sql.calcite.aggregation.Aggregation; import io.druid.sql.calcite.expression.RowExtraction; +import io.druid.sql.calcite.planner.PlannerConfig; import io.druid.sql.calcite.table.RowSignature; import org.apache.calcite.interpreter.BindableConvention; import org.apache.calcite.plan.RelOptCluster; @@ -43,20 +40,16 @@ import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptPlanner; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelWriter; -import org.apache.calcite.rel.core.SemiJoin; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexNode; import java.util.List; import java.util.Set; public class DruidSemiJoin extends DruidRel { - private final SemiJoin semiJoin; - private final DruidRel left; - private final DruidRel right; - private final RexNode condition; + private final DruidRel left; + private final DruidRel right; private final List leftRowExtractions; private final List rightKeys; private final int maxSemiJoinRowsInMemory; @@ -64,38 +57,31 @@ public class DruidSemiJoin extends DruidRel private DruidSemiJoin( final RelOptCluster cluster, final RelTraitSet traitSet, - final SemiJoin semiJoin, final DruidRel left, final DruidRel right, - final RexNode condition, final List leftRowExtractions, final List rightKeys, final int maxSemiJoinRowsInMemory ) { super(cluster, traitSet, left.getQueryMaker()); - this.semiJoin = semiJoin; this.left = left; this.right = right; - this.condition = condition; this.leftRowExtractions = ImmutableList.copyOf(leftRowExtractions); this.rightKeys = ImmutableList.copyOf(rightKeys); this.maxSemiJoinRowsInMemory = maxSemiJoinRowsInMemory; } public static DruidSemiJoin from( - final SemiJoin semiJoin, final DruidRel left, final DruidRel right, - final int maxSemiJoinRowsInMemory + final List leftKeys, + final List rightKeys, + final PlannerConfig plannerConfig ) { - if (semiJoin.getLeftKeys().size() != semiJoin.getRightKeys().size()) { - throw new ISE("WTF?! SemiJoin with different left/right key count?"); - } - final ImmutableList.Builder listBuilder = ImmutableList.builder(); - for (Integer key : semiJoin.getLeftKeys()) { + for (Integer key : leftKeys) { final RowExtraction rex = RowExtraction.fromQueryBuilder(left.getQueryBuilder(), key); if (rex == null) { // Can't figure out what to filter the left-hand side on... @@ -105,15 +91,13 @@ public class DruidSemiJoin extends DruidRel } return new DruidSemiJoin( - semiJoin.getCluster(), - semiJoin.getTraitSet(), - semiJoin, + left.getCluster(), + left.getTraitSet(), left, right, - semiJoin.getCondition(), listBuilder.build(), - semiJoin.getRightKeys(), - maxSemiJoinRowsInMemory + rightKeys, + plannerConfig.getMaxSemiJoinRowsInMemory() ); } @@ -135,10 +119,8 @@ public class DruidSemiJoin extends DruidRel return new DruidSemiJoin( getCluster(), getTraitSet().plusAll(newQueryBuilder.getRelTraits()), - semiJoin, left.withQueryBuilder(newQueryBuilder), right, - condition, leftRowExtractions, rightKeys, maxSemiJoinRowsInMemory @@ -158,10 +140,8 @@ public class DruidSemiJoin extends DruidRel return new DruidSemiJoin( getCluster(), getTraitSet().replace(BindableConvention.INSTANCE), - semiJoin, left, right, - condition, leftRowExtractions, rightKeys, maxSemiJoinRowsInMemory @@ -174,10 +154,8 @@ public class DruidSemiJoin extends DruidRel return new DruidSemiJoin( getCluster(), getTraitSet().replace(DruidConvention.instance()), - semiJoin, left, right, - condition, leftRowExtractions, rightKeys, maxSemiJoinRowsInMemory @@ -210,12 +188,11 @@ public class DruidSemiJoin extends DruidRel @Override public RelWriter explainTerms(RelWriter pw) { - final Pair> rightQueryBuilderWithGrouping = getRightQueryBuilderWithGrouping(); return pw .item("leftRowExtractions", leftRowExtractions) .item("leftQuery", left.getQueryBuilder()) - .item("rightKeysAdjusted", rightQueryBuilderWithGrouping.rhs) - .item("rightQuery", rightQueryBuilderWithGrouping.lhs); + .item("rightKeys", rightKeys) + .item("rightQuery", right.getQueryBuilder()); } @Override @@ -224,71 +201,25 @@ public class DruidSemiJoin extends DruidRel return right.computeSelfCost(planner, mq).plus(left.computeSelfCost(planner, mq).multiplyBy(50)); } - private Pair> getRightQueryBuilderWithGrouping() - { - if (right.getQueryBuilder().getGrouping() != null) { - return Pair.of(right.getQueryBuilder(), rightKeys); - } else { - // Add grouping on the join key to limit resultset from data nodes. - final List dimensionSpecs = Lists.newArrayList(); - final List rowTypes = Lists.newArrayList(); - final List rowOrder = Lists.newArrayList(); - final List rightKeysAdjusted = Lists.newArrayList(); - - int counter = 0; - for (final int key : rightKeys) { - final String keyDimensionOutputName = "v" + key; - final RowExtraction rex = RowExtraction.fromQueryBuilder(right.getQueryBuilder(), key); - if (rex == null) { - throw new ISE("WTF?! Can't find dimensionSpec to group on!"); - } - - final DimensionSpec dimensionSpec = rex.toDimensionSpec(left.getSourceRowSignature(), keyDimensionOutputName); - if (dimensionSpec == null) { - throw new ISE("WTF?! Can't translate row expression to dimensionSpec: %s", rex); - } - - dimensionSpecs.add(dimensionSpec); - rowTypes.add(right.getQueryBuilder().getRowType().getFieldList().get(key).getType()); - rowOrder.add(dimensionSpec.getOutputName()); - rightKeysAdjusted.add(counter++); - } - - final DruidQueryBuilder newQueryBuilder = right - .getQueryBuilder() - .withGrouping( - Grouping.create(dimensionSpecs, ImmutableList.of()), - getCluster().getTypeFactory().createStructType(rowTypes, rowOrder), - rowOrder - ); - - return Pair.of(newQueryBuilder, rightKeysAdjusted); - } - } - /** * Returns a copy of the left rel with the filter applied from the right-hand side. This is an expensive operation * since it actually executes the right-hand side query. */ private DruidRel getLeftRelWithFilter() { - final Pair> pair = getRightQueryBuilderWithGrouping(); - final DruidRel rightRelAdjusted = right.withQueryBuilder(pair.lhs); - final List rightKeysAdjusted = pair.rhs; - // Build list of acceptable values from right side. final Set> valuess = Sets.newHashSet(); final List filters = Lists.newArrayList(); - rightRelAdjusted.runQuery().accumulate( + right.runQuery().accumulate( null, new Accumulator() { @Override public Object accumulate(final Object dummyValue, final Object[] row) { - final List values = Lists.newArrayListWithCapacity(rightKeysAdjusted.size()); + final List values = Lists.newArrayListWithCapacity(rightKeys.size()); - for (int i : rightKeysAdjusted) { + for (int i : rightKeys) { final Object value = row[i]; final String stringValue = value != null ? String.valueOf(value) : ""; values.add(stringValue); diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java index 4de14903ea0..8478498da91 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java @@ -19,24 +19,73 @@ package io.druid.sql.calcite.rule; +import com.google.common.base.Predicate; +import io.druid.query.dimension.DimensionSpec; import io.druid.sql.calcite.planner.PlannerConfig; import io.druid.sql.calcite.rel.DruidRel; import io.druid.sql.calcite.rel.DruidSemiJoin; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; -import org.apache.calcite.rel.core.SemiJoin; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinInfo; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.ImmutableBitSet; +import java.util.ArrayList; +import java.util.List; + +/** + * Planner rule adapted from Calcite 1.11.0's SemiJoinRule. + * + * This rule identifies a JOIN where the right-hand side is being used like a filter. Requirements are: + * + * 1) Right-hand side is grouping on the join key + * 2) No fields from the right-hand side are selected + * 3) Join is INNER (right-hand side acting as filter) or LEFT (right-hand side can be ignored) + * + * This is used instead of Calcite's built in rule because that rule's un-doing of aggregation is unproductive (we'd + * just want to add it back again). Also, this rule operates on DruidRels. + */ public class DruidSemiJoinRule extends RelOptRule { + private static final Predicate IS_LEFT_OR_INNER = + new Predicate() + { + public boolean apply(Join join) + { + final JoinRelType joinType = join.getJoinType(); + return joinType == JoinRelType.LEFT || joinType == JoinRelType.INNER; + } + }; + + private static final Predicate IS_GROUP_BY = + new Predicate() + { + public boolean apply(DruidRel druidRel) + { + return druidRel.getQueryBuilder().getGrouping() != null; + } + }; + private final PlannerConfig plannerConfig; - public DruidSemiJoinRule(final PlannerConfig plannerConfig) + private DruidSemiJoinRule(final PlannerConfig plannerConfig) { super( operand( - SemiJoin.class, - operand(DruidRel.class, none()), - operand(DruidRel.class, none()) + Project.class, + operand( + Join.class, + null, + IS_LEFT_OR_INNER, + some( + operand(DruidRel.class, any()), + operand(DruidRel.class, null, IS_GROUP_BY, any()) + ) + ) ) ); this.plannerConfig = plannerConfig; @@ -50,23 +99,66 @@ public class DruidSemiJoinRule extends RelOptRule @Override public void onMatch(RelOptRuleCall call) { - final SemiJoin semiJoin = call.rel(0); - final DruidRel left = call.rel(1); - final DruidRel right = call.rel(2); - final DruidSemiJoin druidSemiJoin = DruidSemiJoin.from( - semiJoin, - left, - right, - plannerConfig.getMaxSemiJoinRowsInMemory() - ); + final Project project = call.rel(0); + final Join join = call.rel(1); + final DruidRel left = call.rel(2); + final DruidRel right = call.rel(3); + + final ImmutableBitSet bits = + RelOptUtil.InputFinder.bits(project.getProjects(), null); + final ImmutableBitSet rightBits = + ImmutableBitSet.range( + left.getRowType().getFieldCount(), + join.getRowType().getFieldCount() + ); + + if (bits.intersects(rightBits)) { + return; + } + + final JoinInfo joinInfo = join.analyzeCondition(); + final List rightDimsOut = new ArrayList<>(); + for (DimensionSpec dimensionSpec : right.getQueryBuilder().getGrouping().getDimensions()) { + rightDimsOut.add(right.getOutputRowSignature().getRowOrder().indexOf(dimensionSpec.getOutputName())); + } + + if (!joinInfo.isEqui() || !joinInfo.rightSet().equals(ImmutableBitSet.of(rightDimsOut))) { + // Rule requires that aggregate key to be the same as the join key. + // By the way, neither a super-set nor a sub-set would work. + return; + } + + final RelBuilder relBuilder = call.builder(); + + if (join.getJoinType() == JoinRelType.LEFT) { + // Join can be eliminated since the right-hand side cannot have any effect (nothing is being selected, + // and LEFT means even if there is no match, a left-hand row will still be included). + relBuilder.push(left); + } else { + final DruidSemiJoin druidSemiJoin = DruidSemiJoin.from( + left, + right, + joinInfo.leftKeys, + joinInfo.rightKeys, + plannerConfig + ); + + if (druidSemiJoin == null) { + return; + } - if (druidSemiJoin != null) { // Check maxQueryCount. if (plannerConfig.getMaxQueryCount() > 0 && druidSemiJoin.getQueryCount() > plannerConfig.getMaxQueryCount()) { return; } - call.transformTo(druidSemiJoin); + relBuilder.push(druidSemiJoin); } + + call.transformTo( + relBuilder + .project(project.getProjects(), project.getRowType().getFieldNames()) + .build() + ); } } diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index d59416e6c64..b8bc35bf6ef 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -1750,6 +1750,116 @@ public class CalciteQueryTest ); } + @Test + public void testTopNFilterJoin() throws Exception + { + // Filters on top N values of some dimension by using an inner join. + testQuery( + "SELECT t1.dim1, SUM(t1.cnt)\n" + + "FROM druid.foo t1\n" + + " INNER JOIN (\n" + + " SELECT\n" + + " SUM(cnt) AS sum_cnt,\n" + + " dim2\n" + + " FROM druid.foo\n" + + " GROUP BY dim2\n" + + " ORDER BY 1 DESC\n" + + " LIMIT 2\n" + + ") t2 ON (t1.dim2 = t2.dim2)\n" + + "GROUP BY t1.dim1\n" + + "ORDER BY 1\n", + ImmutableList.of( + new TopNQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(QSS(Filtration.eternity())) + .granularity(QueryGranularities.ALL) + .dimension(new DefaultDimensionSpec("dim2", "d0")) + .aggregators(AGGS(new LongSumAggregatorFactory("a0", "cnt"))) + .metric(new NumericTopNMetricSpec("a0")) + .threshold(2) + .build(), + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(QSS(Filtration.eternity())) + .setGranularity(QueryGranularities.ALL) + .setDimFilter(IN("dim2", ImmutableList.of("", "a"), null)) + .setDimensions(DIMS(new DefaultDimensionSpec("dim1", "d0"))) + .setAggregatorSpecs(AGGS(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec( + new DefaultLimitSpec( + ImmutableList.of( + new OrderByColumnSpec( + "d0", + OrderByColumnSpec.Direction.ASCENDING, + StringComparators.LEXICOGRAPHIC + ) + ), + Integer.MAX_VALUE + ) + ) + .build() + ), + ImmutableList.of( + new Object[]{"", 1L}, + new Object[]{"1", 1L}, + new Object[]{"10.1", 1L}, + new Object[]{"2", 1L}, + new Object[]{"abc", 1L} + ) + ); + } + + @Test + public void testRemovableLeftJoin() throws Exception + { + // LEFT JOIN where the right-hand side can be ignored. + + testQuery( + "SELECT t1.dim1, SUM(t1.cnt)\n" + + "FROM druid.foo t1\n" + + " LEFT JOIN (\n" + + " SELECT\n" + + " SUM(cnt) AS sum_cnt,\n" + + " dim2\n" + + " FROM druid.foo\n" + + " GROUP BY dim2\n" + + " ORDER BY 1 DESC\n" + + " LIMIT 2\n" + + ") t2 ON (t1.dim2 = t2.dim2)\n" + + "GROUP BY t1.dim1\n" + + "ORDER BY 1\n", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(QSS(Filtration.eternity())) + .setGranularity(QueryGranularities.ALL) + .setDimensions(DIMS(new DefaultDimensionSpec("dim1", "d0"))) + .setAggregatorSpecs(AGGS(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec( + new DefaultLimitSpec( + ImmutableList.of( + new OrderByColumnSpec( + "d0", + OrderByColumnSpec.Direction.ASCENDING, + StringComparators.LEXICOGRAPHIC + ) + ), + Integer.MAX_VALUE + ) + ) + .build() + ), + ImmutableList.of( + new Object[]{"", 1L}, + new Object[]{"1", 1L}, + new Object[]{"10.1", 1L}, + new Object[]{"2", 1L}, + new Object[]{"abc", 1L}, + new Object[]{"def", 1L} + ) + ); + } + @Test public void testExactCountDistinctOfSemiJoinResult() throws Exception { @@ -1770,7 +1880,7 @@ public class CalciteQueryTest .setDimFilter(NOT(SELECTOR("dim1", "", null))) .setDimensions(DIMS(new ExtractionDimensionSpec( "dim1", - "v0", + "d0", new SubstringDimExtractionFn(0, 1) ))) .build(), @@ -2707,7 +2817,7 @@ public class CalciteQueryTest .setInterval(QSS(Filtration.eternity())) .setGranularity(QueryGranularities.ALL) .setDimFilter(NOT(SELECTOR("dim1", "", null))) - .setDimensions(DIMS(new DefaultDimensionSpec("dim1", "v0"))) + .setDimensions(DIMS(new DefaultDimensionSpec("dim1", "d0"))) .build(), GroupByQuery.builder() .setDataSource(CalciteTests.DATASOURCE1) @@ -2830,7 +2940,7 @@ public class CalciteQueryTest .setGranularity(QueryGranularities.ALL) .setDimFilter(NOT(SELECTOR("dim1", "", null))) .setDimensions( - DIMS(new ExtractionDimensionSpec("dim1", "v0", new SubstringDimExtractionFn(0, 1))) + DIMS(new ExtractionDimensionSpec("dim1", "d0", new SubstringDimExtractionFn(0, 1))) ) .build(), GroupByQuery.builder()