diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java index 49e77ce166e..fe9c0d47f02 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java @@ -69,8 +69,7 @@ public class DruidUnionDataSourceRule extends RelOptRule final DruidRel firstDruidRel = call.rel(1); final DruidQueryRel secondDruidRel = call.rel(2); - // Can only do UNION ALL of inputs that have compatible schemas (or schema mappings). - return unionRel.all && isUnionCompatible(firstDruidRel, secondDruidRel); + return isCompatible(unionRel, firstDruidRel, secondDruidRel); } @Override @@ -111,6 +110,17 @@ public class DruidUnionDataSourceRule extends RelOptRule } } + // Can only do UNION ALL of inputs that have compatible schemas (or schema mappings) and right side + // is a simple table scan + public static boolean isCompatible(final Union unionRel, final DruidRel first, final DruidRel second) + { + if (!(second instanceof DruidQueryRel)) { + return false; + } + + return unionRel.all && isUnionCompatible(first, second); + } + private static boolean isUnionCompatible(final DruidRel first, final DruidRel second) { final Optional> columnNames = getColumnNamesIfTableOrUnion(first); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java index e97ed2b8c72..5863415abc8 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java @@ -55,7 +55,10 @@ public class DruidUnionRule extends RelOptRule public boolean matches(RelOptRuleCall call) { // Make DruidUnionRule and DruidUnionDataSourceRule mutually exclusive. - return !DruidUnionDataSourceRule.instance().matches(call); + final Union unionRel = call.rel(0); + final DruidRel firstDruidRel = call.rel(1); + final DruidRel secondDruidRel = call.rel(2); + return !DruidUnionDataSourceRule.isCompatible(unionRel, firstDruidRel, secondDruidRel); } @Override diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index c4279f22838..88fd1dcceb4 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -4268,6 +4268,120 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testUnionAllTwoQueriesLeftQueryIsJoin() throws Exception + { + cannotVectorize(); + + testQuery( + "(SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k) UNION ALL SELECT SUM(cnt) FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new LookupDataSource("lookyloo"), + "j0.", + equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.k")), + JoinType.INNER + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build(), + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{1L}, new Object[]{6L}) + ); + } + + @Test + public void testUnionAllTwoQueriesRightQueryIsJoin() throws Exception + { + cannotVectorize(); + + testQuery( + "(SELECT SUM(cnt) FROM foo UNION ALL SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k) ", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build(), + Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new LookupDataSource("lookyloo"), + "j0.", + equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.k")), + JoinType.INNER + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{6L}, new Object[]{1L}) + ); + } + + @Test + public void testUnionAllTwoQueriesBothQueriesAreJoin() throws Exception + { + cannotVectorize(); + + testQuery( + "(" + + "SELECT COUNT(*) FROM foo LEFT JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k " + + " UNION ALL " + + "SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k" + + ") ", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new LookupDataSource("lookyloo"), + "j0.", + equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.k")), + JoinType.LEFT + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build(), + Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new LookupDataSource("lookyloo"), + "j0.", + equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.k")), + JoinType.INNER + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{6L}, new Object[]{1L}) + ); + } + @Test public void testPruneDeadAggregators() throws Exception {