Supporting filters in the left base table for join datasources (#10697)

* where filter left first draft

* Revert changes in calcite test

* Refactor a bit

* Fixing the Tests

* Changes

* Adding tests

* Add tests for correlated queries

* Add comment

* Fix typos
This commit is contained in:
Abhishek Agarwal 2021-03-05 00:09:21 +05:30 committed by GitHub
parent 6040c30fcd
commit 1a15987432
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1577 additions and 198 deletions

View File

@ -197,6 +197,7 @@ public class IndexedTableJoinCursorBenchmark
hashJoinSegment = closer.register(
new HashJoinSegment(
ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment),
null,
clauses,
preAnalysis
)

View File

@ -160,6 +160,7 @@ public class JoinAndLookupBenchmark
hashJoinLookupStringKeySegment = new HashJoinSegment(
ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment),
null,
joinableClausesLookupStringKey,
preAnalysisLookupStringKey
);
@ -194,6 +195,7 @@ public class JoinAndLookupBenchmark
hashJoinLookupLongKeySegment = new HashJoinSegment(
ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment),
null,
joinableClausesLookupLongKey,
preAnalysisLookupLongKey
);
@ -228,6 +230,7 @@ public class JoinAndLookupBenchmark
hashJoinIndexedTableStringKeySegment = new HashJoinSegment(
ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment),
null,
joinableClausesIndexedTableStringKey,
preAnalysisIndexedStringKey
);
@ -262,6 +265,7 @@ public class JoinAndLookupBenchmark
hashJoinIndexedTableLongKeySegment = new HashJoinSegment(
ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment),
null,
joinableClausesIndexedTableLongKey,
preAnalysisIndexedLongKey
);

View File

@ -27,10 +27,12 @@ import com.google.common.collect.ImmutableList;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.segment.join.JoinPrefixUtils;
import org.apache.druid.segment.join.JoinType;
import javax.annotation.Nullable;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
@ -58,13 +60,17 @@ public class JoinDataSource implements DataSource
private final String rightPrefix;
private final JoinConditionAnalysis conditionAnalysis;
private final JoinType joinType;
// An optional filter on the left side if left is direct table access
@Nullable
private final DimFilter leftFilter;
private JoinDataSource(
DataSource left,
DataSource right,
String rightPrefix,
JoinConditionAnalysis conditionAnalysis,
JoinType joinType
JoinType joinType,
@Nullable DimFilter leftFilter
)
{
this.left = Preconditions.checkNotNull(left, "left");
@ -72,6 +78,12 @@ public class JoinDataSource implements DataSource
this.rightPrefix = JoinPrefixUtils.validatePrefix(rightPrefix);
this.conditionAnalysis = Preconditions.checkNotNull(conditionAnalysis, "conditionAnalysis");
this.joinType = Preconditions.checkNotNull(joinType, "joinType");
//TODO: Add support for union data sources
Preconditions.checkArgument(
leftFilter == null || left instanceof TableDataSource,
"left filter is only supported if left data source is direct table access"
);
this.leftFilter = leftFilter;
}
/**
@ -84,6 +96,7 @@ public class JoinDataSource implements DataSource
@JsonProperty("rightPrefix") String rightPrefix,
@JsonProperty("condition") String condition,
@JsonProperty("joinType") JoinType joinType,
@Nullable @JsonProperty("leftFilter") DimFilter leftFilter,
@JacksonInject ExprMacroTable macroTable
)
{
@ -96,7 +109,8 @@ public class JoinDataSource implements DataSource
StringUtils.nullToEmptyNonDruidDataString(rightPrefix),
macroTable
),
joinType
joinType,
leftFilter
);
}
@ -108,10 +122,26 @@ public class JoinDataSource implements DataSource
final DataSource right,
final String rightPrefix,
final JoinConditionAnalysis conditionAnalysis,
final JoinType joinType
final JoinType joinType,
final DimFilter leftFilter
)
{
return new JoinDataSource(left, right, rightPrefix, conditionAnalysis, joinType);
return new JoinDataSource(left, right, rightPrefix, conditionAnalysis, joinType, leftFilter);
}
/**
* Create a join dataSource from an existing {@link JoinConditionAnalysis}.
*/
public static JoinDataSource create(
final DataSource left,
final DataSource right,
final String rightPrefix,
final String condition,
final JoinType joinType,
final ExprMacroTable macroTable
)
{
return create(left, right, rightPrefix, condition, joinType, null, macroTable);
}
@Override
@ -158,6 +188,13 @@ public class JoinDataSource implements DataSource
return joinType;
}
@JsonProperty
@Nullable
public DimFilter getLeftFilter()
{
return leftFilter;
}
@Override
public List<DataSource> getChildren()
{
@ -171,7 +208,14 @@ public class JoinDataSource implements DataSource
throw new IAE("Expected [2] children, got [%d]", children.size());
}
return new JoinDataSource(children.get(0), children.get(1), rightPrefix, conditionAnalysis, joinType);
return new JoinDataSource(
children.get(0),
children.get(1),
rightPrefix,
conditionAnalysis,
joinType,
leftFilter
);
}
@Override
@ -206,13 +250,14 @@ public class JoinDataSource implements DataSource
Objects.equals(right, that.right) &&
Objects.equals(rightPrefix, that.rightPrefix) &&
Objects.equals(conditionAnalysis, that.conditionAnalysis) &&
Objects.equals(leftFilter, that.leftFilter) &&
joinType == that.joinType;
}
@Override
public int hashCode()
{
return Objects.hash(left, right, rightPrefix, conditionAnalysis, joinType);
return Objects.hash(left, right, rightPrefix, conditionAnalysis, joinType, leftFilter);
}
@Override
@ -224,6 +269,7 @@ public class JoinDataSource implements DataSource
", rightPrefix='" + rightPrefix + '\'' +
", condition=" + conditionAnalysis +
", joinType=" + joinType +
", leftFilter=" + leftFilter +
'}';
}
}

View File

@ -26,6 +26,7 @@ import org.apache.druid.guice.annotations.PublicApi;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.planning.DataSourceAnalysis;
import org.apache.druid.query.planning.PreJoinableClause;
import org.apache.druid.query.spec.MultipleSpecificSegmentSpec;
@ -192,6 +193,7 @@ public class Queries
final DataSourceAnalysis analysis = DataSourceAnalysis.forDataSource(query.getDataSource());
DataSource current = newBaseDataSource;
DimFilter joinBaseFilter = analysis.getJoinBaseTableFilter().orElse(null);
for (final PreJoinableClause clause : analysis.getPreJoinableClauses()) {
current = JoinDataSource.create(
@ -199,8 +201,10 @@ public class Queries
clause.getDataSource(),
clause.getPrefix(),
clause.getCondition(),
clause.getJoinType()
clause.getJoinType(),
joinBaseFilter
);
joinBaseFilter = null;
}
retVal = query.withDataSource(current);

View File

