Fix classCastException when inputs to union are join (#10950)

* Fix union queries

* Add tests
This commit is contained in:
Abhishek Agarwal 2021-03-09 10:50:26 +05:30 committed by GitHub
parent 756ac6ef30
commit ae620921df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 130 additions and 3 deletions

View File

@ -69,8 +69,7 @@ public class DruidUnionDataSourceRule extends RelOptRule
final DruidRel<?> firstDruidRel = call.rel(1);
final DruidQueryRel secondDruidRel = call.rel(2);
// Can only do UNION ALL of inputs that have compatible schemas (or schema mappings).
return unionRel.all && isUnionCompatible(firstDruidRel, secondDruidRel);
return isCompatible(unionRel, firstDruidRel, secondDruidRel);
}
@Override
@ -111,6 +110,17 @@ public class DruidUnionDataSourceRule extends RelOptRule
}
}
// Can only do UNION ALL of inputs that have compatible schemas (or schema mappings) and right side
// is a simple table scan
public static boolean isCompatible(final Union unionRel, final DruidRel<?> first, final DruidRel<?> second)
{
if (!(second instanceof DruidQueryRel)) {
return false;
}
return unionRel.all && isUnionCompatible(first, second);
}
private static boolean isUnionCompatible(final DruidRel<?> first, final DruidRel<?> second)
{
final Optional<List<String>> columnNames = getColumnNamesIfTableOrUnion(first);

View File

@ -55,7 +55,10 @@ public class DruidUnionRule extends RelOptRule
public boolean matches(RelOptRuleCall call)
{
// Make DruidUnionRule and DruidUnionDataSourceRule mutually exclusive.
return !DruidUnionDataSourceRule.instance().matches(call);
final Union unionRel = call.rel(0);
final DruidRel<?> firstDruidRel = call.rel(1);
final DruidRel<?> secondDruidRel = call.rel(2);
return !DruidUnionDataSourceRule.isCompatible(unionRel, firstDruidRel, secondDruidRel);
}
@Override

View File

@ -4268,6 +4268,120 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testUnionAllTwoQueriesLeftQueryIsJoin() throws Exception
{
cannotVectorize();
testQuery(
"(SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k) UNION ALL SELECT SUM(cnt) FROM foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new LookupDataSource("lookyloo"),
"j0.",
equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.k")),
JoinType.INNER
))
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{1L}, new Object[]{6L})
);
}
@Test
public void testUnionAllTwoQueriesRightQueryIsJoin() throws Exception
{
cannotVectorize();
testQuery(
"(SELECT SUM(cnt) FROM foo UNION ALL SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k) ",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Druids.newTimeseriesQueryBuilder()
.dataSource(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new LookupDataSource("lookyloo"),
"j0.",
equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.k")),
JoinType.INNER
))
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{6L}, new Object[]{1L})
);
}
@Test
public void testUnionAllTwoQueriesBothQueriesAreJoin() throws Exception
{
cannotVectorize();
testQuery(
"("
+ "SELECT COUNT(*) FROM foo LEFT JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k "
+ " UNION ALL "
+ "SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k"
+ ") ",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new LookupDataSource("lookyloo"),
"j0.",
equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.k")),
JoinType.LEFT
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Druids.newTimeseriesQueryBuilder()
.dataSource(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new LookupDataSource("lookyloo"),
"j0.",
equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.k")),
JoinType.INNER
))
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{6L}, new Object[]{1L})
);
}
@Test
public void testPruneDeadAggregators() throws Exception
{