diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidCorrelateUnnestRel.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidCorrelateUnnestRel.java index 5c5bfd8aae4..bcf358272fa 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidCorrelateUnnestRel.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidCorrelateUnnestRel.java @@ -336,7 +336,8 @@ public class DruidCorrelateUnnestRel extends DruidRel RowSignature.builder().add( BASE_UNNEST_OUTPUT_COLUMN, Calcites.getColumnTypeForRelDataType(unnestedType) - ).build() + ).build(), + DruidJoinQueryRel.findExistingJoinPrefixes(leftQuery.getDataSource()) ).rhs; } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java index d0f6f7bef25..5ab29ab13b1 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java @@ -57,6 +57,8 @@ import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException; import org.apache.druid.sql.calcite.table.RowSignatures; import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -160,7 +162,12 @@ public class DruidJoinQueryRel extends DruidRel rightDataSource = rightQuery.getDataSource(); } - final Pair prefixSignaturePair = computeJoinRowSignature(leftSignature, rightSignature); + + final Pair prefixSignaturePair = computeJoinRowSignature( + leftSignature, + rightSignature, + findExistingJoinPrefixes(leftDataSource, rightDataSource) + ); VirtualColumnRegistry virtualColumnRegistry = VirtualColumnRegistry.create( prefixSignaturePair.rhs, @@ -380,13 +387,29 @@ public class DruidJoinQueryRel extends DruidRel && DruidRels.druidTableIfLeafRel(right).filter(table -> table.getDataSource().isGlobal()).isPresent()); } + static Set findExistingJoinPrefixes(DataSource... dataSources) + { + final ArrayList copy = new ArrayList<>(Arrays.asList(dataSources)); + + Set prefixes = new HashSet<>(); + while (!copy.isEmpty()) { + DataSource current = copy.remove(0); + copy.addAll(current.getChildren()); + if (current instanceof JoinDataSource) { + JoinDataSource joiner = (JoinDataSource) current; + prefixes.add(joiner.getRightPrefix()); + } + } + return prefixes; + } /** * Returns a Pair of "rightPrefix" (for JoinDataSource) and the signature of rows that will result from * applying that prefix. */ static Pair computeJoinRowSignature( final RowSignature leftSignature, - final RowSignature rightSignature + final RowSignature rightSignature, + final Set prefixes ) { final RowSignature.Builder signatureBuilder = RowSignature.builder(); @@ -395,8 +418,17 @@ public class DruidJoinQueryRel extends DruidRel signatureBuilder.add(column, leftSignature.getColumnType(column).orElse(null)); } - // Need to include the "0" since findUnusedPrefixForDigits only guarantees safety for digit-initiated suffixes - final String rightPrefix = Calcites.findUnusedPrefixForDigits("j", leftSignature.getColumnNames()) + "0."; + StringBuilder base = new StringBuilder("j"); + // the prefixes collection contains all known join prefixes, which might be in use for nested queries but not + // present in the top level row signatures + // loop until we are sure we got a new prefix + String maybePrefix; + do { + // Need to include the "0" since findUnusedPrefixForDigits only guarantees safety for digit-initiated suffixes + maybePrefix = Calcites.findUnusedPrefixForDigits(base.toString(), leftSignature.getColumnNames()) + "0."; + base.insert(0, "_"); + } while (prefixes.contains(maybePrefix)); + final String rightPrefix = maybePrefix; for (final String column : rightSignature.getColumnNames()) { signatureBuilder.add(rightPrefix + column, rightSignature.getColumnType(column).orElse(null)); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index 620a8cbbb7a..32007ed7abb 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -61,6 +61,7 @@ import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.query.extraction.SubstringDimExtractionFn; import org.apache.druid.query.filter.AndDimFilter; import org.apache.druid.query.filter.BoundDimFilter; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; import org.apache.druid.query.filter.NotDimFilter; import org.apache.druid.query.filter.OrDimFilter; @@ -95,8 +96,10 @@ import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; @@ -4766,8 +4769,8 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest .context(queryContext) .build() ), - "j0.", - equalsCondition(makeColumnExpression("v0"), makeColumnExpression("j0.v0")), + "_j0.", + equalsCondition(makeColumnExpression("v0"), makeColumnExpression("_j0.v0")), JoinType.INNER ) ) @@ -4778,7 +4781,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest ImmutableSet.of("a"), true )) - .columns("dim3", "j0.dim3") + .columns("_j0.dim3", "dim3") .context(queryContext) .build() ), @@ -5084,4 +5087,181 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest null ); } + + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testRegressionFilteredAggregatorsSubqueryJoins(Map queryContext) + { + cannotVectorize(); + testQuery( + "select\n" + + "count(*) filter (where trim(both from dim1) in (select dim2 from foo)),\n" + + "min(m1) filter (where 'A' not in (select m2 from foo))\n" + + "from foo as t0\n" + + "where __time in (select __time from foo)", + queryContext, + useDefault ? + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + join( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions( + new DefaultDimensionSpec("__time", "d0", ColumnType.LONG) + ) + .setGranularity(Granularities.ALL) + .setLimitSpec(NoopLimitSpec.instance()) + .build() + ), + "j0.", + equalsCondition(makeColumnExpression("__time"), makeColumnExpression("j0.d0")), + JoinType.INNER + ), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG)) + .setDimensions( + new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING), + new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + ) + .setGranularity(Granularities.ALL) + .setLimitSpec(NoopLimitSpec.instance()) + .build() + ), + "_j0.", + "(trim(\"dim1\",' ') == \"_j0.d0\")", + JoinType.LEFT + ), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG)) + .setDimFilter(selector("m2", "A", null)) + .setDimensions( + new DefaultDimensionSpec("v0", "d0", ColumnType.LONG) + ) + .setGranularity(Granularities.ALL) + .setLimitSpec(NoopLimitSpec.instance()) + .build() + ), + "__j0.", + "1", + JoinType.LEFT + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + and( + not(selector("_j0.d1", null, null)), + not(selector("dim1", null, null)) + ), + "a0" + ), + new FilteredAggregatorFactory( + new FloatMinAggregatorFactory("a1", "m1"), + selector("__j0.d0", null, null), + "a1" + ) + ) + .context(queryContext) + .build() + ) : + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + join( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions( + new DefaultDimensionSpec("__time", "d0", ColumnType.LONG) + ) + .setGranularity(Granularities.ALL) + .setLimitSpec(NoopLimitSpec.instance()) + .build() + ), + "j0.", + equalsCondition(makeColumnExpression("__time"), makeColumnExpression("j0.d0")), + JoinType.INNER + ), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG)) + .setDimensions( + new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING), + new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + ) + .setGranularity(Granularities.ALL) + .setLimitSpec(NoopLimitSpec.instance()) + .build() + ), + "_j0.", + "(trim(\"dim1\",' ') == \"_j0.d0\")", + JoinType.LEFT + ), + new QueryDataSource( + new TopNQueryBuilder().dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(new InDimFilter("m2", new HashSet<>(Arrays.asList(null, "A")))) + .virtualColumns(expressionVirtualColumn("v0", "notnull(\"m2\")", ColumnType.LONG)) + .dimension(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG)) + .metric(new InvertedTopNMetricSpec(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC))) + .aggregators(new CountAggregatorFactory("a0")) + .threshold(1) + .build() + ), + "__j0.", + "1", + JoinType.LEFT + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + and( + not(selector("_j0.d1", null, null)), + not(selector("dim1", null, null)) + ), + "a0" + ), + new FilteredAggregatorFactory( + new FloatMinAggregatorFactory("a1", "m1"), + or( + selector("__j0.a0", null, null), + not( + or( + not(expressionFilter("\"__j0.d0\"")), + not(selector("__j0.d0", null, null)) + ) + ) + ), + "a1" + ) + ) + .context(queryContext) + .build() + ), + ImmutableList.of( + new Object[]{useDefault ? 1L : 2L, 1.0f} + ) + ); + } }