fix join and unnest planning to ensure that duplicate join prefixes are not used (#13943)

* fix join and unnest planning to ensure that duplicate join prefixes are not used

* wont somebody please think of the children
This commit is contained in:
Clint Wylie 2023-03-22 12:53:55 -07:00 committed by GitHub
parent 7bab407495
commit 086eb26b74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 221 additions and 8 deletions

View File

@ -336,7 +336,8 @@ public class DruidCorrelateUnnestRel extends DruidRel<DruidCorrelateUnnestRel>
RowSignature.builder().add(
BASE_UNNEST_OUTPUT_COLUMN,
Calcites.getColumnTypeForRelDataType(unnestedType)
).build()
).build(),
DruidJoinQueryRel.findExistingJoinPrefixes(leftQuery.getDataSource())
).rhs;
}

View File

@ -57,6 +57,8 @@ import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException;
import org.apache.druid.sql.calcite.table.RowSignatures;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -160,7 +162,12 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
rightDataSource = rightQuery.getDataSource();
}
final Pair<String, RowSignature> prefixSignaturePair = computeJoinRowSignature(leftSignature, rightSignature);
final Pair<String, RowSignature> prefixSignaturePair = computeJoinRowSignature(
leftSignature,
rightSignature,
findExistingJoinPrefixes(leftDataSource, rightDataSource)
);
VirtualColumnRegistry virtualColumnRegistry = VirtualColumnRegistry.create(
prefixSignaturePair.rhs,
@ -380,13 +387,29 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
&& DruidRels.druidTableIfLeafRel(right).filter(table -> table.getDataSource().isGlobal()).isPresent());
}
static Set<String> findExistingJoinPrefixes(DataSource... dataSources)
{
final ArrayList<DataSource> copy = new ArrayList<>(Arrays.asList(dataSources));
Set<String> prefixes = new HashSet<>();
while (!copy.isEmpty()) {
DataSource current = copy.remove(0);
copy.addAll(current.getChildren());
if (current instanceof JoinDataSource) {
JoinDataSource joiner = (JoinDataSource) current;
prefixes.add(joiner.getRightPrefix());
}
}
return prefixes;
}
/**
* Returns a Pair of "rightPrefix" (for JoinDataSource) and the signature of rows that will result from
* applying that prefix.
*/
static Pair<String, RowSignature> computeJoinRowSignature(
final RowSignature leftSignature,
final RowSignature rightSignature
final RowSignature rightSignature,
final Set<String> prefixes
)
{
final RowSignature.Builder signatureBuilder = RowSignature.builder();
@ -395,8 +418,17 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
signatureBuilder.add(column, leftSignature.getColumnType(column).orElse(null));
}
// Need to include the "0" since findUnusedPrefixForDigits only guarantees safety for digit-initiated suffixes
final String rightPrefix = Calcites.findUnusedPrefixForDigits("j", leftSignature.getColumnNames()) + "0.";
StringBuilder base = new StringBuilder("j");
// the prefixes collection contains all known join prefixes, which might be in use for nested queries but not
// present in the top level row signatures
// loop until we are sure we got a new prefix
String maybePrefix;
do {
// Need to include the "0" since findUnusedPrefixForDigits only guarantees safety for digit-initiated suffixes
maybePrefix = Calcites.findUnusedPrefixForDigits(base.toString(), leftSignature.getColumnNames()) + "0.";
base.insert(0, "_");
} while (prefixes.contains(maybePrefix));
final String rightPrefix = maybePrefix;
for (final String column : rightSignature.getColumnNames()) {
signatureBuilder.add(rightPrefix + column, rightSignature.getColumnType(column).orElse(null));

View File

@ -61,6 +61,7 @@ import org.apache.druid.query.dimension.ExtractionDimensionSpec;
import org.apache.druid.query.extraction.SubstringDimExtractionFn;
import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.BoundDimFilter;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.LikeDimFilter;
import org.apache.druid.query.filter.NotDimFilter;
import org.apache.druid.query.filter.OrDimFilter;
@ -95,8 +96,10 @@ import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@ -4766,8 +4769,8 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
.context(queryContext)
.build()
),
"j0.",
equalsCondition(makeColumnExpression("v0"), makeColumnExpression("j0.v0")),
"_j0.",
equalsCondition(makeColumnExpression("v0"), makeColumnExpression("_j0.v0")),
JoinType.INNER
)
)
@ -4778,7 +4781,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
ImmutableSet.of("a"),
true
))
.columns("dim3", "j0.dim3")
.columns("_j0.dim3", "dim3")
.context(queryContext)
.build()
),
@ -5084,4 +5087,181 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
null
);
}
@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testRegressionFilteredAggregatorsSubqueryJoins(Map<String, Object> queryContext)
{
cannotVectorize();
testQuery(
"select\n" +
"count(*) filter (where trim(both from dim1) in (select dim2 from foo)),\n" +
"min(m1) filter (where 'A' not in (select m2 from foo))\n" +
"from foo as t0\n" +
"where __time in (select __time from foo)",
queryContext,
useDefault ?
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(
join(
join(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(
new DefaultDimensionSpec("__time", "d0", ColumnType.LONG)
)
.setGranularity(Granularities.ALL)
.setLimitSpec(NoopLimitSpec.instance())
.build()
),
"j0.",
equalsCondition(makeColumnExpression("__time"), makeColumnExpression("j0.d0")),
JoinType.INNER
),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
.setDimensions(
new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING),
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
)
.setGranularity(Granularities.ALL)
.setLimitSpec(NoopLimitSpec.instance())
.build()
),
"_j0.",
"(trim(\"dim1\",' ') == \"_j0.d0\")",
JoinType.LEFT
),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
.setDimFilter(selector("m2", "A", null))
.setDimensions(
new DefaultDimensionSpec("v0", "d0", ColumnType.LONG)
)
.setGranularity(Granularities.ALL)
.setLimitSpec(NoopLimitSpec.instance())
.build()
),
"__j0.",
"1",
JoinType.LEFT
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.aggregators(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
and(
not(selector("_j0.d1", null, null)),
not(selector("dim1", null, null))
),
"a0"
),
new FilteredAggregatorFactory(
new FloatMinAggregatorFactory("a1", "m1"),
selector("__j0.d0", null, null),
"a1"
)
)
.context(queryContext)
.build()
) :
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(
join(
join(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(
new DefaultDimensionSpec("__time", "d0", ColumnType.LONG)
)
.setGranularity(Granularities.ALL)
.setLimitSpec(NoopLimitSpec.instance())
.build()
),
"j0.",
equalsCondition(makeColumnExpression("__time"), makeColumnExpression("j0.d0")),
JoinType.INNER
),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
.setDimensions(
new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING),
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
)
.setGranularity(Granularities.ALL)
.setLimitSpec(NoopLimitSpec.instance())
.build()
),
"_j0.",
"(trim(\"dim1\",' ') == \"_j0.d0\")",
JoinType.LEFT
),
new QueryDataSource(
new TopNQueryBuilder().dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.filters(new InDimFilter("m2", new HashSet<>(Arrays.asList(null, "A"))))
.virtualColumns(expressionVirtualColumn("v0", "notnull(\"m2\")", ColumnType.LONG))
.dimension(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))
.metric(new InvertedTopNMetricSpec(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC)))
.aggregators(new CountAggregatorFactory("a0"))
.threshold(1)
.build()
),
"__j0.",
"1",
JoinType.LEFT
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.aggregators(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
and(
not(selector("_j0.d1", null, null)),
not(selector("dim1", null, null))
),
"a0"
),
new FilteredAggregatorFactory(
new FloatMinAggregatorFactory("a1", "m1"),
or(
selector("__j0.a0", null, null),
not(
or(
not(expressionFilter("\"__j0.d0\"")),
not(selector("__j0.d0", null, null))
)
)
),
"a1"
)
)
.context(queryContext)
.build()
),
ImmutableList.of(
new Object[]{useDefault ? 1L : 2L, 1.0f}
)
);
}
}