@ -28,6 +28,7 @@ import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnionDataSource;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.spec.QuerySegmentSpec;
import javax.annotation.Nullable;
@ -80,12 +81,15 @@ public class DataSourceAnalysis
private final DataSource baseDataSource;
@Nullable
private final Query<?> baseQuery;
@Nullable
private final DimFilter joinBaseTableFilter;
private final List<PreJoinableClause> preJoinableClauses;
private DataSourceAnalysis(
DataSource dataSource,
DataSource baseDataSource,
@Nullable Query<?> baseQuery,
@Nullable DimFilter joinBaseTableFilter,
List<PreJoinableClause> preJoinableClauses
)
{
@ -98,6 +102,7 @@ public class DataSourceAnalysis
this.dataSource = dataSource;
this.baseDataSource = baseDataSource;
this.baseQuery = baseQuery;
this.joinBaseTableFilter = joinBaseTableFilter;
this.preJoinableClauses = preJoinableClauses;
}
@ -121,10 +126,10 @@ public class DataSourceAnalysis
}
if (current instanceof JoinDataSource) {
final Pair<DataSource, List<PreJoinableClause>> flattened = flattenJoin((JoinDataSource) current);
return new DataSourceAnalysis(dataSource, flattened.lhs, baseQuery, flattened.rhs);
final Pair<Pair<DataSource, DimFilter>, List<PreJoinableClause>> flattened = flattenJoin((JoinDataSource) current);
return new DataSourceAnalysis(dataSource, flattened.lhs.lhs, baseQuery, flattened.lhs.rhs, flattened.rhs);
} else {
return new DataSourceAnalysis(dataSource, current, baseQuery, Collections.emptyList());
return new DataSourceAnalysis(dataSource, current, baseQuery, null, Collections.emptyList());
}
}
@ -134,14 +139,19 @@ public class DataSourceAnalysis
*
* @throws IllegalArgumentException if dataSource cannot be fully flattened.
*/
private static Pair<DataSource, List<PreJoinableClause>> flattenJoin(final JoinDataSource dataSource)
private static Pair<Pair<DataSource, DimFilter>, List<PreJoinableClause>> flattenJoin(final JoinDataSource dataSource)
{
DataSource current = dataSource;
DimFilter currentDimFilter = null;
final List<PreJoinableClause> preJoinableClauses = new ArrayList<>();
while (current instanceof JoinDataSource) {
final JoinDataSource joinDataSource = (JoinDataSource) current;
current = joinDataSource.getLeft();
if (currentDimFilter != null) {
throw new IAE("Left filters are only allowed when left child is direct table access");
}
currentDimFilter = joinDataSource.getLeftFilter();
preJoinableClauses.add(
new PreJoinableClause(
joinDataSource.getRightPrefix(),
@ -156,7 +166,7 @@ public class DataSourceAnalysis
// going-up order. So reverse them.
Collections.reverse(preJoinableClauses);
return Pair.of(current, preJoinableClauses);
return Pair.of(Pair.of(current, currentDimFilter), preJoinableClauses);
}
/**
@ -214,11 +224,20 @@ public class DataSourceAnalysis
return Optional.ofNullable(baseQuery);
}
/**
* If the original data source is a join data source and there is a DimFilter on the base table data source,
* that DimFilter is returned here
*/
public Optional<DimFilter> getJoinBaseTableFilter()
{
return Optional.ofNullable(joinBaseTableFilter);
}
/**
* Returns the {@link QuerySegmentSpec} that is associated with the base datasource, if any. This only happens
* when there is an outer query datasource. In this case, the base querySegmentSpec is the one associated with the
* innermost subquery.
*
* <p>
* This {@link QuerySegmentSpec} is taken from the query returned by {@link #getBaseQuery()}.
*
* @return the query segment spec associated with the base datasource if {@link #isQuery()} is true, else empty

View File

@ -22,6 +22,7 @@ package org.apache.druid.segment.join;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.guava.CloseQuietly;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.query.filter.Filter;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.segment.StorageAdapter;
@ -43,6 +44,7 @@ import java.util.Optional;
public class HashJoinSegment implements SegmentReference
{
private final SegmentReference baseSegment;
private final Filter baseFilter;
private final List<JoinableClause> clauses;
private final JoinFilterPreAnalysis joinFilterPreAnalysis;
@ -54,11 +56,13 @@ public class HashJoinSegment implements SegmentReference
*/
public HashJoinSegment(
SegmentReference baseSegment,
@Nullable Filter baseFilter,
List<JoinableClause> clauses,
JoinFilterPreAnalysis joinFilterPreAnalysis
)
{
this.baseSegment = baseSegment;
this.baseFilter = baseFilter;
this.clauses = clauses;
this.joinFilterPreAnalysis = joinFilterPreAnalysis;
@ -93,7 +97,12 @@ public class HashJoinSegment implements SegmentReference
@Override
public StorageAdapter asStorageAdapter()
{
return new HashJoinSegmentStorageAdapter(baseSegment.asStorageAdapter(), clauses, joinFilterPreAnalysis);
return new HashJoinSegmentStorageAdapter(
baseSegment.asStorageAdapter(),
baseFilter,
clauses,
joinFilterPreAnalysis
);
}
@Override

View File

@ -56,6 +56,7 @@ import java.util.Set;
public class HashJoinSegmentStorageAdapter implements StorageAdapter
{
private final StorageAdapter baseAdapter;
private final Filter baseFilter;
private final List<JoinableClause> clauses;
private final JoinFilterPreAnalysis joinFilterPreAnalysis;
@ -69,8 +70,25 @@ public class HashJoinSegmentStorageAdapter implements StorageAdapter
final List<JoinableClause> clauses,
final JoinFilterPreAnalysis joinFilterPreAnalysis
)
{
this(baseAdapter, null, clauses, joinFilterPreAnalysis);
}
/**
* @param baseAdapter A StorageAdapter for the left-hand side base segment
* @param baseFilter A filter for the left-hand side base segment
* @param clauses The right-hand side clauses. The caller is responsible for ensuring that there are no
* @param joinFilterPreAnalysis Pre-analysis for the query we expect to run on this storage adapter
*/
HashJoinSegmentStorageAdapter(
final StorageAdapter baseAdapter,
final Filter baseFilter,
final List<JoinableClause> clauses,
final JoinFilterPreAnalysis joinFilterPreAnalysis
)
{
this.baseAdapter = baseAdapter;
this.baseFilter = baseFilter;
this.clauses = clauses;
this.joinFilterPreAnalysis = joinFilterPreAnalysis;
}
@ -239,14 +257,12 @@ public class HashJoinSegmentStorageAdapter implements StorageAdapter
postJoinVirtualColumns
);
JoinFilterSplit joinFilterSplit = JoinFilterAnalyzer.splitFilter(joinFilterPreAnalysis);
// We merge the filter on base table specified by the user and filter on the base table that is pushed from
// the join
JoinFilterSplit joinFilterSplit = JoinFilterAnalyzer.splitFilter(joinFilterPreAnalysis, baseFilter);
preJoinVirtualColumns.addAll(joinFilterSplit.getPushDownVirtualColumns());
// Soon, we will need a way to push filters past a join when possible. This could potentially be done right here
// (by splitting out pushable pieces of 'filter') or it could be done at a higher level (i.e. in the SQL planner).
//
// If it's done in the SQL planner, that will likely mean adding a 'baseFilter' parameter to this class that would
// be passed in to the below baseAdapter.makeCursors call (instead of the null filter).
final Sequence<Cursor> baseCursorSequence = baseAdapter.makeCursors(
joinFilterSplit.getBaseTableFilter().isPresent() ? joinFilterSplit.getBaseTableFilter().get() : null,
interval,

View File

@ -23,13 +23,13 @@ public enum JoinType
{
INNER {
@Override
boolean isLefty()
public boolean isLefty()
{
return false;
}
@Override
boolean isRighty()
public boolean isRighty()
{
return false;
}
@ -37,13 +37,13 @@ public enum JoinType
LEFT {
@Override
boolean isLefty()
public boolean isLefty()
{
return true;
}
@Override
boolean isRighty()
public boolean isRighty()
{
return false;
}
@ -51,13 +51,13 @@ public enum JoinType
RIGHT {
@Override
boolean isLefty()
public boolean isLefty()
{
return false;
}
@Override
boolean isRighty()
public boolean isRighty()
{
return true;
}
@ -65,13 +65,13 @@ public enum JoinType
FULL {
@Override
boolean isLefty()
public boolean isLefty()
{
return true;
}
@Override
boolean isRighty()
public boolean isRighty()
{
return true;
}
@ -80,10 +80,10 @@ public enum JoinType
/**
* "Lefty" joins (LEFT or FULL) always include the full left-hand side, and can generate nulls on the right.
*/
abstract boolean isLefty();
public abstract boolean isLefty();
/**
* "Righty" joins (RIGHT or FULL) always include the full right-hand side, and can generate nulls on the left.
*/
abstract boolean isRighty();
public abstract boolean isRighty();
}

View File

@ -24,6 +24,7 @@ import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.Query;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.filter.Filter;
import org.apache.druid.query.planning.DataSourceAnalysis;
import org.apache.druid.query.planning.PreJoinableClause;
import org.apache.druid.segment.SegmentReference;
@ -69,6 +70,7 @@ public class JoinableFactoryWrapper
* query from the end user.
*/
public Function<SegmentReference, SegmentReference> createSegmentMapFn(
final Filter baseFilter,
final List<PreJoinableClause> clauses,
final AtomicLong cpuTimeAccumulator,
final Query<?> query
@ -94,6 +96,7 @@ public class JoinableFactoryWrapper
return baseSegment ->
new HashJoinSegment(
baseSegment,
baseFilter,
joinableClauses.getJoinableClauses(),
joinFilterPreAnalysis
);
@ -103,7 +106,8 @@ public class JoinableFactoryWrapper
}
/**
* Compute a cache key prefix for data sources that participate in the RHS of a join. This key prefix
* Compute a cache key prefix for a join data source. This includes the data sources that participate in the RHS of a
* join as well as any query specific constructs associated with join data source such as base table filter. This key prefix
* can be used in segment level cache or result level cache. The function can return following wrapped in an
* Optional
* - Non-empty byte array - If there is join datasource involved and caching is possible. The result includes
@ -126,6 +130,9 @@ public class JoinableFactoryWrapper
final CacheKeyBuilder keyBuilder;
keyBuilder = new CacheKeyBuilder(JOIN_OPERATION);
if (dataSourceAnalysis.getJoinBaseTableFilter().isPresent()) {
keyBuilder.appendCacheable(dataSourceAnalysis.getJoinBaseTableFilter().get());
}
for (PreJoinableClause clause : clauses) {
Optional<byte[]> bytes = joinableFactory.computeJoinCacheKey(clause.getDataSource(), clause.getCondition());
if (!bytes.isPresent()) {

View File

@ -33,6 +33,7 @@ import org.apache.druid.segment.filter.OrFilter;
import org.apache.druid.segment.filter.SelectorFilter;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
@ -137,18 +138,27 @@ public class JoinFilterAnalyzer
return preAnalysisBuilder.withCorrelations(correlations).build();
}
/**
* @param joinFilterPreAnalysis The pre-analysis computed by {@link #computeJoinFilterPreAnalysis)}
*
* @return A JoinFilterSplit indicating what parts of the filter should be applied pre-join and post-join
*/
public static JoinFilterSplit splitFilter(
JoinFilterPreAnalysis joinFilterPreAnalysis
)
{
return splitFilter(joinFilterPreAnalysis, null);
}
/**
* @param joinFilterPreAnalysis The pre-analysis computed by {@link #computeJoinFilterPreAnalysis)}
* @param baseFilter - Filter on base table that was specified in the query itself
*
* @return A JoinFilterSplit indicating what parts of the filter should be applied pre-join and post-join
*/
public static JoinFilterSplit splitFilter(
JoinFilterPreAnalysis joinFilterPreAnalysis,
@Nullable Filter baseFilter
)
{
if (joinFilterPreAnalysis.getOriginalFilter() == null || !joinFilterPreAnalysis.isEnableFilterPushDown()) {
return new JoinFilterSplit(
null,
baseFilter,
joinFilterPreAnalysis.getOriginalFilter(),
ImmutableSet.of()
);
@ -159,6 +169,10 @@ public class JoinFilterAnalyzer
List<Filter> rightFilters = new ArrayList<>();
Map<Expr, VirtualColumn> pushDownVirtualColumnsForLhsExprs = new HashMap<>();
if (null != baseFilter) {
leftFilters.add(baseFilter);
}
for (Filter baseTableFilter : joinFilterPreAnalysis.getNormalizedBaseTableClauses()) {
if (!Filters.filterMatchesNull(baseTableFilter)) {
leftFilters.add(baseTableFilter);

View File

@ -24,12 +24,14 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.filter.TrueDimFilter;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.join.JoinType;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mockito;
import java.util.Collections;
@ -152,11 +154,52 @@ public class JoinDataSourceTest
public void test_serde() throws Exception
{
final ObjectMapper jsonMapper = TestHelper.makeJsonMapper();
JoinDataSource joinDataSource = JoinDataSource.create(
new TableDataSource("table1"),
new TableDataSource("table2"),
"j.",
"x == \"j.x\"",
JoinType.LEFT,
TrueDimFilter.instance(),
ExprMacroTable.nil()
);
final JoinDataSource deserialized = (JoinDataSource) jsonMapper.readValue(
jsonMapper.writeValueAsString(joinTableToLookup),
jsonMapper.writeValueAsString(joinDataSource),
DataSource.class
);
Assert.assertEquals(joinTableToLookup, deserialized);
Assert.assertEquals(joinDataSource, deserialized);
}
@Test
public void testException_leftFilterOnNonTableSource()
{
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("left filter is only supported if left data source is direct table access");
JoinDataSource ignored = JoinDataSource.create(
new QueryDataSource(Mockito.mock(Query.class)),
new TableDataSource("table"),
"j.",
"x == \"j.x\"",
JoinType.LEFT,
TrueDimFilter.instance(),
ExprMacroTable.nil()
);
}
@Test
public void testLeftFilter()
{
JoinDataSource dataSource = JoinDataSource.create(
new TableDataSource("table1"),
new TableDataSource("table2"),
"j.",
"x == \"j.x\"",
JoinType.LEFT,
TrueDimFilter.instance(),
ExprMacroTable.nil()
);
Assert.assertEquals(TrueDimFilter.instance(), dataSource.getLeftFilter());
}
}

View File

@ -30,6 +30,7 @@ import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.filter.TrueDimFilter;
import org.apache.druid.query.spec.MultipleSpecificSegmentSpec;
import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesResultValue;
@ -423,6 +424,7 @@ public class QueriesTest
"j0.",
"\"foo.x\" == \"bar.x\"",
JoinType.INNER,
null,
ExprMacroTable.nil()
)
)
@ -459,6 +461,7 @@ public class QueriesTest
"j0.",
"\"foo.x\" == \"bar.x\"",
JoinType.INNER,
null,
ExprMacroTable.nil()
)
)
@ -479,4 +482,78 @@ public class QueriesTest
)
);
}
@Test
public void testWithBaseDataSourcedBaseFilterWithMultiJoin()
{
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
.dataSource(
new QueryDataSource(
Druids.newTimeseriesQueryBuilder()
.dataSource(
JoinDataSource.create(
JoinDataSource.create(
new TableDataSource("foo"),
new TableDataSource("bar"),
"j1.",
"\"foo.x\" == \"bar.x\"",
JoinType.INNER,
TrueDimFilter.instance(),
ExprMacroTable.nil()
),
new TableDataSource("foo_outer"),
"j0.",
"\"foo_outer.x\" == \"bar.x\"",
JoinType.INNER,
null,
ExprMacroTable.nil()
)
)
.intervals("2000/3000")
.granularity(Granularities.ALL)
.build()
)
)
.intervals("2000/3000")
.granularity(Granularities.ALL)
.build(),
Queries.withBaseDataSource(
Druids.newTimeseriesQueryBuilder()
.dataSource(
new QueryDataSource(
Druids.newTimeseriesQueryBuilder()
.dataSource(
JoinDataSource.create(
JoinDataSource.create(
new TableDataSource("foo_inner"),
new TableDataSource("bar"),
"j1.",
"\"foo.x\" == \"bar.x\"",
JoinType.INNER,
TrueDimFilter.instance(),
ExprMacroTable.nil()
),
new TableDataSource("foo_outer"),
"j0.",
"\"foo_outer.x\" == \"bar.x\"",
JoinType.INNER,
null,
ExprMacroTable.nil()
)
)
.intervals("2000/3000")
.granularity(Granularities.ALL)
.build()
)
)
.intervals("2000/3000")
.granularity(Granularities.ALL)
.build(),
new TableDataSource("foo")
)
);
}
}

View File

@ -32,6 +32,8 @@ import org.apache.druid.query.LookupDataSource;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnionDataSource;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.TrueDimFilter;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.segment.column.RowSignature;
@ -237,6 +239,54 @@ public class DataSourceAnalysisTest
Assert.assertEquals(joinDataSource, analysis.getDataSource());
Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource());
Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource());
Assert.assertEquals(Optional.empty(), analysis.getJoinBaseTableFilter());
Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuery());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuerySegmentSpec());
Assert.assertEquals(
ImmutableList.of(
new PreJoinableClause("1.", LOOKUP_LOOKYLOO, JoinType.INNER, joinClause("1.")),
new PreJoinableClause("2.", INLINE, JoinType.LEFT, joinClause("2.")),
new PreJoinableClause("3.", subquery(LOOKUP_LOOKYLOO), JoinType.FULL, joinClause("3."))
),
analysis.getPreJoinableClauses()
);
Assert.assertTrue(analysis.isJoin());
}
@Test
public void testJoinSimpleLeftLeaningWithLeftFilter()
{
final JoinDataSource joinDataSource =
join(
join(
join(
TABLE_FOO,
LOOKUP_LOOKYLOO,
"1.",
JoinType.INNER,
TrueDimFilter.instance()
),
INLINE,
"2.",
JoinType.LEFT
),
subquery(LOOKUP_LOOKYLOO),
"3.",
JoinType.FULL
);
final DataSourceAnalysis analysis = DataSourceAnalysis.forDataSource(joinDataSource);
Assert.assertTrue(analysis.isConcreteBased());
Assert.assertTrue(analysis.isConcreteTableBased());
Assert.assertFalse(analysis.isGlobal());
Assert.assertFalse(analysis.isQuery());
Assert.assertEquals(joinDataSource, analysis.getDataSource());
Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource());
Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource());
Assert.assertEquals(TrueDimFilter.instance(), analysis.getJoinBaseTableFilter().orElse(null));
Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuery());
@ -290,6 +340,54 @@ public class DataSourceAnalysisTest
Assert.assertEquals(joinDataSource, analysis.getDataSource());
Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource());
Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource());
Assert.assertEquals(Optional.empty(), analysis.getJoinBaseTableFilter());
Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuery());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuerySegmentSpec());
Assert.assertEquals(
ImmutableList.of(
new PreJoinableClause("3.", rightLeaningJoinStack, JoinType.RIGHT, joinClause("3."))
),
analysis.getPreJoinableClauses()
);
Assert.assertTrue(analysis.isJoin());
}
@Test
public void testJoinSimpleRightLeaningWithLeftFilter()
{
final JoinDataSource rightLeaningJoinStack =
join(
LOOKUP_LOOKYLOO,
join(
INLINE,
subquery(LOOKUP_LOOKYLOO),
"1.",
JoinType.LEFT
),
"2.",
JoinType.FULL
);
final JoinDataSource joinDataSource =
join(
TABLE_FOO,
rightLeaningJoinStack,
"3.",
JoinType.RIGHT,
TrueDimFilter.instance()
);
final DataSourceAnalysis analysis = DataSourceAnalysis.forDataSource(joinDataSource);
Assert.assertTrue(analysis.isConcreteBased());
Assert.assertTrue(analysis.isConcreteTableBased());
Assert.assertFalse(analysis.isGlobal());
Assert.assertFalse(analysis.isQuery());
Assert.assertEquals(joinDataSource, analysis.getDataSource());
Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource());
Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource());
Assert.assertEquals(TrueDimFilter.instance(), analysis.getJoinBaseTableFilter().orElse(null));
Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuery());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuerySegmentSpec());
@ -309,7 +407,8 @@ public class DataSourceAnalysisTest
TABLE_FOO,
subquery(TABLE_FOO),
"1.",
JoinType.INNER
JoinType.INNER,
TrueDimFilter.instance()
);
final DataSourceAnalysis analysis = DataSourceAnalysis.forDataSource(joinDataSource);
@ -320,6 +419,7 @@ public class DataSourceAnalysisTest
Assert.assertFalse(analysis.isQuery());
Assert.assertEquals(joinDataSource, analysis.getDataSource());
Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource());
Assert.assertEquals(TrueDimFilter.instance(), analysis.getJoinBaseTableFilter().orElse(null));
Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource());
Assert.assertEquals(
@ -350,6 +450,7 @@ public class DataSourceAnalysisTest
Assert.assertFalse(analysis.isQuery());
Assert.assertEquals(joinDataSource, analysis.getDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseTableDataSource());
Assert.assertEquals(Optional.empty(), analysis.getJoinBaseTableFilter());
Assert.assertEquals(Optional.of(unionDataSource), analysis.getBaseUnionDataSource());
Assert.assertEquals(unionDataSource, analysis.getBaseDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuery());
@ -373,7 +474,8 @@ public class DataSourceAnalysisTest
TABLE_FOO,
LOOKUP_LOOKYLOO,
"1.",
JoinType.INNER
JoinType.INNER,
TrueDimFilter.instance()
)
)
);
@ -386,6 +488,7 @@ public class DataSourceAnalysisTest
Assert.assertTrue(analysis.isQuery());
Assert.assertEquals(queryDataSource, analysis.getDataSource());
Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource());
Assert.assertEquals(TrueDimFilter.instance(), analysis.getJoinBaseTableFilter().orElse(null));
Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource());
Assert.assertEquals(
@ -395,7 +498,8 @@ public class DataSourceAnalysisTest
TABLE_FOO,
LOOKUP_LOOKYLOO,
"1.",
JoinType.INNER
JoinType.INNER,
TrueDimFilter.instance()
)
).getQuery()
),
@ -436,6 +540,7 @@ public class DataSourceAnalysisTest
Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuery());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuerySegmentSpec());
Assert.assertEquals(Optional.empty(), analysis.getJoinBaseTableFilter());
Assert.assertEquals(
ImmutableList.of(
new PreJoinableClause("1.", LOOKUP_LOOKYLOO, JoinType.INNER, joinClause("1."))
@ -467,6 +572,7 @@ public class DataSourceAnalysisTest
Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuery());
Assert.assertEquals(Optional.empty(), analysis.getBaseQuerySegmentSpec());
Assert.assertEquals(Optional.empty(), analysis.getJoinBaseTableFilter());
Assert.assertEquals(
ImmutableList.of(
new PreJoinableClause("1.", TABLE_FOO, JoinType.INNER, joinClause("1."))
@ -484,7 +590,7 @@ public class DataSourceAnalysisTest
.withNonnullFields("dataSource")
// These fields are not necessary, because they're wholly determined by "dataSource"
.withIgnoredFields("baseDataSource", "baseQuery", "preJoinableClauses")
.withIgnoredFields("baseDataSource", "baseQuery", "preJoinableClauses", "joinBaseTableFilter")
.verify();
}
@ -495,7 +601,8 @@ public class DataSourceAnalysisTest
final DataSource left,
final DataSource right,
final String rightPrefix,
final JoinType joinType
final JoinType joinType,
final DimFilter dimFilter
)
{
return JoinDataSource.create(
@ -504,10 +611,21 @@ public class DataSourceAnalysisTest
rightPrefix,
joinClause(rightPrefix).getOriginalExpression(),
joinType,
dimFilter,
ExprMacroTable.nil()
);
}
private static JoinDataSource join(
final DataSource left,
final DataSource right,
final String rightPrefix,
final JoinType joinType
)
{
return join(left, right, rightPrefix, joinType, null);
}
/**
* Generate a join clause that joins on a column named "x" on both sides.
*/

View File

@ -35,6 +35,7 @@ import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.filter.Filters;
import org.apache.druid.segment.filter.SelectorFilter;
import org.apache.druid.segment.join.filter.JoinFilterPreAnalysis;
import org.apache.druid.segment.join.lookup.LookupJoinable;
@ -43,6 +44,7 @@ import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -2019,6 +2021,207 @@ public class HashJoinSegmentStorageAdapterTest extends BaseHashJoinSegmentStorag
);
}
@Test
public void test_makeCursors_factToCountryLeftWithBaseFilter()
{
List<JoinableClause> joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.LEFT));
JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis(
null,
joinableClauses,
VirtualColumns.EMPTY
);
JoinTestHelper.verifyCursors(
new HashJoinSegmentStorageAdapter(
factSegment.asStorageAdapter(),
Filters.or(Arrays.asList(
new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
)),
joinableClauses,
joinFilterPreAnalysis
).makeCursors(
null,
Intervals.ETERNITY,
VirtualColumns.EMPTY,
Granularities.ALL,
false,
null
),
ImmutableList.of(
"page",
"countryIsoCode",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryIsoCode",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryName",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryNumber"
),
ImmutableList.of(
new Object[]{"Didier Leclair", "CA", "CA", "Canada", 1L},
new Object[]{"Les Argonautes", "CA", "CA", "Canada", 1L},
new Object[]{"Sarah Michelle Gellar", "CA", "CA", "Canada", 1L},
new Object[]{"Orange Soda", "MatchNothing", null, null, NULL_COUNTRY}
)
);
}
@Test
public void test_makeCursors_factToCountryInnerWithBaseFilter()
{
List<JoinableClause> joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.INNER));
JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis(
null,
joinableClauses,
VirtualColumns.EMPTY
);
JoinTestHelper.verifyCursors(
new HashJoinSegmentStorageAdapter(
factSegment.asStorageAdapter(),
Filters.or(Arrays.asList(
new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
)),
joinableClauses,
joinFilterPreAnalysis
).makeCursors(
null,
Intervals.ETERNITY,
VirtualColumns.EMPTY,
Granularities.ALL,
false,
null
),
ImmutableList.of(
"page",
"countryIsoCode",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryIsoCode",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryName",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryNumber"
),
ImmutableList.of(
new Object[]{"Didier Leclair", "CA", "CA", "Canada", 1L},
new Object[]{"Les Argonautes", "CA", "CA", "Canada", 1L},
new Object[]{"Sarah Michelle Gellar", "CA", "CA", "Canada", 1L}
)
);
}
@Test
public void test_makeCursors_factToCountryRightWithBaseFilter()
{
List<JoinableClause> joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.RIGHT));
JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis(
null,
joinableClauses,
VirtualColumns.EMPTY
);
JoinTestHelper.verifyCursors(
new HashJoinSegmentStorageAdapter(
factSegment.asStorageAdapter(),
Filters.or(Arrays.asList(
new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
)),
joinableClauses,
joinFilterPreAnalysis
).makeCursors(
null,
Intervals.ETERNITY,
VirtualColumns.EMPTY,
Granularities.ALL,
false,
null
),
ImmutableList.of(
"page",
"countryIsoCode",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryIsoCode",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryName",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryNumber"
),
ImmutableList.of(
new Object[]{"Didier Leclair", "CA", "CA", "Canada", 1L},
new Object[]{"Les Argonautes", "CA", "CA", "Canada", 1L},
new Object[]{"Sarah Michelle Gellar", "CA", "CA", "Canada", 1L},
new Object[]{null, null, "AU", "Australia", 0L},
new Object[]{null, null, "CL", "Chile", 2L},
new Object[]{null, null, "DE", "Germany", 3L},
new Object[]{null, null, "EC", "Ecuador", 4L},
new Object[]{null, null, "FR", "France", 5L},
new Object[]{null, null, "GB", "United Kingdom", 6L},
new Object[]{null, null, "IT", "Italy", 7L},
new Object[]{null, null, "JP", "Japan", 8L},
new Object[]{null, null, "KR", "Republic of Korea", 9L},
new Object[]{null, null, "MX", "Mexico", 10L},
new Object[]{null, null, "NO", "Norway", 11L},
new Object[]{null, null, "SV", "El Salvador", 12L},
new Object[]{null, null, "US", "United States", 13L},
new Object[]{null, null, "AX", "Atlantis", 14L},
new Object[]{null, null, "SU", "States United", 15L},
new Object[]{null, null, "USCA", "Usca", 16L},
new Object[]{null, null, "MMMM", "Fourems", 205L}
)
);
}
@Test
public void test_makeCursors_factToCountryFullWithBaseFilter()
{
List<JoinableClause> joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.FULL));
JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis(
null,
joinableClauses,
VirtualColumns.EMPTY
);
JoinTestHelper.verifyCursors(
new HashJoinSegmentStorageAdapter(
factSegment.asStorageAdapter(),
Filters.or(Arrays.asList(
new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
)),
joinableClauses,
joinFilterPreAnalysis
).makeCursors(
null,
Intervals.ETERNITY,
VirtualColumns.EMPTY,
Granularities.ALL,
false,
null
),
ImmutableList.of(
"page",
"countryIsoCode",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryIsoCode",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryName",
FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryNumber"
),
ImmutableList.of(
new Object[]{"Didier Leclair", "CA", "CA", "Canada", 1L},
new Object[]{"Les Argonautes", "CA", "CA", "Canada", 1L},
new Object[]{"Sarah Michelle Gellar", "CA", "CA", "Canada", 1L},
new Object[]{"Orange Soda", "MatchNothing", null, null, NullHandling.sqlCompatible() ? null : 0L},
new Object[]{null, null, "AU", "Australia", 0L},
new Object[]{null, null, "CL", "Chile", 2L},
new Object[]{null, null, "DE", "Germany", 3L},
new Object[]{null, null, "EC", "Ecuador", 4L},
new Object[]{null, null, "FR", "France", 5L},
new Object[]{null, null, "GB", "United Kingdom", 6L},
new Object[]{null, null, "IT", "Italy", 7L},
new Object[]{null, null, "JP", "Japan", 8L},
new Object[]{null, null, "KR", "Republic of Korea", 9L},
new Object[]{null, null, "MX", "Mexico", 10L},
new Object[]{null, null, "NO", "Norway", 11L},
new Object[]{null, null, "SV", "El Salvador", 12L},
new Object[]{null, null, "US", "United States", 13L},
new Object[]{null, null, "AX", "Atlantis", 14L},
new Object[]{null, null, "SU", "States United", 15L},
new Object[]{null, null, "USCA", "Usca", 16L},
new Object[]{null, null, "MMMM", "Fourems", 205L}
)
);
}
@Test
public void test_determineBaseColumnsWithPreAndPostJoinVirtualColumns()
{

View File

@ -182,6 +182,7 @@ public class HashJoinSegmentTest extends InitializedNullHandlingTest
};
hashJoinSegment = new HashJoinSegment(
testWrapper,
null,
joinableClauses,
null
)
@ -210,6 +211,7 @@ public class HashJoinSegmentTest extends InitializedNullHandlingTest
final HashJoinSegment ignored = new HashJoinSegment(
ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment),
null,
joinableClauses,
null
);

