Updating plans when using joins with unnest on the left (#15075)

* Updating plans when using joins with unnest on the left

* Correcting segment map function for hashJoin

* The changes done here are not reflected into MSQ yet so these tests might not run in MSQ

* native tests

* Self joins with unnest data source

* Making this pass

* Addressing comments by adding explanation and new test
This commit is contained in:
Soumyava 2023-10-06 19:23:12 -07:00 committed by GitHub
parent f9439970c9
commit 57ab8e13dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 454 additions and 15 deletions

View File

@ -476,10 +476,25 @@ public class JoinDataSource implements DataSource
.orElse(null)
)
);
final Function<SegmentReference, SegmentReference> baseMapFn;
// A join data source is not concrete
// And isConcrete() of an unnest datasource delegates to its base
// Hence, in the case of a Join -> Unnest -> Join
// if we just use isConcrete on the left
// the segment map function for the unnest would never get called
// This calls us to delegate to the segmentMapFunction of the left
// only when it is not a JoinDataSource
if (left instanceof JoinDataSource) {
baseMapFn = Function.identity();
} else {
baseMapFn = left.createSegmentMapFunction(
query,
cpuTimeAccumulator
);
}
return baseSegment ->
new HashJoinSegment(
baseSegment,
baseMapFn.apply(baseSegment),
baseFilterToUse,
GuavaUtils.firstNonNull(clausesToUse, ImmutableList.of()),
joinFilterPreAnalysis
@ -501,7 +516,21 @@ public class JoinDataSource implements DataSource
DimFilter currentDimFilter = null;
final List<PreJoinableClause> preJoinableClauses = new ArrayList<>();
while (current instanceof JoinDataSource) {
// There can be queries like
// Join of Unnest of Join of Unnest of Filter
// so these checks are needed to be ORed
// to get the base
// This method is called to get the analysis for the join data source
// Since the analysis of an UnnestDS or FilteredDS always delegates to its base
// To obtain the base data source underneath a Join
// we also iterate through the base of the FilterDS and UnnestDS in its path
// the base of which can be a concrete data source
// This also means that an addition of a new datasource
// Will need an instanceof check here
// A future work should look into if the flattenJoin
// can be refactored to omit these instanceof checks
while (current instanceof JoinDataSource || current instanceof UnnestDataSource || current instanceof FilteredDataSource) {
if (current instanceof JoinDataSource) {
final JoinDataSource joinDataSource = (JoinDataSource) current;
current = joinDataSource.getLeft();
currentDimFilter = validateLeftFilter(current, joinDataSource.getLeftFilter());
@ -513,6 +542,13 @@ public class JoinDataSource implements DataSource
joinDataSource.getConditionAnalysis()
)
);
} else if (current instanceof UnnestDataSource) {
final UnnestDataSource unnestDataSource = (UnnestDataSource) current;
current = unnestDataSource.getBase();
} else {
final FilteredDataSource filteredDataSource = (FilteredDataSource) current;
current = filteredDataSource.getBase();
}
}
// Join clauses were added in the order we saw them while traversing down, but we need to apply them in the

View File

