Bypass Calcite's SemiJoinRule and use our own. (#3843)

This simplifies DruidSemiJoin, which no longer needs to add aggregation back
in. It also allows some more kinds of queries to plan properly, like the one
added in "testTopNFilterJoin".
This commit is contained in:
Gian Merlino 2017-01-19 19:51:14 -08:00 committed by Jonathan Wei
parent d51f5e058d
commit 9cc3015ddd
4 changed files with 237 additions and 106 deletions

View File

@ -53,7 +53,6 @@ import org.apache.calcite.rel.rules.ProjectToWindowRule;
import org.apache.calcite.rel.rules.ProjectWindowTransposeRule;
import org.apache.calcite.rel.rules.PruneEmptyRules;
import org.apache.calcite.rel.rules.ReduceExpressionsRule;
import org.apache.calcite.rel.rules.SemiJoinRule;
import org.apache.calcite.rel.rules.SortJoinTransposeRule;
import org.apache.calcite.rel.rules.SortProjectTransposeRule;
import org.apache.calcite.rel.rules.SortRemoveRule;
@ -124,7 +123,6 @@ public class Rules
FilterJoinRule.JOIN,
AbstractConverter.ExpandConversionRule.INSTANCE,
JoinCommuteRule.INSTANCE,
SemiJoinRule.INSTANCE,
AggregateRemoveRule.INSTANCE,
UnionToDistinctRule.INSTANCE,
ProjectRemoveRule.INSTANCE,

View File