View File

@ -40,6 +40,7 @@ import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.filter.AndFilter;
import org.apache.druid.segment.filter.BoundFilter;
import org.apache.druid.segment.filter.FalseFilter;
import org.apache.druid.segment.filter.Filters;
import org.apache.druid.segment.filter.OrFilter;
import org.apache.druid.segment.filter.SelectorFilter;
import org.apache.druid.segment.join.filter.JoinFilterAnalyzer;
@ -54,6 +55,7 @@ import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.junit.Assert;
import org.junit.Test;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
@ -2650,4 +2652,29 @@ public class JoinFilterAnalyzerTest extends BaseHashJoinSegmentStorageAdapterTes
.withNonnullFields("virtualColumns")
.verify();
}
@Test
public void test_filterPushDown_baseTableFilter()
{
Filter originalFilter = new SelectorFilter("channel", "#en.wikipedia");
Filter baseTableFilter = new SelectorFilter("countryIsoCode", "CA");
List<JoinableClause> joinableClauses = ImmutableList.of(
factToRegion(JoinType.LEFT),
regionToCountry(JoinType.LEFT)
);
JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis(
originalFilter,
joinableClauses,
VirtualColumns.EMPTY
);
JoinFilterSplit expectedFilterSplit = new JoinFilterSplit(
Filters.and(Arrays.asList(originalFilter, baseTableFilter)),
null,
ImmutableSet.of()
);
JoinFilterSplit actualFilterSplit = JoinFilterAnalyzer.splitFilter(joinFilterPreAnalysis, baseTableFilter);
Assert.assertEquals(expectedFilterSplit, actualFilterSplit);
}
}

View File

@ -33,6 +33,8 @@ import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.TestQuery;
import org.apache.druid.query.extraction.MapLookupExtractor;
import org.apache.druid.query.filter.FalseDimFilter;
import org.apache.druid.query.filter.TrueDimFilter;
import org.apache.druid.query.planning.DataSourceAnalysis;
import org.apache.druid.query.planning.PreJoinableClause;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
@ -72,6 +74,7 @@ public class JoinableFactoryWrapperTest
public void test_createSegmentMapFn_noClauses()
{
final Function<SegmentReference, SegmentReference> segmentMapFn = NOOP_JOINABLE_FACTORY_WRAPPER.createSegmentMapFn(
null,
ImmutableList.of(),
new AtomicLong(),
null
@ -95,6 +98,7 @@ public class JoinableFactoryWrapperTest
expectedException.expectMessage("dataSource is not joinable");
final Function<SegmentReference, SegmentReference> ignored = NOOP_JOINABLE_FACTORY_WRAPPER.createSegmentMapFn(
null,
ImmutableList.of(clause),
new AtomicLong(),
null
@ -138,6 +142,7 @@ public class JoinableFactoryWrapperTest
}
});
final Function<SegmentReference, SegmentReference> segmentMapFn = joinableFactoryWrapper.createSegmentMapFn(
null,
ImmutableList.of(clause),
new AtomicLong(),
new TestQuery(
@ -157,6 +162,7 @@ public class JoinableFactoryWrapperTest
DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class);
DataSource dataSource = new NoopDataSource();
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.emptyList());
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty());
EasyMock.expect(analysis.getDataSource()).andReturn(dataSource);
EasyMock.replay(analysis);
JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey());
@ -172,11 +178,11 @@ public class JoinableFactoryWrapperTest
@Test
public void test_computeJoinDataSourceCacheKey_noHashJoin()
{
PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j.");
PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_2", "x != \"h.x\"", "h.");
DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class);
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Arrays.asList(clause1, clause2)).anyTimes();
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.of(TrueDimFilter.instance())).anyTimes();
EasyMock.replay(analysis);
JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey());
Optional<byte[]> cacheKey = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
@ -192,6 +198,7 @@ public class JoinableFactoryWrapperTest
PreJoinableClause clause2 = makePreJoinableClause(dataSource, "x == \"h.x\"", "h.", JoinType.LEFT);
DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class);
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Arrays.asList(clause1, clause2)).anyTimes();
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.of(TrueDimFilter.instance())).anyTimes();
EasyMock.replay(analysis);
JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey());
Optional<byte[]> cacheKey = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
@ -207,6 +214,7 @@ public class JoinableFactoryWrapperTest
PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_2", "x == \"h.x\"", "h.");
DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class);
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Arrays.asList(clause1, clause2)).anyTimes();
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
EasyMock.replay(analysis);
JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey());
Optional<byte[]> cacheKey = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
@ -218,6 +226,7 @@ public class JoinableFactoryWrapperTest
public void test_computeJoinDataSourceCacheKey_keyChangesWithExpression()
{
DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class);
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey());
PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "y == \"j.y\"", "j.");
@ -231,6 +240,7 @@ public class JoinableFactoryWrapperTest
PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j.");
EasyMock.reset(analysis);
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes();
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
EasyMock.replay(analysis);
Optional<byte[]> cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
Assert.assertTrue(cacheKey2.isPresent());
@ -242,6 +252,7 @@ public class JoinableFactoryWrapperTest
public void test_computeJoinDataSourceCacheKey_keyChangesWithJoinType()
{
DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class);
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey());
PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j.", JoinType.LEFT);
@ -255,6 +266,7 @@ public class JoinableFactoryWrapperTest
PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j.", JoinType.INNER);
EasyMock.reset(analysis);
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes();
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
EasyMock.replay(analysis);
Optional<byte[]> cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
Assert.assertTrue(cacheKey2.isPresent());
@ -266,6 +278,7 @@ public class JoinableFactoryWrapperTest
public void test_computeJoinDataSourceCacheKey_keyChangesWithPrefix()
{
DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class);
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey());
PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "abc == xyz", "ab");
@ -279,6 +292,33 @@ public class JoinableFactoryWrapperTest
PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_1", "abc == xyz", "xy");
EasyMock.reset(analysis);
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes();
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
EasyMock.replay(analysis);
Optional<byte[]> cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
Assert.assertTrue(cacheKey2.isPresent());
Assert.assertFalse(Arrays.equals(cacheKey1.get(), cacheKey2.get()));
}
@Test
public void test_computeJoinDataSourceCacheKey_keyChangesWithBaseFilter()
{
DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class);
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.of(TrueDimFilter.instance())).anyTimes();
JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey());
PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "abc == xyz", "ab");
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause1)).anyTimes();
EasyMock.replay(analysis);
Optional<byte[]> cacheKey1 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
Assert.assertTrue(cacheKey1.isPresent());
Assert.assertNotEquals(0, cacheKey1.get().length);
PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_1", "abc == xyz", "ab");
EasyMock.reset(analysis);
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes();
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.of(FalseDimFilter.instance())).anyTimes();
EasyMock.replay(analysis);
Optional<byte[]> cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
Assert.assertTrue(cacheKey2.isPresent());
@ -290,6 +330,7 @@ public class JoinableFactoryWrapperTest
public void test_computeJoinDataSourceCacheKey_keyChangesWithJoinable()
{
DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class);
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey());
PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j.");
@ -303,6 +344,8 @@ public class JoinableFactoryWrapperTest
PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_2", "x == \"j.x\"", "j.");
EasyMock.reset(analysis);
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes();
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
EasyMock.replay(analysis);
Optional<byte[]> cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
Assert.assertTrue(cacheKey2.isPresent());
@ -318,6 +361,7 @@ public class JoinableFactoryWrapperTest
PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j.");
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause1)).anyTimes();
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
EasyMock.replay(analysis);
Optional<byte[]> cacheKey1 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
@ -327,6 +371,7 @@ public class JoinableFactoryWrapperTest
PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j.");
EasyMock.reset(analysis);
EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes();
EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes();
EasyMock.replay(analysis);
Optional<byte[]> cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis);
Assert.assertTrue(cacheKey2.isPresent());

