mirror of https://github.com/apache/druid.git
fix join and unnest planning to ensure that duplicate join prefixes are not used (#13943)
* fix join and unnest planning to ensure that duplicate join prefixes are not used * wont somebody please think of the children
This commit is contained in:
parent
7bab407495
commit
086eb26b74
|
@ -336,7 +336,8 @@ public class DruidCorrelateUnnestRel extends DruidRel<DruidCorrelateUnnestRel>
|
|||
RowSignature.builder().add(
|
||||
BASE_UNNEST_OUTPUT_COLUMN,
|
||||
Calcites.getColumnTypeForRelDataType(unnestedType)
|
||||
).build()
|
||||
).build(),
|
||||
DruidJoinQueryRel.findExistingJoinPrefixes(leftQuery.getDataSource())
|
||||
).rhs;
|
||||
}
|
||||
|
||||
|
|
|
@ -57,6 +57,8 @@ import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException;
|
|||
import org.apache.druid.sql.calcite.table.RowSignatures;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
@ -160,7 +162,12 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
|
|||
rightDataSource = rightQuery.getDataSource();
|
||||
}
|
||||
|
||||
final Pair<String, RowSignature> prefixSignaturePair = computeJoinRowSignature(leftSignature, rightSignature);
|
||||
|
||||
final Pair<String, RowSignature> prefixSignaturePair = computeJoinRowSignature(
|
||||
leftSignature,
|
||||
rightSignature,
|
||||
findExistingJoinPrefixes(leftDataSource, rightDataSource)
|
||||
);
|
||||
|
||||
VirtualColumnRegistry virtualColumnRegistry = VirtualColumnRegistry.create(
|
||||
prefixSignaturePair.rhs,
|
||||
|
@ -380,13 +387,29 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
|
|||
&& DruidRels.druidTableIfLeafRel(right).filter(table -> table.getDataSource().isGlobal()).isPresent());
|
||||
}
|
||||
|
||||
static Set<String> findExistingJoinPrefixes(DataSource... dataSources)
|
||||
{
|
||||
final ArrayList<DataSource> copy = new ArrayList<>(Arrays.asList(dataSources));
|
||||
|
||||
Set<String> prefixes = new HashSet<>();
|
||||
while (!copy.isEmpty()) {
|
||||
DataSource current = copy.remove(0);
|
||||
copy.addAll(current.getChildren());
|
||||
if (current instanceof JoinDataSource) {
|
||||
JoinDataSource joiner = (JoinDataSource) current;
|
||||
prefixes.add(joiner.getRightPrefix());
|
||||
}
|
||||
}
|
||||
return prefixes;
|
||||
}
|
||||
/**
|
||||
* Returns a Pair of "rightPrefix" (for JoinDataSource) and the signature of rows that will result from
|
||||
* applying that prefix.
|
||||
*/
|
||||
static Pair<String, RowSignature> computeJoinRowSignature(
|
||||
final RowSignature leftSignature,
|
||||
final RowSignature rightSignature
|
||||
final RowSignature rightSignature,
|
||||
final Set<String> prefixes
|
||||
)
|
||||
{
|
||||
final RowSignature.Builder signatureBuilder = RowSignature.builder();
|
||||
|
@ -395,8 +418,17 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
|
|||
signatureBuilder.add(column, leftSignature.getColumnType(column).orElse(null));
|
||||
}
|
||||
|
||||
// Need to include the "0" since findUnusedPrefixForDigits only guarantees safety for digit-initiated suffixes
|
||||
final String rightPrefix = Calcites.findUnusedPrefixForDigits("j", leftSignature.getColumnNames()) + "0.";
|
||||
StringBuilder base = new StringBuilder("j");
|
||||
// the prefixes collection contains all known join prefixes, which might be in use for nested queries but not
|
||||
// present in the top level row signatures
|
||||
// loop until we are sure we got a new prefix
|
||||
String maybePrefix;
|
||||
do {
|
||||
// Need to include the "0" since findUnusedPrefixForDigits only guarantees safety for digit-initiated suffixes
|
||||
maybePrefix = Calcites.findUnusedPrefixForDigits(base.toString(), leftSignature.getColumnNames()) + "0.";
|
||||
base.insert(0, "_");
|
||||
} while (prefixes.contains(maybePrefix));
|
||||
final String rightPrefix = maybePrefix;
|
||||
|
||||
for (final String column : rightSignature.getColumnNames()) {
|
||||
signatureBuilder.add(rightPrefix + column, rightSignature.getColumnType(column).orElse(null));
|
||||
|
|
|
@ -61,6 +61,7 @@ import org.apache.druid.query.dimension.ExtractionDimensionSpec;
|
|||
import org.apache.druid.query.extraction.SubstringDimExtractionFn;
|
||||
import org.apache.druid.query.filter.AndDimFilter;
|
||||
import org.apache.druid.query.filter.BoundDimFilter;
|
||||
import org.apache.druid.query.filter.InDimFilter;
|
||||
import org.apache.druid.query.filter.LikeDimFilter;
|
||||
import org.apache.druid.query.filter.NotDimFilter;
|
||||
import org.apache.druid.query.filter.OrDimFilter;
|
||||
|
@ -95,8 +96,10 @@ import org.junit.Ignore;
|
|||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -4766,8 +4769,8 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
|
|||
.context(queryContext)
|
||||
.build()
|
||||
),
|
||||
"j0.",
|
||||
equalsCondition(makeColumnExpression("v0"), makeColumnExpression("j0.v0")),
|
||||
"_j0.",
|
||||
equalsCondition(makeColumnExpression("v0"), makeColumnExpression("_j0.v0")),
|
||||
JoinType.INNER
|
||||
)
|
||||
)
|
||||
|
@ -4778,7 +4781,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
|
|||
ImmutableSet.of("a"),
|
||||
true
|
||||
))
|
||||
.columns("dim3", "j0.dim3")
|
||||
.columns("_j0.dim3", "dim3")
|
||||
.context(queryContext)
|
||||
.build()
|
||||
),
|
||||
|
@ -5084,4 +5087,181 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
|
|||
null
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Parameters(source = QueryContextForJoinProvider.class)
|
||||
public void testRegressionFilteredAggregatorsSubqueryJoins(Map<String, Object> queryContext)
|
||||
{
|
||||
cannotVectorize();
|
||||
testQuery(
|
||||
"select\n" +
|
||||
"count(*) filter (where trim(both from dim1) in (select dim2 from foo)),\n" +
|
||||
"min(m1) filter (where 'A' not in (select m2 from foo))\n" +
|
||||
"from foo as t0\n" +
|
||||
"where __time in (select __time from foo)",
|
||||
queryContext,
|
||||
useDefault ?
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(
|
||||
join(
|
||||
join(
|
||||
join(
|
||||
new TableDataSource(CalciteTests.DATASOURCE1),
|
||||
new QueryDataSource(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setDimensions(
|
||||
new DefaultDimensionSpec("__time", "d0", ColumnType.LONG)
|
||||
)
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setLimitSpec(NoopLimitSpec.instance())
|
||||
.build()
|
||||
),
|
||||
"j0.",
|
||||
equalsCondition(makeColumnExpression("__time"), makeColumnExpression("j0.d0")),
|
||||
JoinType.INNER
|
||||
),
|
||||
new QueryDataSource(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
|
||||
.setDimensions(
|
||||
new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING),
|
||||
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
|
||||
)
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setLimitSpec(NoopLimitSpec.instance())
|
||||
.build()
|
||||
),
|
||||
"_j0.",
|
||||
"(trim(\"dim1\",' ') == \"_j0.d0\")",
|
||||
JoinType.LEFT
|
||||
),
|
||||
new QueryDataSource(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
|
||||
.setDimFilter(selector("m2", "A", null))
|
||||
.setDimensions(
|
||||
new DefaultDimensionSpec("v0", "d0", ColumnType.LONG)
|
||||
)
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setLimitSpec(NoopLimitSpec.instance())
|
||||
.build()
|
||||
),
|
||||
"__j0.",
|
||||
"1",
|
||||
JoinType.LEFT
|
||||
)
|
||||
)
|
||||
.intervals(querySegmentSpec(Filtration.eternity()))
|
||||
.aggregators(
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a0"),
|
||||
and(
|
||||
not(selector("_j0.d1", null, null)),
|
||||
not(selector("dim1", null, null))
|
||||
),
|
||||
"a0"
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new FloatMinAggregatorFactory("a1", "m1"),
|
||||
selector("__j0.d0", null, null),
|
||||
"a1"
|
||||
)
|
||||
)
|
||||
.context(queryContext)
|
||||
.build()
|
||||
) :
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(
|
||||
join(
|
||||
join(
|
||||
join(
|
||||
new TableDataSource(CalciteTests.DATASOURCE1),
|
||||
new QueryDataSource(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setDimensions(
|
||||
new DefaultDimensionSpec("__time", "d0", ColumnType.LONG)
|
||||
)
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setLimitSpec(NoopLimitSpec.instance())
|
||||
.build()
|
||||
),
|
||||
"j0.",
|
||||
equalsCondition(makeColumnExpression("__time"), makeColumnExpression("j0.d0")),
|
||||
JoinType.INNER
|
||||
),
|
||||
new QueryDataSource(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
|
||||
.setDimensions(
|
||||
new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING),
|
||||
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
|
||||
)
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setLimitSpec(NoopLimitSpec.instance())
|
||||
.build()
|
||||
),
|
||||
"_j0.",
|
||||
"(trim(\"dim1\",' ') == \"_j0.d0\")",
|
||||
JoinType.LEFT
|
||||
),
|
||||
new QueryDataSource(
|
||||
new TopNQueryBuilder().dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(querySegmentSpec(Filtration.eternity()))
|
||||
.filters(new InDimFilter("m2", new HashSet<>(Arrays.asList(null, "A"))))
|
||||
.virtualColumns(expressionVirtualColumn("v0", "notnull(\"m2\")", ColumnType.LONG))
|
||||
.dimension(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))
|
||||
.metric(new InvertedTopNMetricSpec(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC)))
|
||||
.aggregators(new CountAggregatorFactory("a0"))
|
||||
.threshold(1)
|
||||
.build()
|
||||
),
|
||||
"__j0.",
|
||||
"1",
|
||||
JoinType.LEFT
|
||||
)
|
||||
)
|
||||
.intervals(querySegmentSpec(Filtration.eternity()))
|
||||
.aggregators(
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a0"),
|
||||
and(
|
||||
not(selector("_j0.d1", null, null)),
|
||||
not(selector("dim1", null, null))
|
||||
),
|
||||
"a0"
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new FloatMinAggregatorFactory("a1", "m1"),
|
||||
or(
|
||||
selector("__j0.a0", null, null),
|
||||
not(
|
||||
or(
|
||||
not(expressionFilter("\"__j0.d0\"")),
|
||||
not(selector("__j0.d0", null, null))
|
||||
)
|
||||
)
|
||||
),
|
||||
"a1"
|
||||
)
|
||||
)
|
||||
.context(queryContext)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{useDefault ? 1L : 2L, 1.0f}
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue