diff --git a/processing/src/main/java/org/apache/druid/query/JoinDataSource.java b/processing/src/main/java/org/apache/druid/query/JoinDataSource.java index 77734edf025..220f18a9485 100644 --- a/processing/src/main/java/org/apache/druid/query/JoinDataSource.java +++ b/processing/src/main/java/org/apache/druid/query/JoinDataSource.java @@ -476,10 +476,25 @@ public class JoinDataSource implements DataSource .orElse(null) ) ); - + final Function baseMapFn; + // A join data source is not concrete + // And isConcrete() of an unnest datasource delegates to its base + // Hence, in the case of a Join -> Unnest -> Join + // if we just use isConcrete on the left + // the segment map function for the unnest would never get called + // This calls us to delegate to the segmentMapFunction of the left + // only when it is not a JoinDataSource + if (left instanceof JoinDataSource) { + baseMapFn = Function.identity(); + } else { + baseMapFn = left.createSegmentMapFunction( + query, + cpuTimeAccumulator + ); + } return baseSegment -> new HashJoinSegment( - baseSegment, + baseMapFn.apply(baseSegment), baseFilterToUse, GuavaUtils.firstNonNull(clausesToUse, ImmutableList.of()), joinFilterPreAnalysis @@ -501,18 +516,39 @@ public class JoinDataSource implements DataSource DimFilter currentDimFilter = null; final List preJoinableClauses = new ArrayList<>(); - while (current instanceof JoinDataSource) { - final JoinDataSource joinDataSource = (JoinDataSource) current; - current = joinDataSource.getLeft(); - currentDimFilter = validateLeftFilter(current, joinDataSource.getLeftFilter()); - preJoinableClauses.add( - new PreJoinableClause( - joinDataSource.getRightPrefix(), - joinDataSource.getRight(), - joinDataSource.getJoinType(), - joinDataSource.getConditionAnalysis() - ) - ); + // There can be queries like + // Join of Unnest of Join of Unnest of Filter + // so these checks are needed to be ORed + // to get the base + // This method is called to get the analysis for the join data source + // Since the analysis of an UnnestDS or FilteredDS always delegates to its base + // To obtain the base data source underneath a Join + // we also iterate through the base of the FilterDS and UnnestDS in its path + // the base of which can be a concrete data source + // This also means that an addition of a new datasource + // Will need an instanceof check here + // A future work should look into if the flattenJoin + // can be refactored to omit these instanceof checks + while (current instanceof JoinDataSource || current instanceof UnnestDataSource || current instanceof FilteredDataSource) { + if (current instanceof JoinDataSource) { + final JoinDataSource joinDataSource = (JoinDataSource) current; + current = joinDataSource.getLeft(); + currentDimFilter = validateLeftFilter(current, joinDataSource.getLeftFilter()); + preJoinableClauses.add( + new PreJoinableClause( + joinDataSource.getRightPrefix(), + joinDataSource.getRight(), + joinDataSource.getJoinType(), + joinDataSource.getConditionAnalysis() + ) + ); + } else if (current instanceof UnnestDataSource) { + final UnnestDataSource unnestDataSource = (UnnestDataSource) current; + current = unnestDataSource.getBase(); + } else { + final FilteredDataSource filteredDataSource = (FilteredDataSource) current; + current = filteredDataSource.getBase(); + } } // Join clauses were added in the order we saw them while traversing down, but we need to apply them in the diff --git a/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java b/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java index b23c0b92dbb..b821bc49c4e 100644 --- a/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java +++ b/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java @@ -29,11 +29,14 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.TrueDimFilter; +import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.segment.join.JoinType; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.segment.join.NoopJoinableFactory; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.easymock.Mock; import org.junit.Assert; import org.junit.Rule; @@ -433,6 +436,51 @@ public class JoinDataSourceTest Assert.assertFalse(Arrays.equals(cacheKey1, cacheKey2)); } + @Test + public void testGetAnalysisWithUnnestDS() + { + JoinDataSource dataSource = JoinDataSource.create( + UnnestDataSource.create( + new TableDataSource("table1"), + new ExpressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING, ExprMacroTable.nil()), + null + ), + new TableDataSource("table2"), + "j.", + "x == \"j.x\"", + JoinType.LEFT, + null, + ExprMacroTable.nil(), + null + ); + DataSourceAnalysis analysis = dataSource.getAnalysis(); + Assert.assertEquals("table1", analysis.getBaseDataSource().getTableNames().iterator().next()); + } + + @Test + public void testGetAnalysisWithFilteredDS() + { + JoinDataSource dataSource = JoinDataSource.create( + UnnestDataSource.create( + FilteredDataSource.create( + new TableDataSource("table1"), + TrueDimFilter.instance() + ), + new ExpressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING, ExprMacroTable.nil()), + null + ), + new TableDataSource("table2"), + "j.", + "x == \"j.x\"", + JoinType.LEFT, + null, + ExprMacroTable.nil(), + null + ); + DataSourceAnalysis analysis = dataSource.getAnalysis(); + Assert.assertEquals("table1", analysis.getBaseDataSource().getTableNames().iterator().next()); + } + @Test public void test_computeJoinDataSourceCacheKey_keyChangesWithBaseFilter() { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java index c35c872544f..1627329c75e 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java @@ -66,7 +66,7 @@ public class DruidRels */ public static boolean isScanOrProject(final DruidRel druidRel, final boolean canBeJoinOrUnion) { - if (druidRel instanceof DruidQueryRel || (canBeJoinOrUnion && (druidRel instanceof DruidJoinQueryRel + if (druidRel instanceof DruidQueryRel || (canBeJoinOrUnion && (druidRel instanceof DruidJoinQueryRel || druidRel instanceof DruidCorrelateUnnestRel || druidRel instanceof DruidUnionDataSourceRel))) { final PartialDruidQuery partialQuery = druidRel.getPartialDruidQuery(); final PartialDruidQuery.Stage stage = partialQuery.stage(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index d1300ff19b2..e8a72833960 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -38,6 +38,7 @@ import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.DataSource; import org.apache.druid.query.Druids; +import org.apache.druid.query.FilteredDataSource; import org.apache.druid.query.GlobalTableDataSource; import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.JoinDataSource; @@ -49,6 +50,7 @@ import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryException; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.UnionDataSource; +import org.apache.druid.query.UnnestDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; @@ -64,6 +66,7 @@ import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.query.extraction.SubstringDimExtractionFn; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.ResultRow; @@ -5914,4 +5917,356 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest ) ); } + + @Test + public void testJoinsWithUnnestOnLeft() + { + // Segment map function of MSQ needs some work + // To handle these nested cases + // Remove this when that's handled + msqIncompatible(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "with t1 as (\n" + + "select * from foo, unnest(MV_TO_ARRAY(\"dim3\")) as u(d3)\n" + + ")\n" + + "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN numfoo as t2\n" + + "ON t1.d3 = t2.\"dim2\"", + context, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE1), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + ), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE3) + .columns("dim2") + .legacy(false) + .context(context) + .build() + ), + "_j0.", + "(\"j0.unnest\" == \"_j0.dim2\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("_j0.dim2", "dim3", "j0.unnest") + .context(context) + .build() + ), + useDefault ? + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"", "", ""} + ) + ); + } + + @Test + public void testJoinsWithUnnestOverFilteredDSOnLeft() + { + // Segment map function of MSQ needs some work + // To handle these nested cases + // Remove this when that's handled + msqIncompatible(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "with t1 as (\n" + + "select * from foo, unnest(MV_TO_ARRAY(\"dim3\")) as u(d3) where dim2='a'\n" + + ")\n" + + "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN numfoo as t2\n" + + "ON t1.d3 = t2.\"dim2\"", + context, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + UnnestDataSource.create( + FilteredDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE1), + equality("dim2", "a", ColumnType.STRING) + ), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + ), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE3) + .columns("dim2") + .legacy(false) + .context(context) + .build() + ), + "_j0.", + "(\"j0.unnest\" == \"_j0.dim2\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("_j0.dim2", "dim3", "j0.unnest") + .context(context) + .build() + ), + useDefault ? + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"", "", ""} + ) + ); + } + + @Test + public void testJoinsWithUnnestOverJoin() + { + // Segment map function of MSQ needs some work + // To handle these nested cases + // Remove this when that's handled + msqIncompatible(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "with t1 as (\n" + + "select * from (SELECT * from foo JOIN (select dim2 as t from foo where dim2 IN ('a','b','ab','abc')) ON dim2=t), " + + " unnest(MV_TO_ARRAY(\"dim3\")) as u(d3) \n" + + ")\n" + + "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN numfoo as t2\n" + + "ON t1.d3 = t2.\"dim2\"", + context, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + UnnestDataSource.create( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE1) + .filters(new InDimFilter("dim2", ImmutableList.of("a", "b", "ab", "abc"), null)) + .legacy(false) + .context(context) + .columns("dim2") + .build() + ), + "j0.", + "(\"dim2\" == \"j0.dim2\")", + JoinType.INNER + ), + expressionVirtualColumn("_j0.unnest", "\"dim3\"", ColumnType.STRING), + null + ), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE3) + .columns("dim2") + .legacy(false) + .context(context) + .build() + ), + "__j0.", + "(\"_j0.unnest\" == \"__j0.dim2\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("__j0.dim2", "_j0.unnest", "dim3") + .context(context) + .build() + ), + useDefault ? + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"", "", ""}, + new Object[]{"", "", ""}, + new Object[]{"", "", ""}, + new Object[]{"", "", ""} + ) + ); + } + + @Test + public void testSelfJoinsWithUnnestOnLeftAndRight() + { + // Segment map function of MSQ needs some work + // To handle these nested cases + // Remove this when that's handled + msqIncompatible(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "with t1 as (\n" + + "select * from foo, unnest(MV_TO_ARRAY(\"dim3\")) as u(d3)\n" + + ")\n" + + "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN t1 as t2\n" + + "ON t1.d3 = t2.d3", + context, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE1), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + ), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE1), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + )) + .columns("dim2", "j0.unnest") + .legacy(false) + .context(context) + .build() + ), + "_j0.", + "(\"j0.unnest\" == \"_j0.j0.unnest\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("_j0.dim2", "dim3", "j0.unnest") + .context(context) + .build() + ), + useDefault ? + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "b", "a"}, + new Object[]{"[\"a\",\"b\"]", "b", ""}, + new Object[]{"[\"b\",\"c\"]", "b", "a"}, + new Object[]{"[\"b\",\"c\"]", "b", ""}, + new Object[]{"[\"b\",\"c\"]", "c", ""}, + new Object[]{"d", "d", ""} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "b", "a"}, + new Object[]{"[\"a\",\"b\"]", "b", null}, + new Object[]{"[\"b\",\"c\"]", "b", "a"}, + new Object[]{"[\"b\",\"c\"]", "b", null}, + new Object[]{"[\"b\",\"c\"]", "c", null}, + new Object[]{"d", "d", ""}, + new Object[]{"", "", "a"} + ) + ); + } + + @Test + public void testJoinsOverUnnestOverFilterDSOverJoin() + { + // Segment map function of MSQ needs some work + // To handle these nested cases + // Remove this when that's handled + msqIncompatible(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "with t1 as (\n" + + "select * from (SELECT * from foo JOIN (select dim2 as t from foo where dim2 IN ('a','b','ab','abc')) ON dim2=t),\n" + + "unnest(MV_TO_ARRAY(\"dim3\")) as u(d3) where m1 IN (1,4) and d3='a'\n" + + ")\n" + + "select t1.dim3, t1.d3, t2.dim2, t1.m1 from t1 JOIN numfoo as t2\n" + + "ON t1.d3 = t2.\"dim2\"", + context, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + UnnestDataSource.create( + FilteredDataSource.create( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE1) + .columns("dim2") + .filters(new InDimFilter( + "dim2", + ImmutableList.of("a", "ab", "abc", "b"), + null + )) + .legacy(false) + .context(context) + .build() + ), + "j0.", + "(\"dim2\" == \"j0.dim2\")", + JoinType.INNER + ), + useDefault ? + new InDimFilter("m1", ImmutableList.of("1", "4"), null) : + or( + equality("m1", 1.0, ColumnType.FLOAT), + equality("m1", 4.0, ColumnType.FLOAT) + ) + ), + expressionVirtualColumn("_j0.unnest", "\"dim3\"", ColumnType.STRING), + equality("_j0.unnest", "a", ColumnType.STRING) + ), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE3) + .columns("dim2") + .legacy(false) + .context(context) + .build() + ), + "__j0.", + "(\"_j0.unnest\" == \"__j0.dim2\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("__j0.dim2", "_j0.unnest", "dim3", "m1") + .context(context) + .build() + ), + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f} + ) + ); + } }