View File

@ -59,6 +59,7 @@ import org.apache.druid.query.spec.SpecificSegmentQueryRunner;
import org.apache.druid.query.spec.SpecificSegmentSpec;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.segment.StorageAdapter;
import org.apache.druid.segment.filter.Filters;
import org.apache.druid.segment.join.JoinableFactory;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.apache.druid.segment.realtime.FireHydrant;
@ -172,6 +173,7 @@ public class SinkQuerySegmentWalker implements QuerySegmentWalker
// segmentMapFn maps each base Segment into a joined Segment if necessary.
final Function<SegmentReference, SegmentReference> segmentMapFn = joinableFactoryWrapper.createSegmentMapFn(
analysis.getJoinBaseTableFilter().map(Filters::toFilter).orElse(null),
analysis.getPreJoinableClauses(),
cpuTimeAccumulator,
analysis.getBaseQuery().orElse(query)

View File

@ -36,6 +36,7 @@ import org.apache.druid.query.planning.DataSourceAnalysis;
import org.apache.druid.segment.ReferenceCountingSegment;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.segment.SegmentWrangler;
import org.apache.druid.segment.filter.Filters;
import org.apache.druid.segment.join.JoinableFactory;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.joda.time.Interval;
@ -95,6 +96,7 @@ public class LocalQuerySegmentWalker implements QuerySegmentWalker
final AtomicLong cpuAccumulator = new AtomicLong(0L);
final Function<SegmentReference, SegmentReference> segmentMapFn = joinableFactoryWrapper.createSegmentMapFn(
analysis.getJoinBaseTableFilter().map(Filters::toFilter).orElse(null),
analysis.getPreJoinableClauses(),
cpuAccumulator,
analysis.getBaseQuery().orElse(query)

View File

@ -59,6 +59,7 @@ import org.apache.druid.query.spec.SpecificSegmentSpec;
import org.apache.druid.segment.ReferenceCountingSegment;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.segment.StorageAdapter;
import org.apache.druid.segment.filter.Filters;
import org.apache.druid.segment.join.JoinableFactory;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.apache.druid.server.SegmentManager;
@ -200,6 +201,7 @@ public class ServerManager implements QuerySegmentWalker
// segmentMapFn maps each base Segment into a joined Segment if necessary.
final Function<SegmentReference, SegmentReference> segmentMapFn = joinableFactoryWrapper.createSegmentMapFn(
analysis.getJoinBaseTableFilter().map(Filters::toFilter).orElse(null),
analysis.getPreJoinableClauses(),
cpuTimeAccumulator,
analysis.getBaseQuery().orElse(query)

View File

@ -47,6 +47,7 @@ import org.apache.druid.query.spec.SpecificSegmentSpec;
import org.apache.druid.segment.ReferenceCountingSegment;
import org.apache.druid.segment.Segment;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.segment.filter.Filters;
import org.apache.druid.segment.join.JoinableFactory;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.apache.druid.timeline.TimelineObjectHolder;
@ -143,6 +144,7 @@ public class TestClusterQuerySegmentWalker implements QuerySegmentWalker
}
final Function<SegmentReference, SegmentReference> segmentMapFn = joinableFactoryWrapper.createSegmentMapFn(
analysis.getJoinBaseTableFilter().map(Filters::toFilter).orElse(null),
analysis.getPreJoinableClauses(),
new AtomicLong(),
analysis.getBaseQuery().orElse(query)

View File

@ -31,6 +31,7 @@ import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
@ -38,6 +39,7 @@ import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Sequence;
@ -45,13 +47,16 @@ import org.apache.druid.query.DataSource;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.table.RowSignatures;
import javax.annotation.Nullable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -64,6 +69,7 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
{
private static final TableDataSource DUMMY_DATA_SOURCE = new TableDataSource("__join__");
private final Filter leftFilter;
private final PartialDruidQuery partialQuery;
private final Join joinRel;
private RelNode left;
@ -73,6 +79,7 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
RelOptCluster cluster,
RelTraitSet traitSet,
Join joinRel,
Filter leftFilter,
PartialDruidQuery partialQuery,
QueryMaker queryMaker
)
@ -81,6 +88,7 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
this.joinRel = joinRel;
this.left = joinRel.getLeft();
this.right = joinRel.getRight();
this.leftFilter = leftFilter;
this.partialQuery = partialQuery;
}
@ -89,6 +97,7 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
*/
public static DruidJoinQueryRel create(
final Join joinRel,
final Filter leftFilter,
final QueryMaker queryMaker
)
{
@ -96,6 +105,7 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
joinRel.getCluster(),
joinRel.getTraitSet(),
joinRel,
leftFilter,
PartialDruidQuery.create(joinRel),
queryMaker
);
@ -125,6 +135,7 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
getCluster(),
getTraitSet().plusAll(newQueryBuilder.getRelTraits()),
joinRel,
leftFilter,
newQueryBuilder,
getQueryMaker()
);
@ -145,6 +156,9 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
if (computeLeftRequiresSubquery(leftDruidRel)) {
leftDataSource = new QueryDataSource(leftQuery.getQuery());
if (leftFilter != null) {
throw new ISE("Filter on left table is supposed to be null if left child is a query source");
}
} else {
leftDataSource = leftQuery.getDataSource();
}
@ -177,6 +191,7 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
prefixSignaturePair.lhs,
condition.getExpression(),
toDruidJoinType(joinRel.getJoinType()),
getDimFilter(getPlannerContext(), leftSignature, leftFilter),
getPlannerContext().getExprMacroTable()
),
prefixSignaturePair.rhs,
@ -214,6 +229,7 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
.map(input -> RelOptRule.convert(input, DruidConvention.instance()))
.collect(Collectors.toList())
),
leftFilter,
partialQuery,
getQueryMaker()
);
@ -252,6 +268,7 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
getCluster(),
traitSet,
joinRel.copy(joinRel.getTraitSet(), inputs),
leftFilter,
getPartialDruidQuery(),
getQueryMaker()
);
@ -312,7 +329,7 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
return planner.getCostFactory().makeCost(cost, 0, 0);
}
private static JoinType toDruidJoinType(JoinRelType calciteJoinType)
public static JoinType toDruidJoinType(JoinRelType calciteJoinType)
{
switch (calciteJoinType) {
case LEFT:
@ -377,4 +394,28 @@ public class DruidJoinQueryRel extends DruidRel<DruidJoinQueryRel>
return (DruidRel<?>) Iterables.getFirst(subset.getRels(), null);
}
}
@Nullable
private static DimFilter getDimFilter(
final PlannerContext plannerContext,
final RowSignature rowSignature,
@Nullable final Filter filter
)
{
if (filter == null) {
return null;
}
final RexNode condition = filter.getCondition();
final DimFilter dimFilter = Expressions.toFilter(
plannerContext,
rowSignature,
null,
condition
);
if (dimFilter == null) {
throw new CannotBuildQueryException(filter, condition);
} else {
return dimFilter;
}
}
}