@ -29,11 +29,14 @@ import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.TrueDimFilter;
import org.apache.druid.query.planning.DataSourceAnalysis;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.apache.druid.segment.join.NoopJoinableFactory;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.easymock.Mock;
import org.junit.Assert;
import org.junit.Rule;
@ -433,6 +436,51 @@ public class JoinDataSourceTest
Assert.assertFalse(Arrays.equals(cacheKey1, cacheKey2));
}
@Test
public void testGetAnalysisWithUnnestDS()
{
JoinDataSource dataSource = JoinDataSource.create(
UnnestDataSource.create(
new TableDataSource("table1"),
new ExpressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING, ExprMacroTable.nil()),
null
),
new TableDataSource("table2"),
"j.",
"x == \"j.x\"",
JoinType.LEFT,
null,
ExprMacroTable.nil(),
null
);
DataSourceAnalysis analysis = dataSource.getAnalysis();
Assert.assertEquals("table1", analysis.getBaseDataSource().getTableNames().iterator().next());
}
@Test
public void testGetAnalysisWithFilteredDS()
{
JoinDataSource dataSource = JoinDataSource.create(
UnnestDataSource.create(
FilteredDataSource.create(
new TableDataSource("table1"),
TrueDimFilter.instance()
),
new ExpressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING, ExprMacroTable.nil()),
null
),
new TableDataSource("table2"),
"j.",
"x == \"j.x\"",
JoinType.LEFT,
null,
ExprMacroTable.nil(),
null
);
DataSourceAnalysis analysis = dataSource.getAnalysis();
Assert.assertEquals("table1", analysis.getBaseDataSource().getTableNames().iterator().next());
}
@Test
public void test_computeJoinDataSourceCacheKey_keyChangesWithBaseFilter()
{

View File

@ -66,7 +66,7 @@ public class DruidRels
*/
public static boolean isScanOrProject(final DruidRel<?> druidRel, final boolean canBeJoinOrUnion)
{
if (druidRel instanceof DruidQueryRel || (canBeJoinOrUnion && (druidRel instanceof DruidJoinQueryRel
if (druidRel instanceof DruidQueryRel || (canBeJoinOrUnion && (druidRel instanceof DruidJoinQueryRel || druidRel instanceof DruidCorrelateUnnestRel
|| druidRel instanceof DruidUnionDataSourceRel))) {
final PartialDruidQuery partialQuery = druidRel.getPartialDruidQuery();
final PartialDruidQuery.Stage stage = partialQuery.stage();

View File

@ -38,6 +38,7 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.Druids;
import org.apache.druid.query.FilteredDataSource;
import org.apache.druid.query.GlobalTableDataSource;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinDataSource;
@ -49,6 +50,7 @@ import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryException;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnionDataSource;
import org.apache.druid.query.UnnestDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
@ -64,6 +66,7 @@ import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.ExtractionDimensionSpec;
import org.apache.druid.query.extraction.SubstringDimExtractionFn;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.LikeDimFilter;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.ResultRow;
@ -5914,4 +5917,356 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
)
);
}
@Test
public void testJoinsWithUnnestOnLeft()
{
// Segment map function of MSQ needs some work
// To handle these nested cases
// Remove this when that's handled
msqIncompatible();
Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
testQuery(
"with t1 as (\n"
+ "select * from foo, unnest(MV_TO_ARRAY(\"dim3\")) as u(d3)\n"
+ ")\n"
+ "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN numfoo as t2\n"
+ "ON t1.d3 = t2.\"dim2\"",
context,
ImmutableList.of(
newScanQueryBuilder()
.dataSource(
join(
UnnestDataSource.create(
new TableDataSource(CalciteTests.DATASOURCE1),
expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING),
null
),
new QueryDataSource(
newScanQueryBuilder()
.intervals(querySegmentSpec(Filtration.eternity()))
.dataSource(CalciteTests.DATASOURCE3)
.columns("dim2")
.legacy(false)
.context(context)
.build()
),
"_j0.",
"(\"j0.unnest\" == \"_j0.dim2\")",
JoinType.INNER
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("_j0.dim2", "dim3", "j0.unnest")
.context(context)
.build()
),
useDefault ?
ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"}
) : ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"", "", ""}
)
);
}
@Test
public void testJoinsWithUnnestOverFilteredDSOnLeft()
{
// Segment map function of MSQ needs some work
// To handle these nested cases
// Remove this when that's handled
msqIncompatible();
Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
testQuery(
"with t1 as (\n"
+ "select * from foo, unnest(MV_TO_ARRAY(\"dim3\")) as u(d3) where dim2='a'\n"
+ ")\n"
+ "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN numfoo as t2\n"
+ "ON t1.d3 = t2.\"dim2\"",
context,
ImmutableList.of(
newScanQueryBuilder()
.dataSource(
join(
UnnestDataSource.create(
FilteredDataSource.create(
new TableDataSource(CalciteTests.DATASOURCE1),
equality("dim2", "a", ColumnType.STRING)
),
expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING),
null
),
new QueryDataSource(
newScanQueryBuilder()
.intervals(querySegmentSpec(Filtration.eternity()))
.dataSource(CalciteTests.DATASOURCE3)
.columns("dim2")
.legacy(false)
.context(context)
.build()
),
"_j0.",
"(\"j0.unnest\" == \"_j0.dim2\")",
JoinType.INNER
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("_j0.dim2", "dim3", "j0.unnest")
.context(context)
.build()
),
useDefault ?
ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"}
) : ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"", "", ""}
)
);
}
@Test
public void testJoinsWithUnnestOverJoin()
{
// Segment map function of MSQ needs some work
// To handle these nested cases
// Remove this when that's handled
msqIncompatible();
Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
testQuery(
"with t1 as (\n"
+ "select * from (SELECT * from foo JOIN (select dim2 as t from foo where dim2 IN ('a','b','ab','abc')) ON dim2=t), "
+ " unnest(MV_TO_ARRAY(\"dim3\")) as u(d3) \n"
+ ")\n"
+ "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN numfoo as t2\n"
+ "ON t1.d3 = t2.\"dim2\"",
context,
ImmutableList.of(
newScanQueryBuilder()
.dataSource(
join(
UnnestDataSource.create(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
newScanQueryBuilder()
.intervals(querySegmentSpec(Filtration.eternity()))
.dataSource(CalciteTests.DATASOURCE1)
.filters(new InDimFilter("dim2", ImmutableList.of("a", "b", "ab", "abc"), null))
.legacy(false)
.context(context)
.columns("dim2")
.build()
),
"j0.",
"(\"dim2\" == \"j0.dim2\")",
JoinType.INNER
),
expressionVirtualColumn("_j0.unnest", "\"dim3\"", ColumnType.STRING),
null
),
new QueryDataSource(
newScanQueryBuilder()
.intervals(querySegmentSpec(Filtration.eternity()))
.dataSource(CalciteTests.DATASOURCE3)
.columns("dim2")
.legacy(false)
.context(context)
.build()
),
"__j0.",
"(\"_j0.unnest\" == \"__j0.dim2\")",
JoinType.INNER
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("__j0.dim2", "_j0.unnest", "dim3")
.context(context)
.build()
),
useDefault ?
ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"}
) : ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"", "", ""},
new Object[]{"", "", ""},
new Object[]{"", "", ""},
new Object[]{"", "", ""}
)
);
}
@Test
public void testSelfJoinsWithUnnestOnLeftAndRight()
{
// Segment map function of MSQ needs some work
// To handle these nested cases
// Remove this when that's handled
msqIncompatible();
Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
testQuery(
"with t1 as (\n"
+ "select * from foo, unnest(MV_TO_ARRAY(\"dim3\")) as u(d3)\n"
+ ")\n"
+ "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN t1 as t2\n"
+ "ON t1.d3 = t2.d3",
context,
ImmutableList.of(
newScanQueryBuilder()
.dataSource(
join(
UnnestDataSource.create(
new TableDataSource(CalciteTests.DATASOURCE1),
expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING),
null
),
new QueryDataSource(
newScanQueryBuilder()
.intervals(querySegmentSpec(Filtration.eternity()))
.dataSource(UnnestDataSource.create(
new TableDataSource(CalciteTests.DATASOURCE1),
expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING),
null
))
.columns("dim2", "j0.unnest")
.legacy(false)
.context(context)
.build()
),
"_j0.",
"(\"j0.unnest\" == \"_j0.j0.unnest\")",
JoinType.INNER
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("_j0.dim2", "dim3", "j0.unnest")
.context(context)
.build()
),
useDefault ?
ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "b", "a"},
new Object[]{"[\"a\",\"b\"]", "b", ""},
new Object[]{"[\"b\",\"c\"]", "b", "a"},
new Object[]{"[\"b\",\"c\"]", "b", ""},
new Object[]{"[\"b\",\"c\"]", "c", ""},
new Object[]{"d", "d", ""}
) : ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "a", "a"},
new Object[]{"[\"a\",\"b\"]", "b", "a"},
new Object[]{"[\"a\",\"b\"]", "b", null},
new Object[]{"[\"b\",\"c\"]", "b", "a"},
new Object[]{"[\"b\",\"c\"]", "b", null},
new Object[]{"[\"b\",\"c\"]", "c", null},
new Object[]{"d", "d", ""},
new Object[]{"", "", "a"}
)
);
}
@Test
public void testJoinsOverUnnestOverFilterDSOverJoin()
{
// Segment map function of MSQ needs some work
// To handle these nested cases
// Remove this when that's handled
msqIncompatible();
Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
testQuery(
"with t1 as (\n"
+ "select * from (SELECT * from foo JOIN (select dim2 as t from foo where dim2 IN ('a','b','ab','abc')) ON dim2=t),\n"
+ "unnest(MV_TO_ARRAY(\"dim3\")) as u(d3) where m1 IN (1,4) and d3='a'\n"
+ ")\n"
+ "select t1.dim3, t1.d3, t2.dim2, t1.m1 from t1 JOIN numfoo as t2\n"
+ "ON t1.d3 = t2.\"dim2\"",
context,
ImmutableList.of(
newScanQueryBuilder()
.dataSource(
join(
UnnestDataSource.create(
FilteredDataSource.create(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
newScanQueryBuilder()
.intervals(querySegmentSpec(Filtration.eternity()))
.dataSource(CalciteTests.DATASOURCE1)
.columns("dim2")
.filters(new InDimFilter(
"dim2",
ImmutableList.of("a", "ab", "abc", "b"),
null
))
.legacy(false)
.context(context)
.build()
),
"j0.",
"(\"dim2\" == \"j0.dim2\")",
JoinType.INNER
),
useDefault ?
new InDimFilter("m1", ImmutableList.of("1", "4"), null) :
or(
equality("m1", 1.0, ColumnType.FLOAT),
equality("m1", 4.0, ColumnType.FLOAT)
)
),
expressionVirtualColumn("_j0.unnest", "\"dim3\"", ColumnType.STRING),
equality("_j0.unnest", "a", ColumnType.STRING)
),
new QueryDataSource(
newScanQueryBuilder()
.intervals(querySegmentSpec(Filtration.eternity()))
.dataSource(CalciteTests.DATASOURCE3)
.columns("dim2")
.legacy(false)
.context(context)
.build()
),
"__j0.",
"(\"_j0.unnest\" == \"__j0.dim2\")",
JoinType.INNER
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("__j0.dim2", "_j0.unnest", "dim3", "m1")
.context(context)
.build()
),
ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f},
new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f},
new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f},
new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f},
new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f},
new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f},
new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f},
new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}
)
);
}
}