@ -22,20 +22,17 @@ package io.druid.sql.calcite.rel;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.druid.java.util.common.ISE;
import io.druid.java.util.common.Pair;
import io.druid.java.util.common.guava.Accumulator;
import io.druid.java.util.common.guava.Sequence;
import io.druid.java.util.common.guava.Sequences;
import io.druid.query.QueryDataSource;
import io.druid.query.ResourceLimitExceededException;
import io.druid.query.dimension.DimensionSpec;
import io.druid.query.filter.AndDimFilter;
import io.druid.query.filter.BoundDimFilter;
import io.druid.query.filter.DimFilter;
import io.druid.query.filter.OrDimFilter;
import io.druid.sql.calcite.aggregation.Aggregation;
import io.druid.sql.calcite.expression.RowExtraction;
import io.druid.sql.calcite.planner.PlannerConfig;
import io.druid.sql.calcite.table.RowSignature;
import org.apache.calcite.interpreter.BindableConvention;
import org.apache.calcite.plan.RelOptCluster;
@ -43,20 +40,16 @@ import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.SemiJoin;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import java.util.List;
import java.util.Set;
public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
{
private final SemiJoin semiJoin;
private final DruidRel left;
private final DruidRel right;
private final RexNode condition;
private final DruidRel<?> left;
private final DruidRel<?> right;
private final List<RowExtraction> leftRowExtractions;
private final List<Integer> rightKeys;
private final int maxSemiJoinRowsInMemory;
@ -64,38 +57,31 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
private DruidSemiJoin(
final RelOptCluster cluster,
final RelTraitSet traitSet,
final SemiJoin semiJoin,
final DruidRel left,
final DruidRel right,
final RexNode condition,
final List<RowExtraction> leftRowExtractions,
final List<Integer> rightKeys,
final int maxSemiJoinRowsInMemory
)
{
super(cluster, traitSet, left.getQueryMaker());
this.semiJoin = semiJoin;
this.left = left;
this.right = right;
this.condition = condition;
this.leftRowExtractions = ImmutableList.copyOf(leftRowExtractions);
this.rightKeys = ImmutableList.copyOf(rightKeys);
this.maxSemiJoinRowsInMemory = maxSemiJoinRowsInMemory;
}
public static DruidSemiJoin from(
final SemiJoin semiJoin,
final DruidRel left,
final DruidRel right,
final int maxSemiJoinRowsInMemory
final List<Integer> leftKeys,
final List<Integer> rightKeys,
final PlannerConfig plannerConfig
)
{
if (semiJoin.getLeftKeys().size() != semiJoin.getRightKeys().size()) {
throw new ISE("WTF?! SemiJoin with different left/right key count?");
}
final ImmutableList.Builder<RowExtraction> listBuilder = ImmutableList.builder();
for (Integer key : semiJoin.getLeftKeys()) {
for (Integer key : leftKeys) {
final RowExtraction rex = RowExtraction.fromQueryBuilder(left.getQueryBuilder(), key);
if (rex == null) {
// Can't figure out what to filter the left-hand side on...
@ -105,15 +91,13 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
}
return new DruidSemiJoin(
semiJoin.getCluster(),
semiJoin.getTraitSet(),
semiJoin,
left.getCluster(),
left.getTraitSet(),
left,
right,
semiJoin.getCondition(),
listBuilder.build(),
semiJoin.getRightKeys(),
maxSemiJoinRowsInMemory
rightKeys,
plannerConfig.getMaxSemiJoinRowsInMemory()
);
}
@ -135,10 +119,8 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
return new DruidSemiJoin(
getCluster(),
getTraitSet().plusAll(newQueryBuilder.getRelTraits()),
semiJoin,
left.withQueryBuilder(newQueryBuilder),
right,
condition,
leftRowExtractions,
rightKeys,
maxSemiJoinRowsInMemory
@ -158,10 +140,8 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
return new DruidSemiJoin(
getCluster(),
getTraitSet().replace(BindableConvention.INSTANCE),
semiJoin,
left,
right,
condition,
leftRowExtractions,
rightKeys,
maxSemiJoinRowsInMemory
@ -174,10 +154,8 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
return new DruidSemiJoin(
getCluster(),
getTraitSet().replace(DruidConvention.instance()),
semiJoin,
left,
right,
condition,
leftRowExtractions,
rightKeys,
maxSemiJoinRowsInMemory
@ -210,12 +188,11 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
@Override
public RelWriter explainTerms(RelWriter pw)
{
final Pair<DruidQueryBuilder, List<Integer>> rightQueryBuilderWithGrouping = getRightQueryBuilderWithGrouping();
return pw
.item("leftRowExtractions", leftRowExtractions)
.item("leftQuery", left.getQueryBuilder())
.item("rightKeysAdjusted", rightQueryBuilderWithGrouping.rhs)
.item("rightQuery", rightQueryBuilderWithGrouping.lhs);
.item("rightKeys", rightKeys)
.item("rightQuery", right.getQueryBuilder());
}
@Override
@ -224,71 +201,25 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
return right.computeSelfCost(planner, mq).plus(left.computeSelfCost(planner, mq).multiplyBy(50));
}
private Pair<DruidQueryBuilder, List<Integer>> getRightQueryBuilderWithGrouping()
{
if (right.getQueryBuilder().getGrouping() != null) {
return Pair.of(right.getQueryBuilder(), rightKeys);
} else {
// Add grouping on the join key to limit resultset from data nodes.
final List<DimensionSpec> dimensionSpecs = Lists.newArrayList();
final List<RelDataType> rowTypes = Lists.newArrayList();
final List<String> rowOrder = Lists.newArrayList();
final List<Integer> rightKeysAdjusted = Lists.newArrayList();
int counter = 0;
for (final int key : rightKeys) {
final String keyDimensionOutputName = "v" + key;
final RowExtraction rex = RowExtraction.fromQueryBuilder(right.getQueryBuilder(), key);
if (rex == null) {
throw new ISE("WTF?! Can't find dimensionSpec to group on!");
}
final DimensionSpec dimensionSpec = rex.toDimensionSpec(left.getSourceRowSignature(), keyDimensionOutputName);
if (dimensionSpec == null) {
throw new ISE("WTF?! Can't translate row expression to dimensionSpec: %s", rex);
}
dimensionSpecs.add(dimensionSpec);
rowTypes.add(right.getQueryBuilder().getRowType().getFieldList().get(key).getType());
rowOrder.add(dimensionSpec.getOutputName());
rightKeysAdjusted.add(counter++);
}
final DruidQueryBuilder newQueryBuilder = right
.getQueryBuilder()
.withGrouping(
Grouping.create(dimensionSpecs, ImmutableList.<Aggregation>of()),
getCluster().getTypeFactory().createStructType(rowTypes, rowOrder),
rowOrder
);
return Pair.of(newQueryBuilder, rightKeysAdjusted);
}
}
/**
* Returns a copy of the left rel with the filter applied from the right-hand side. This is an expensive operation
* since it actually executes the right-hand side query.
*/
private DruidRel<?> getLeftRelWithFilter()
{
final Pair<DruidQueryBuilder, List<Integer>> pair = getRightQueryBuilderWithGrouping();
final DruidRel<?> rightRelAdjusted = right.withQueryBuilder(pair.lhs);
final List<Integer> rightKeysAdjusted = pair.rhs;
// Build list of acceptable values from right side.
final Set<List<String>> valuess = Sets.newHashSet();
final List<DimFilter> filters = Lists.newArrayList();
rightRelAdjusted.runQuery().accumulate(
right.runQuery().accumulate(
null,
new Accumulator<Object, Object[]>()
{
@Override
public Object accumulate(final Object dummyValue, final Object[] row)
{
final List<String> values = Lists.newArrayListWithCapacity(rightKeysAdjusted.size());
final List<String> values = Lists.newArrayListWithCapacity(rightKeys.size());
for (int i : rightKeysAdjusted) {
for (int i : rightKeys) {
final Object value = row[i];
final String stringValue = value != null ? String.valueOf(value) : "";
values.add(stringValue);

View File

@ -19,24 +19,73 @@
package io.druid.sql.calcite.rule;
import com.google.common.base.Predicate;
import io.druid.query.dimension.DimensionSpec;
import io.druid.sql.calcite.planner.PlannerConfig;
import io.druid.sql.calcite.rel.DruidRel;
import io.druid.sql.calcite.rel.DruidSemiJoin;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.core.SemiJoin;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import java.util.ArrayList;
import java.util.List;
/**
* Planner rule adapted from Calcite 1.11.0's SemiJoinRule.
*
* This rule identifies a JOIN where the right-hand side is being used like a filter. Requirements are:
*
* 1) Right-hand side is grouping on the join key
* 2) No fields from the right-hand side are selected
* 3) Join is INNER (right-hand side acting as filter) or LEFT (right-hand side can be ignored)
*
* This is used instead of Calcite's built in rule because that rule's un-doing of aggregation is unproductive (we'd
* just want to add it back again). Also, this rule operates on DruidRels.
*/
public class DruidSemiJoinRule extends RelOptRule
{
private static final Predicate<Join> IS_LEFT_OR_INNER =
new Predicate<Join>()
{
public boolean apply(Join join)
{
final JoinRelType joinType = join.getJoinType();
return joinType == JoinRelType.LEFT || joinType == JoinRelType.INNER;
}
};
private static final Predicate<DruidRel> IS_GROUP_BY =
new Predicate<DruidRel>()
{
public boolean apply(DruidRel druidRel)
{
return druidRel.getQueryBuilder().getGrouping() != null;
}
};
private final PlannerConfig plannerConfig;
public DruidSemiJoinRule(final PlannerConfig plannerConfig)
private DruidSemiJoinRule(final PlannerConfig plannerConfig)
{
super(
operand(
SemiJoin.class,
operand(DruidRel.class, none()),
operand(DruidRel.class, none())
Project.class,
operand(
Join.class,
null,
IS_LEFT_OR_INNER,
some(
operand(DruidRel.class, any()),
operand(DruidRel.class, null, IS_GROUP_BY, any())
)
)
)
);
this.plannerConfig = plannerConfig;
@ -50,23 +99,66 @@ public class DruidSemiJoinRule extends RelOptRule
@Override
public void onMatch(RelOptRuleCall call)
{
final SemiJoin semiJoin = call.rel(0);
final DruidRel left = call.rel(1);
final DruidRel right = call.rel(2);
final DruidSemiJoin druidSemiJoin = DruidSemiJoin.from(
semiJoin,
left,
right,
plannerConfig.getMaxSemiJoinRowsInMemory()
);
final Project project = call.rel(0);
final Join join = call.rel(1);
final DruidRel left = call.rel(2);
final DruidRel right = call.rel(3);
final ImmutableBitSet bits =
RelOptUtil.InputFinder.bits(project.getProjects(), null);
final ImmutableBitSet rightBits =
ImmutableBitSet.range(
left.getRowType().getFieldCount(),
join.getRowType().getFieldCount()
);
if (bits.intersects(rightBits)) {
return;
}
final JoinInfo joinInfo = join.analyzeCondition();
final List<Integer> rightDimsOut = new ArrayList<>();
for (DimensionSpec dimensionSpec : right.getQueryBuilder().getGrouping().getDimensions()) {
rightDimsOut.add(right.getOutputRowSignature().getRowOrder().indexOf(dimensionSpec.getOutputName()));
}
if (!joinInfo.isEqui() || !joinInfo.rightSet().equals(ImmutableBitSet.of(rightDimsOut))) {
// Rule requires that aggregate key to be the same as the join key.
// By the way, neither a super-set nor a sub-set would work.
return;
}
final RelBuilder relBuilder = call.builder();
if (join.getJoinType() == JoinRelType.LEFT) {
// Join can be eliminated since the right-hand side cannot have any effect (nothing is being selected,
// and LEFT means even if there is no match, a left-hand row will still be included).
relBuilder.push(left);
} else {
final DruidSemiJoin druidSemiJoin = DruidSemiJoin.from(
left,
right,
joinInfo.leftKeys,
joinInfo.rightKeys,
plannerConfig
);
if (druidSemiJoin == null) {
return;
}
if (druidSemiJoin != null) {
// Check maxQueryCount.
if (plannerConfig.getMaxQueryCount() > 0 && druidSemiJoin.getQueryCount() > plannerConfig.getMaxQueryCount()) {
return;
}
call.transformTo(druidSemiJoin);
relBuilder.push(druidSemiJoin);
}
call.transformTo(
relBuilder
.project(project.getProjects(), project.getRowType().getFieldNames())
.build()
);
}
}

View File

@ -1750,6 +1750,116 @@ public class CalciteQueryTest
);
}
@Test
public void testTopNFilterJoin() throws Exception
{
// Filters on top N values of some dimension by using an inner join.
testQuery(
"SELECT t1.dim1, SUM(t1.cnt)\n"
+ "FROM druid.foo t1\n"
+ " INNER JOIN (\n"
+ " SELECT\n"
+ " SUM(cnt) AS sum_cnt,\n"
+ " dim2\n"
+ " FROM druid.foo\n"
+ " GROUP BY dim2\n"
+ " ORDER BY 1 DESC\n"
+ " LIMIT 2\n"
+ ") t2 ON (t1.dim2 = t2.dim2)\n"
+ "GROUP BY t1.dim1\n"
+ "ORDER BY 1\n",
ImmutableList.<Query>of(
new TopNQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(QSS(Filtration.eternity()))
.granularity(QueryGranularities.ALL)
.dimension(new DefaultDimensionSpec("dim2", "d0"))
.aggregators(AGGS(new LongSumAggregatorFactory("a0", "cnt")))
.metric(new NumericTopNMetricSpec("a0"))
.threshold(2)
.build(),
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(QueryGranularities.ALL)
.setDimFilter(IN("dim2", ImmutableList.of("", "a"), null))
.setDimensions(DIMS(new DefaultDimensionSpec("dim1", "d0")))
.setAggregatorSpecs(AGGS(new LongSumAggregatorFactory("a0", "cnt")))
.setLimitSpec(
new DefaultLimitSpec(
ImmutableList.of(
new OrderByColumnSpec(
"d0",
OrderByColumnSpec.Direction.ASCENDING,
StringComparators.LEXICOGRAPHIC
)
),
Integer.MAX_VALUE
)
)
.build()
),
ImmutableList.of(
new Object[]{"", 1L},
new Object[]{"1", 1L},
new Object[]{"10.1", 1L},
new Object[]{"2", 1L},
new Object[]{"abc", 1L}
)
);
}
@Test
public void testRemovableLeftJoin() throws Exception
{
// LEFT JOIN where the right-hand side can be ignored.
testQuery(
"SELECT t1.dim1, SUM(t1.cnt)\n"
+ "FROM druid.foo t1\n"
+ " LEFT JOIN (\n"
+ " SELECT\n"
+ " SUM(cnt) AS sum_cnt,\n"
+ " dim2\n"
+ " FROM druid.foo\n"
+ " GROUP BY dim2\n"
+ " ORDER BY 1 DESC\n"
+ " LIMIT 2\n"
+ ") t2 ON (t1.dim2 = t2.dim2)\n"
+ "GROUP BY t1.dim1\n"
+ "ORDER BY 1\n",
ImmutableList.<Query>of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(QueryGranularities.ALL)
.setDimensions(DIMS(new DefaultDimensionSpec("dim1", "d0")))
.setAggregatorSpecs(AGGS(new LongSumAggregatorFactory("a0", "cnt")))
.setLimitSpec(
new DefaultLimitSpec(
ImmutableList.of(
new OrderByColumnSpec(
"d0",
OrderByColumnSpec.Direction.ASCENDING,
StringComparators.LEXICOGRAPHIC
)
),
Integer.MAX_VALUE
)
)
.build()
),
ImmutableList.of(
new Object[]{"", 1L},
new Object[]{"1", 1L},
new Object[]{"10.1", 1L},
new Object[]{"2", 1L},
new Object[]{"abc", 1L},
new Object[]{"def", 1L}
)
);
}
@Test
public void testExactCountDistinctOfSemiJoinResult() throws Exception
{
@ -1770,7 +1880,7 @@ public class CalciteQueryTest
.setDimFilter(NOT(SELECTOR("dim1", "", null)))
.setDimensions(DIMS(new ExtractionDimensionSpec(
"dim1",
"v0",
"d0",
new SubstringDimExtractionFn(0, 1)
)))
.build(),
@ -2707,7 +2817,7 @@ public class CalciteQueryTest
.setInterval(QSS(Filtration.eternity()))
.setGranularity(QueryGranularities.ALL)
.setDimFilter(NOT(SELECTOR("dim1", "", null)))
.setDimensions(DIMS(new DefaultDimensionSpec("dim1", "v0")))
.setDimensions(DIMS(new DefaultDimensionSpec("dim1", "d0")))
.build(),
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
@ -2830,7 +2940,7 @@ public class CalciteQueryTest
.setGranularity(QueryGranularities.ALL)
.setDimFilter(NOT(SELECTOR("dim1", "", null)))
.setDimensions(
DIMS(new ExtractionDimensionSpec("dim1", "v0", new SubstringDimExtractionFn(0, 1)))
DIMS(new ExtractionDimensionSpec("dim1", "d0", new SubstringDimExtractionFn(0, 1)))
)
.build(),
GroupByQuery.builder()