View File

@ -19,6 +19,7 @@
package org.apache.druid.sql.calcite.rel;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSortedMap;
@ -42,9 +43,11 @@ import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.aggregation.PostAggregator;
@ -648,6 +651,52 @@ public class DruidQuery
return VirtualColumns.create(columns);
}
/**
* Returns a pair of DataSource and Filtration object created on the query filter. In case the, data source is
* a join datasource, the datasource may be altered and left filter of join datasource may
* be rid of time filters.
* TODO: should we optimize the base table filter just like we do with query filters
*/
@VisibleForTesting
static Pair<DataSource, Filtration> getFiltration(
DataSource dataSource,
DimFilter filter,
VirtualColumnRegistry virtualColumnRegistry
)
{
if (!(dataSource instanceof JoinDataSource)) {
return Pair.of(dataSource, toFiltration(filter, virtualColumnRegistry));
}
JoinDataSource joinDataSource = (JoinDataSource) dataSource;
if (joinDataSource.getLeftFilter() == null) {
return Pair.of(dataSource, toFiltration(filter, virtualColumnRegistry));
}
//TODO: We should avoid promoting the time filter as interval for right outer and full outer joins. This is not
// done now as we apply the intervals to left base table today irrespective of the join type.
// If the join is left or inner, we can pull the intervals up to the query. This is done
// so that broker can prune the segments to query.
Filtration leftFiltration = Filtration.create(joinDataSource.getLeftFilter())
.optimize(virtualColumnRegistry.getFullRowSignature());
// Adds the intervals from the join left filter to query filtration
Filtration queryFiltration = Filtration.create(filter, leftFiltration.getIntervals())
.optimize(virtualColumnRegistry.getFullRowSignature());
JoinDataSource newDataSource = JoinDataSource.create(
joinDataSource.getLeft(),
joinDataSource.getRight(),
joinDataSource.getRightPrefix(),
joinDataSource.getConditionAnalysis(),
joinDataSource.getJoinType(),
leftFiltration.getDimFilter()
);
return Pair.of(newDataSource, queryFiltration);
}
private static Filtration toFiltration(DimFilter filter, VirtualColumnRegistry virtualColumnRegistry)
{
return Filtration.create(filter).optimize(virtualColumnRegistry.getFullRowSignature());
}
public DataSource getDataSource()
{
return dataSource;
@ -793,7 +842,13 @@ public class DruidQuery
return null;
}
final Filtration filtration = Filtration.create(filter).optimize(virtualColumnRegistry.getFullRowSignature());
final Pair<DataSource, Filtration> dataSourceFiltrationPair = getFiltration(
dataSource,
filter,
virtualColumnRegistry
);
final DataSource newDataSource = dataSourceFiltrationPair.lhs;
final Filtration filtration = dataSourceFiltrationPair.rhs;
final List<PostAggregator> postAggregators = new ArrayList<>(grouping.getPostAggregators());
if (sorting != null && sorting.getProjection() != null) {
@ -801,7 +856,7 @@ public class DruidQuery
}
return new TimeseriesQuery(
dataSource,
newDataSource,
filtration.getQuerySegmentSpec(),
descending,
getVirtualColumns(false),
@ -872,7 +927,13 @@ public class DruidQuery
return null;
}
final Filtration filtration = Filtration.create(filter).optimize(virtualColumnRegistry.getFullRowSignature());
final Pair<DataSource, Filtration> dataSourceFiltrationPair = getFiltration(
dataSource,
filter,
virtualColumnRegistry
);
final DataSource newDataSource = dataSourceFiltrationPair.lhs;
final Filtration filtration = dataSourceFiltrationPair.rhs;
final List<PostAggregator> postAggregators = new ArrayList<>(grouping.getPostAggregators());
if (sorting.getProjection() != null) {
@ -880,7 +941,7 @@ public class DruidQuery
}
return new TopNQuery(
dataSource,
newDataSource,
getVirtualColumns(true),
dimensionSpec,
topNMetricSpec,
@ -911,7 +972,13 @@ public class DruidQuery
return null;
}
final Filtration filtration = Filtration.create(filter).optimize(virtualColumnRegistry.getFullRowSignature());
final Pair<DataSource, Filtration> dataSourceFiltrationPair = getFiltration(
dataSource,
filter,
virtualColumnRegistry
);
final DataSource newDataSource = dataSourceFiltrationPair.lhs;
final Filtration filtration = dataSourceFiltrationPair.rhs;
final DimFilterHavingSpec havingSpec;
if (grouping.getHavingFilter() != null) {
@ -930,7 +997,7 @@ public class DruidQuery
}
return new GroupByQuery(
dataSource,
newDataSource,
filtration.getQuerySegmentSpec(),
getVirtualColumns(true),
filtration.getDimFilter(),
@ -963,7 +1030,14 @@ public class DruidQuery
throw new ISE("Cannot convert to Scan query without any columns.");
}
final Filtration filtration = Filtration.create(filter).optimize(virtualColumnRegistry.getFullRowSignature());
final Pair<DataSource, Filtration> dataSourceFiltrationPair = getFiltration(
dataSource,
filter,
virtualColumnRegistry
);
final DataSource newDataSource = dataSourceFiltrationPair.lhs;
final Filtration filtration = dataSourceFiltrationPair.rhs;
final ScanQuery.Order order;
long scanOffset = 0L;
long scanLimit = 0L;
@ -1008,7 +1082,7 @@ public class DruidQuery
}
return new ScanQuery(
dataSource,
newDataSource,
filtration.getQuerySegmentSpec(),
getVirtualColumns(true),
ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST,

View File

@ -26,6 +26,7 @@ import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
@ -97,18 +98,19 @@ public class DruidJoinRule extends RelOptRule
final DruidRel<?> newLeft;
final DruidRel<?> newRight;
final Filter leftFilter;
final List<RexNode> newProjectExprs = new ArrayList<>();
// Already verified to be present in "matches", so just call "get".
// Can't be final, because we're going to reassign it up to a couple of times.
ConditionAnalysis conditionAnalysis = analyzeCondition(join.getCondition(), join.getLeft().getRowType()).get();
if (left.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT
&& left.getPartialDruidQuery().getWhereFilter() == null) {
if (left.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT) {
// Swap the left-side projection above the join, so the left side is a simple scan or mapping. This helps us
// avoid subqueries.
final RelNode leftScan = left.getPartialDruidQuery().getScan();
final Project leftProject = left.getPartialDruidQuery().getSelectProject();
leftFilter = left.getPartialDruidQuery().getWhereFilter();
// Left-side projection expressions rewritten to be on top of the join.
newProjectExprs.addAll(leftProject.getProjects());
@ -121,6 +123,7 @@ public class DruidJoinRule extends RelOptRule
}
newLeft = left;
leftFilter = null;
}
if (right.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT
@ -163,6 +166,7 @@ public class DruidJoinRule extends RelOptRule
join.getJoinType(),
join.isSemiJoinDone()
),
leftFilter,
left.getQueryMaker()
);

View File

@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.druid.annotations.UsedByJUnitParamsRunner;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.hll.VersionOneHyperLogLogCollector;
import org.apache.druid.java.util.common.DateTimes;
@ -374,7 +375,8 @@ public class BaseCalciteQueryTest extends CalciteTestBase
DataSource right,
String rightPrefix,
String condition,
JoinType joinType
JoinType joinType,
DimFilter filter
)
{
return JoinDataSource.create(
@ -383,10 +385,22 @@ public class BaseCalciteQueryTest extends CalciteTestBase
rightPrefix,
condition,
joinType,
filter,
CalciteTests.createExprMacroTable()
);
}
public static JoinDataSource join(
DataSource left,
DataSource right,
String rightPrefix,
String condition,
JoinType joinType
)
{
return join(left, right, rightPrefix, condition, joinType, null);
}
public static String equalsCondition(DruidExpression left, DruidExpression right)
{
return StringUtils.format("(%s == %s)", left.getExpression(), right.getExpression());
@ -828,4 +842,39 @@ public class BaseCalciteQueryTest extends CalciteTestBase
{
skipVectorize = true;
}
/**
* This is a provider of query contexts that should be used by join tests.
* It tests various configs that can be passed to join queries. All the configs provided by this provider should
* have the join query engine return the same results.
*/
public static class QueryContextForJoinProvider
{
@UsedByJUnitParamsRunner
public static Object[] provideQueryContexts()
{
return new Object[]{
// default behavior
QUERY_CONTEXT_DEFAULT,
// filter value re-writes enabled
new ImmutableMap.Builder<String, Object>()
.putAll(QUERY_CONTEXT_DEFAULT)
.put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true)
.put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, true)
.build(),
// rewrite values enabled but filter re-writes disabled.
// This should be drive the same behavior as the previous config
new ImmutableMap.Builder<String, Object>()
.putAll(QUERY_CONTEXT_DEFAULT)
.put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true)
.put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false)
.build(),
// filter re-writes disabled
new ImmutableMap.Builder<String, Object>()
.putAll(QUERY_CONTEXT_DEFAULT)
.put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false)
.build(),
};
}
}
}

View File

@ -0,0 +1,481 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.data.input.MapBasedInputRow;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.AllGranularity;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory;
import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import java.io.File;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@RunWith(JUnitParamsRunner.class)
public class CalciteCorrelatedQueryTest extends BaseCalciteQueryTest
{
private static final IncrementalIndexSchema INDEX_SCHEMA = new IncrementalIndexSchema.Builder()
.withMetrics(
new CountAggregatorFactory("cnt")
)
.withRollup(false)
.withMinTimestamp(DateTimes.of("2020-12-31").getMillis())
.build();
private static final List<String> DIMENSIONS = ImmutableList.of("user", "country", "city");
@Before
public void setup() throws Exception
{
final QueryableIndex index1 = IndexBuilder
.create()
.tmpDir(new File(temporaryFolder.newFolder(), "1"))
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(INDEX_SCHEMA)
.rows(getRawRows())
.buildMMappedIndex();
final DataSegment segment = DataSegment.builder()
.dataSource("visits")
.interval(index1.getDataInterval())
.version("1")
.shardSpec(new LinearShardSpec(0))
.size(0)
.build();
walker.add(segment, index1);
}
@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testCorrelatedSubquery(Map<String, Object> queryContext) throws Exception
{
cannotVectorize();
testQuery(
"select country, ANY_VALUE(\n"
+ " select avg(\"users\") from (\n"
+ " select floor(__time to day), count(distinct user) \"users\" from visits f where f.country = visits.country group by 1\n"
+ " )\n"
+ " ) as \"DAU\"\n"
+ "from visits \n"
+ "group by 1",
queryContext,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
join(
new TableDataSource("visits"),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource("visits")
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setVirtualColumns(new ExpressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG,
TestExprMacroTable.INSTANCE
))
.setDimFilter(not(selector("country", null, null)))
.setDimensions(
new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
),
new DefaultDimensionSpec(
"country",
"d1"
)
)
.setAggregatorSpecs(new CardinalityAggregatorFactory(
"a0:a",
null,
Collections.singletonList(
new DefaultDimensionSpec(
"user",
"user"
)),
false,
true
))
.setPostAggregatorSpecs(Collections.singletonList(new HyperUniqueFinalizingPostAggregator(
"a0",
"a0:a"
)))
.setContext(queryContext)
.setGranularity(new AllGranularity())
.build()
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
.setPostAggregatorSpecs(Collections.singletonList(new ArithmeticPostAggregator(
"_a0",
"quotient",
Arrays
.asList(
new FieldAccessPostAggregator(null, "_a0:sum"),
new FieldAccessPostAggregator(null, "_a0:count")
)
)))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
"j0.",
equalsCondition(
DruidExpression.fromColumn("country"),
DruidExpression.fromColumn("j0._d0")
),
JoinType.LEFT
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"India", 2L},
new Object[]{"USA", 1L},
new Object[]{"canada", 3L}
)
);
}
@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testCorrelatedSubqueryWithLeftFilter(Map<String, Object> queryContext) throws Exception
{
cannotVectorize();
testQuery(
"select country, ANY_VALUE(\n"
+ " select max(\"users\") from (\n"
+ " select floor(__time to day), count(*) \"users\" from visits f where f.country = visits.country group by 1\n"
+ " )\n"
+ " ) as \"dailyVisits\"\n"
+ "from visits \n"
+ " where city = 'B' and __time between '2021-01-01 01:00:00' AND '2021-01-02 23:59:59'"
+ " group by 1",
queryContext,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
join(
new TableDataSource("visits"),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource("visits")
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setVirtualColumns(new ExpressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG,
TestExprMacroTable.INSTANCE
))
.setDimFilter(not(selector("country", null, null)))
.setDimensions(
new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
),
new DefaultDimensionSpec(
"country",
"d1"
)
)
.setAggregatorSpecs(new CountAggregatorFactory("a0"))
.setContext(queryContext)
.setGranularity(new AllGranularity())
.build()
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongMaxAggregatorFactory("_a0", "a0")
)
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
"j0.",
equalsCondition(
DruidExpression.fromColumn("country"),
DruidExpression.fromColumn("j0._d0")
),
JoinType.LEFT,
selector("city", "B", null)
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.of(
"2021-01-01T01:00:00.000Z/2021-01-02T23:59:59.001Z")))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"canada", 4L}
)
);
}
@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testCorrelatedSubqueryWithCorrelatedQueryFilter(Map<String, Object> queryContext) throws Exception
{
cannotVectorize();
testQuery(
"select country, ANY_VALUE(\n"
+ " select max(\"users\") from (\n"
+ " select floor(__time to day), count(user) \"users\" from visits f where f.country = visits.country and f.city = 'A' group by 1\n"
+ " )\n"
+ " ) as \"dailyVisits\"\n"
+ "from visits \n"
+ " where city = 'B'"
+ " group by 1",
queryContext,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
join(
new TableDataSource("visits"),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource("visits")
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setVirtualColumns(new ExpressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG,
TestExprMacroTable.INSTANCE
))
.setDimensions(
new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
),
new DefaultDimensionSpec(
"country",
"d1"
)
)
.setAggregatorSpecs(new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
not(selector("user", null, null))
))
.setDimFilter(and(
selector("city", "A", null),
not(selector("country", null, null))
))
.setContext(queryContext)
.setGranularity(new AllGranularity())
.build()
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongMaxAggregatorFactory("_a0", "a0")
)
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
"j0.",
equalsCondition(
DruidExpression.fromColumn("country"),
DruidExpression.fromColumn("j0._d0")
),
JoinType.LEFT,
selector("city", "B", null)
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"canada", 2L}
)
);
}
@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testCorrelatedSubqueryWithCorrelatedQueryFilter_Scan(Map<String, Object> queryContext) throws Exception
{
cannotVectorize();
testQuery(
"select country, ANY_VALUE(\n"
+ " select max(\"users\") from (\n"
+ " select floor(__time to day), count(user) \"users\" from visits f where f.country = visits.country and f.city = 'A' group by 1\n"
+ " )\n"
+ " ) as \"dailyVisits\"\n"
+ "from visits \n"
+ " where city = 'B'"
+ " group by 1",
queryContext,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
join(
new TableDataSource("visits"),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource("visits")
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setVirtualColumns(new ExpressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG,
TestExprMacroTable.INSTANCE
))
.setDimensions(
new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
),
new DefaultDimensionSpec(
"country",
"d1"
)
)
.setAggregatorSpecs(new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
not(selector("user", null, null))
))
.setDimFilter(and(
selector("city", "A", null),
not(selector("country", null, null))
))
.setContext(queryContext)
.setGranularity(new AllGranularity())
.build()
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongMaxAggregatorFactory("_a0", "a0")
)
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
"j0.",
equalsCondition(
DruidExpression.fromColumn("country"),
DruidExpression.fromColumn("j0._d0")
),
JoinType.LEFT,
selector("city", "B", null)
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"canada", 2L}
)
);
}
private List<InputRow> getRawRows()
{
return ImmutableList.of(
toRow("2021-01-01T01:00:00Z", ImmutableMap.of("user", "alice", "country", "canada", "city", "A")),
toRow("2021-01-01T02:00:00Z", ImmutableMap.of("user", "alice", "country", "canada", "city", "B")),
toRow("2021-01-01T03:00:00Z", ImmutableMap.of("user", "bob", "country", "canada", "city", "A")),
toRow("2021-01-01T04:00:00Z", ImmutableMap.of("user", "alice", "country", "India", "city", "Y")),
toRow("2021-01-02T01:00:00Z", ImmutableMap.of("user", "alice", "country", "canada", "city", "A")),
toRow("2021-01-02T02:00:00Z", ImmutableMap.of("user", "bob", "country", "canada", "city", "A")),
toRow("2021-01-02T03:00:00Z", ImmutableMap.of("user", "foo", "country", "canada", "city", "B")),
toRow("2021-01-02T04:00:00Z", ImmutableMap.of("user", "bar", "country", "canada", "city", "B")),
toRow("2021-01-02T05:00:00Z", ImmutableMap.of("user", "alice", "country", "India", "city", "X")),
toRow("2021-01-02T06:00:00Z", ImmutableMap.of("user", "bob", "country", "India", "city", "X")),
toRow("2021-01-02T07:00:00Z", ImmutableMap.of("user", "foo", "country", "India", "city", "X")),
toRow("2021-01-03T01:00:00Z", ImmutableMap.of("user", "foo", "country", "USA", "city", "M"))
);
}
private MapBasedInputRow toRow(String time, Map<String, Object> event)
{
return new MapBasedInputRow(DateTimes.ISO_DATE_OPTIONAL_TIME.parse(time), DIMENSIONS, event);
}
}

View File

@ -25,7 +25,6 @@ import com.google.common.collect.ImmutableMap;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.druid.annotations.UsedByJUnitParamsRunner;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.IAE;
@ -85,7 +84,6 @@ import org.apache.druid.query.filter.BoundDimFilter;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.LikeDimFilter;
import org.apache.druid.query.filter.NotDimFilter;
import org.apache.druid.query.filter.RegexDimFilter;
import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.query.groupby.GroupByQuery;
@ -5365,14 +5363,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.dataSource(
join(
join(
new QueryDataSource(
newScanQueryBuilder().dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("dim1", "dim2")
.filters(selector("dim2", "a", null))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
newScanQueryBuilder().dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
@ -5382,7 +5373,8 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
),
"j0.",
"(\"dim2\" == \"j0.dim2\")",
JoinType.INNER
JoinType.INNER,
bound("dim2", "a", "a", false, false, null, null)
),
new QueryDataSource(
newScanQueryBuilder().dataSource(CalciteTests.DATASOURCE1)
@ -13227,15 +13219,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.builder()
.setDataSource(
join(
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.of("2001-01-02T00:00:00.000Z/146140482-04-24T15:36:27.903Z")))
.columns("dim1")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(queryContext)
.build()
),
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
@ -13254,7 +13238,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
)
)
.setGranularity(Granularities.ALL)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setInterval(querySegmentSpec(Intervals.of("2001-01-02T00:00:00.000Z/146140482-04-24T15:36:27.903Z")))
.setDimFilter(selector("dim1", "def", null))
.setDimensions(
dimensions(
@ -16037,24 +16021,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
newScanQueryBuilder()
.dataSource(
join(
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(
querySegmentSpec(
Intervals.utc(
DateTimes.of("1999-01-01").getMillis(),
JodaUtils.MAX_INSTANT
)
)
)
.filters(new SelectorDimFilter("dim1", "10.1", null))
.virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
.columns(ImmutableList.of("__time", "v0"))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(queryContext)
.build()
),
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
@ -16074,14 +16041,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()
),
"j0.",
equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.v0")),
JoinType.LEFT
equalsCondition(DruidExpression.fromExpression("'10.1'"), DruidExpression.fromColumn("j0.v0")),
JoinType.LEFT,
selector("dim1", "10.1", null)
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING))
.columns("__time", "_v0")
.filters(new SelectorDimFilter("v0", "10.1", null))
.intervals(querySegmentSpec(
Intervals.utc(
DateTimes.of("1999-01-01").getMillis(),
JodaUtils.MAX_INSTANT
)
))
.virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
.columns("__time", "v0")
.context(queryContext)
.build()
),
@ -16106,17 +16078,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
newScanQueryBuilder()
.dataSource(
join(
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.filters(new SelectorDimFilter("dim1", "10.1", null))
.virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
.columns(ImmutableList.of("__time", "v0"))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(queryContext)
.build()
),
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
@ -16128,14 +16090,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()
),
"j0.",
equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.dim1")),
JoinType.LEFT
equalsCondition(DruidExpression.fromExpression("'10.1'"), DruidExpression.fromColumn("j0.dim1")),
JoinType.LEFT,
selector("dim1", "10.1", null)
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING))
.columns("__time", "_v0")
.filters(new SelectorDimFilter("v0", "10.1", null))
.virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
.columns("__time", "v0")
.context(queryContext)
.build()
),
@ -16160,17 +16122,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
newScanQueryBuilder()
.dataSource(
join(
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.filters(new SelectorDimFilter("dim1", "10.1", null))
.virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
.columns(ImmutableList.of("__time", "v0"))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(queryContext)
.build()
),
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
@ -16182,13 +16134,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()
),
"j0.",
equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.dim1")),
JoinType.LEFT
equalsCondition(DruidExpression.fromExpression("'10.1'"), DruidExpression.fromColumn("j0.dim1")),
JoinType.LEFT,
selector("dim1", "10.1", null)
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING))
.columns("__time", "_v0")
.virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
.columns("__time", "v0")
.context(queryContext)
.build()
),
@ -16213,17 +16166,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
newScanQueryBuilder()
.dataSource(
join(
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.filters(new SelectorDimFilter("dim1", "10.1", null))
.virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
.columns(ImmutableList.of("__time", "v0"))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(queryContext)
.build()
),
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
@ -16235,14 +16178,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()
),
"j0.",
equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.dim1")),
JoinType.INNER
equalsCondition(DruidExpression.fromExpression("'10.1'"), DruidExpression.fromColumn("j0.dim1")),
JoinType.INNER,
selector("dim1", "10.1", null)
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING))
.columns("__time", "_v0")
.filters(new NotDimFilter(new SelectorDimFilter("v0", null, null)))
.virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
.columns("__time", "v0")
.context(queryContext)
.build()
),
@ -16267,17 +16210,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
newScanQueryBuilder()
.dataSource(
join(
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.filters(new SelectorDimFilter("dim1", "10.1", null))
.virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
.columns(ImmutableList.of("__time", "v0"))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(queryContext)
.build()
),
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
@ -16289,13 +16222,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()
),
"j0.",
equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.dim1")),
JoinType.INNER
equalsCondition(DruidExpression.fromExpression("'10.1'"), DruidExpression.fromColumn("j0.dim1")),
JoinType.INNER,
selector("dim1", "10.1", null)
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING))
.columns("__time", "_v0")
.virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
.columns("__time", "v0")
.context(queryContext)
.build()
),
@ -16710,38 +16644,4 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
/**
* This is a provider of query contexts that should be used by join tests.
* It tests various configs that can be passed to join queries. All the configs provided by this provider should
* have the join query engine return the same results.
*/
public static class QueryContextForJoinProvider
{
@UsedByJUnitParamsRunner
public static Object[] provideQueryContexts()
{
return new Object[] {
// default behavior
QUERY_CONTEXT_DEFAULT,
// filter value re-writes enabled
new ImmutableMap.Builder<String, Object>()
.putAll(QUERY_CONTEXT_DEFAULT)
.put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true)
.put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, true)
.build(),
// rewrite values enabled but filter re-writes disabled.
// This should be drive the same behavior as the previous config
new ImmutableMap.Builder<String, Object>()
.putAll(QUERY_CONTEXT_DEFAULT)
.put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true)
.put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false)
.build(),
// filter re-writes disabled
new ImmutableMap.Builder<String, Object>()
.putAll(QUERY_CONTEXT_DEFAULT)
.put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false)
.build(),
};
}
}
}

View File

@ -0,0 +1,187 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.rel;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.BoundDimFilter;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Test;
import java.util.Collections;
public class DruidQueryTest
{
static {
NullHandling.initializeForTests();
}
private final DimFilter selectorFilter = new SelectorDimFilter("column", "value", null);
private final DimFilter otherFilter = new SelectorDimFilter("column_2", "value_2", null);
private final DimFilter filterWithInterval = new AndDimFilter(
selectorFilter,
new BoundDimFilter("__time", "100", "200", false, true, null, null, StringComparators.NUMERIC)
);
@Test
public void test_filtration_noJoinAndInterval()
{
DataSource dataSource = new TableDataSource("test");
Pair<DataSource, Filtration> pair = DruidQuery.getFiltration(
dataSource,
selectorFilter,
VirtualColumnRegistry.create(RowSignature.empty())
);
verify(pair, dataSource, selectorFilter, Intervals.ETERNITY);
}
@Test
public void test_filtration_intervalInQueryFilter()
{
DataSource dataSource = new TableDataSource("test");
Pair<DataSource, Filtration> pair = DruidQuery.getFiltration(
dataSource,
filterWithInterval,
VirtualColumnRegistry.create(RowSignature.empty())
);
verify(pair, dataSource, selectorFilter, Intervals.utc(100, 200));
}
@Test
public void test_filtration_joinDataSource_intervalInQueryFilter()
{
DataSource dataSource = join(JoinType.INNER, otherFilter);
Pair<DataSource, Filtration> pair = DruidQuery.getFiltration(
dataSource,
filterWithInterval,
VirtualColumnRegistry.create(RowSignature.empty())
);
verify(pair, dataSource, selectorFilter, Intervals.utc(100, 200));
}
@Test
public void test_filtration_joinDataSource_intervalInBaseTableFilter_inner()
{
DataSource dataSource = join(JoinType.INNER, filterWithInterval);
DataSource expectedDataSource = join(JoinType.INNER, selectorFilter);
Pair<DataSource, Filtration> pair = DruidQuery.getFiltration(
dataSource,
otherFilter,
VirtualColumnRegistry.create(RowSignature.empty())
);
verify(pair, expectedDataSource, otherFilter, Intervals.utc(100, 200));
}
@Test
public void test_filtration_joinDataSource_intervalInBaseTableFilter_left()
{
DataSource dataSource = join(JoinType.LEFT, filterWithInterval);
DataSource expectedDataSource = join(JoinType.LEFT, selectorFilter);
Pair<DataSource, Filtration> pair = DruidQuery.getFiltration(
dataSource,
otherFilter,
VirtualColumnRegistry.create(RowSignature.empty())
);
verify(pair, expectedDataSource, otherFilter, Intervals.utc(100, 200));
}
@Test
public void test_filtration_joinDataSource_intervalInBaseTableFilter_right()
{
DataSource dataSource = join(JoinType.RIGHT, filterWithInterval);
DataSource expectedDataSource = join(JoinType.RIGHT, selectorFilter);
Pair<DataSource, Filtration> pair = DruidQuery.getFiltration(
dataSource,
otherFilter,
VirtualColumnRegistry.create(RowSignature.empty())
);
verify(pair, expectedDataSource, otherFilter, Intervals.utc(100, 200));
}
@Test
public void test_filtration_joinDataSource_intervalInBaseTableFilter_full()
{
DataSource dataSource = join(JoinType.FULL, filterWithInterval);
DataSource expectedDataSource = join(JoinType.FULL, selectorFilter);
Pair<DataSource, Filtration> pair = DruidQuery.getFiltration(
dataSource,
otherFilter,
VirtualColumnRegistry.create(RowSignature.empty())
);
verify(pair, expectedDataSource, otherFilter, Intervals.utc(100, 200));
}
@Test
public void test_filtration_intervalsInBothFilters()
{
DataSource dataSource = join(JoinType.INNER, filterWithInterval);
DataSource expectedDataSource = join(JoinType.INNER, selectorFilter);
DimFilter queryFilter = new AndDimFilter(
otherFilter,
new BoundDimFilter("__time", "150", "250", false, true, null, null, StringComparators.NUMERIC)
);
Pair<DataSource, Filtration> pair = DruidQuery.getFiltration(
dataSource,
queryFilter,
VirtualColumnRegistry.create(RowSignature.empty())
);
verify(pair, expectedDataSource, otherFilter, Intervals.utc(150, 200));
}
private JoinDataSource join(JoinType joinType, DimFilter filter)
{
return JoinDataSource.create(
new TableDataSource("left"),
new TableDataSource("right"),
"r.",
"c == \"r.c\"",
joinType,
filter,
ExprMacroTable.nil()
);
}
private void verify(
Pair<DataSource, Filtration> pair,
DataSource dataSource,
DimFilter columnFilter,
Interval interval
)
{
Assert.assertEquals(dataSource, pair.lhs);
Assert.assertEquals("dim-filter: " + pair.rhs.getDimFilter(), columnFilter, pair.rhs.getDimFilter());
Assert.assertEquals(Collections.singletonList(interval), pair.rhs.getIntervals());
}
}