From 1a15987432fab044fc83a99af8ab497d887aeda1 Mon Sep 17 00:00:00 2001 From: Abhishek Agarwal <1477457+abhishekagarwal87@users.noreply.github.com> Date: Fri, 5 Mar 2021 00:09:21 +0530 Subject: [PATCH] Supporting filters in the left base table for join datasources (#10697) * where filter left first draft * Revert changes in calcite test * Refactor a bit * Fixing the Tests * Changes * Adding tests * Add tests for correlated queries * Add comment * Fix typos --- .../IndexedTableJoinCursorBenchmark.java | 1 + .../benchmark/JoinAndLookupBenchmark.java | 4 + .../apache/druid/query/JoinDataSource.java | 58 ++- .../java/org/apache/druid/query/Queries.java | 6 +- .../query/planning/DataSourceAnalysis.java | 31 +- .../druid/segment/join/HashJoinSegment.java | 11 +- .../join/HashJoinSegmentStorageAdapter.java | 28 +- .../apache/druid/segment/join/JoinType.java | 20 +- .../segment/join/JoinableFactoryWrapper.java | 9 +- .../join/filter/JoinFilterAnalyzer.java | 26 +- .../druid/query/JoinDataSourceTest.java | 47 +- .../org/apache/druid/query/QueriesTest.java | 77 +++ .../planning/DataSourceAnalysisTest.java | 128 ++++- .../HashJoinSegmentStorageAdapterTest.java | 203 ++++++++ .../segment/join/HashJoinSegmentTest.java | 2 + .../segment/join/JoinFilterAnalyzerTest.java | 27 + .../join/JoinableFactoryWrapperTest.java | 47 +- .../appenderator/SinkQuerySegmentWalker.java | 2 + .../druid/server/LocalQuerySegmentWalker.java | 2 + .../server/coordination/ServerManager.java | 2 + .../server/TestClusterQuerySegmentWalker.java | 2 + .../sql/calcite/rel/DruidJoinQueryRel.java | 43 +- .../druid/sql/calcite/rel/DruidQuery.java | 90 +++- .../druid/sql/calcite/rule/DruidJoinRule.java | 8 +- .../sql/calcite/BaseCalciteQueryTest.java | 51 +- .../calcite/CalciteCorrelatedQueryTest.java | 481 ++++++++++++++++++ .../druid/sql/calcite/CalciteQueryTest.java | 182 ++----- .../druid/sql/calcite/rel/DruidQueryTest.java | 187 +++++++ 28 files changed, 1577 insertions(+), 198 deletions(-) create mode 100644 sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java create mode 100644 sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidQueryTest.java diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java index 839f7608a36..1ed1c37bc3a 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java @@ -197,6 +197,7 @@ public class IndexedTableJoinCursorBenchmark hashJoinSegment = closer.register( new HashJoinSegment( ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment), + null, clauses, preAnalysis ) diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java index 72b92ff711f..8d68e6b79c5 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java @@ -160,6 +160,7 @@ public class JoinAndLookupBenchmark hashJoinLookupStringKeySegment = new HashJoinSegment( ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment), + null, joinableClausesLookupStringKey, preAnalysisLookupStringKey ); @@ -194,6 +195,7 @@ public class JoinAndLookupBenchmark hashJoinLookupLongKeySegment = new HashJoinSegment( ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment), + null, joinableClausesLookupLongKey, preAnalysisLookupLongKey ); @@ -228,6 +230,7 @@ public class JoinAndLookupBenchmark hashJoinIndexedTableStringKeySegment = new HashJoinSegment( ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment), + null, joinableClausesIndexedTableStringKey, preAnalysisIndexedStringKey ); @@ -262,6 +265,7 @@ public class JoinAndLookupBenchmark hashJoinIndexedTableLongKeySegment = new HashJoinSegment( ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment), + null, joinableClausesIndexedTableLongKey, preAnalysisIndexedLongKey ); diff --git a/processing/src/main/java/org/apache/druid/query/JoinDataSource.java b/processing/src/main/java/org/apache/druid/query/JoinDataSource.java index b5856e26fcf..93438a5ec9b 100644 --- a/processing/src/main/java/org/apache/druid/query/JoinDataSource.java +++ b/processing/src/main/java/org/apache/druid/query/JoinDataSource.java @@ -27,10 +27,12 @@ import com.google.common.collect.ImmutableList; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.query.filter.DimFilter; import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.segment.join.JoinPrefixUtils; import org.apache.druid.segment.join.JoinType; +import javax.annotation.Nullable; import java.util.HashSet; import java.util.List; import java.util.Objects; @@ -58,13 +60,17 @@ public class JoinDataSource implements DataSource private final String rightPrefix; private final JoinConditionAnalysis conditionAnalysis; private final JoinType joinType; + // An optional filter on the left side if left is direct table access + @Nullable + private final DimFilter leftFilter; private JoinDataSource( DataSource left, DataSource right, String rightPrefix, JoinConditionAnalysis conditionAnalysis, - JoinType joinType + JoinType joinType, + @Nullable DimFilter leftFilter ) { this.left = Preconditions.checkNotNull(left, "left"); @@ -72,6 +78,12 @@ public class JoinDataSource implements DataSource this.rightPrefix = JoinPrefixUtils.validatePrefix(rightPrefix); this.conditionAnalysis = Preconditions.checkNotNull(conditionAnalysis, "conditionAnalysis"); this.joinType = Preconditions.checkNotNull(joinType, "joinType"); + //TODO: Add support for union data sources + Preconditions.checkArgument( + leftFilter == null || left instanceof TableDataSource, + "left filter is only supported if left data source is direct table access" + ); + this.leftFilter = leftFilter; } /** @@ -84,6 +96,7 @@ public class JoinDataSource implements DataSource @JsonProperty("rightPrefix") String rightPrefix, @JsonProperty("condition") String condition, @JsonProperty("joinType") JoinType joinType, + @Nullable @JsonProperty("leftFilter") DimFilter leftFilter, @JacksonInject ExprMacroTable macroTable ) { @@ -96,7 +109,8 @@ public class JoinDataSource implements DataSource StringUtils.nullToEmptyNonDruidDataString(rightPrefix), macroTable ), - joinType + joinType, + leftFilter ); } @@ -108,10 +122,26 @@ public class JoinDataSource implements DataSource final DataSource right, final String rightPrefix, final JoinConditionAnalysis conditionAnalysis, - final JoinType joinType + final JoinType joinType, + final DimFilter leftFilter ) { - return new JoinDataSource(left, right, rightPrefix, conditionAnalysis, joinType); + return new JoinDataSource(left, right, rightPrefix, conditionAnalysis, joinType, leftFilter); + } + + /** + * Create a join dataSource from an existing {@link JoinConditionAnalysis}. + */ + public static JoinDataSource create( + final DataSource left, + final DataSource right, + final String rightPrefix, + final String condition, + final JoinType joinType, + final ExprMacroTable macroTable + ) + { + return create(left, right, rightPrefix, condition, joinType, null, macroTable); } @Override @@ -158,6 +188,13 @@ public class JoinDataSource implements DataSource return joinType; } + @JsonProperty + @Nullable + public DimFilter getLeftFilter() + { + return leftFilter; + } + @Override public List getChildren() { @@ -171,7 +208,14 @@ public class JoinDataSource implements DataSource throw new IAE("Expected [2] children, got [%d]", children.size()); } - return new JoinDataSource(children.get(0), children.get(1), rightPrefix, conditionAnalysis, joinType); + return new JoinDataSource( + children.get(0), + children.get(1), + rightPrefix, + conditionAnalysis, + joinType, + leftFilter + ); } @Override @@ -206,13 +250,14 @@ public class JoinDataSource implements DataSource Objects.equals(right, that.right) && Objects.equals(rightPrefix, that.rightPrefix) && Objects.equals(conditionAnalysis, that.conditionAnalysis) && + Objects.equals(leftFilter, that.leftFilter) && joinType == that.joinType; } @Override public int hashCode() { - return Objects.hash(left, right, rightPrefix, conditionAnalysis, joinType); + return Objects.hash(left, right, rightPrefix, conditionAnalysis, joinType, leftFilter); } @Override @@ -224,6 +269,7 @@ public class JoinDataSource implements DataSource ", rightPrefix='" + rightPrefix + '\'' + ", condition=" + conditionAnalysis + ", joinType=" + joinType + + ", leftFilter=" + leftFilter + '}'; } } diff --git a/processing/src/main/java/org/apache/druid/query/Queries.java b/processing/src/main/java/org/apache/druid/query/Queries.java index 7c6c4bfe9fc..e25a88ea38b 100644 --- a/processing/src/main/java/org/apache/druid/query/Queries.java +++ b/processing/src/main/java/org/apache/druid/query/Queries.java @@ -26,6 +26,7 @@ import org.apache.druid.guice.annotations.PublicApi; import org.apache.druid.java.util.common.ISE; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.PostAggregator; +import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.query.planning.PreJoinableClause; import org.apache.druid.query.spec.MultipleSpecificSegmentSpec; @@ -192,6 +193,7 @@ public class Queries final DataSourceAnalysis analysis = DataSourceAnalysis.forDataSource(query.getDataSource()); DataSource current = newBaseDataSource; + DimFilter joinBaseFilter = analysis.getJoinBaseTableFilter().orElse(null); for (final PreJoinableClause clause : analysis.getPreJoinableClauses()) { current = JoinDataSource.create( @@ -199,8 +201,10 @@ public class Queries clause.getDataSource(), clause.getPrefix(), clause.getCondition(), - clause.getJoinType() + clause.getJoinType(), + joinBaseFilter ); + joinBaseFilter = null; } retVal = query.withDataSource(current); diff --git a/processing/src/main/java/org/apache/druid/query/planning/DataSourceAnalysis.java b/processing/src/main/java/org/apache/druid/query/planning/DataSourceAnalysis.java index bfa7644e357..061554f3d0c 100644 --- a/processing/src/main/java/org/apache/druid/query/planning/DataSourceAnalysis.java +++ b/processing/src/main/java/org/apache/druid/query/planning/DataSourceAnalysis.java @@ -28,6 +28,7 @@ import org.apache.druid.query.Query; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.UnionDataSource; +import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.spec.QuerySegmentSpec; import javax.annotation.Nullable; @@ -80,12 +81,15 @@ public class DataSourceAnalysis private final DataSource baseDataSource; @Nullable private final Query baseQuery; + @Nullable + private final DimFilter joinBaseTableFilter; private final List preJoinableClauses; private DataSourceAnalysis( DataSource dataSource, DataSource baseDataSource, @Nullable Query baseQuery, + @Nullable DimFilter joinBaseTableFilter, List preJoinableClauses ) { @@ -98,6 +102,7 @@ public class DataSourceAnalysis this.dataSource = dataSource; this.baseDataSource = baseDataSource; this.baseQuery = baseQuery; + this.joinBaseTableFilter = joinBaseTableFilter; this.preJoinableClauses = preJoinableClauses; } @@ -121,10 +126,10 @@ public class DataSourceAnalysis } if (current instanceof JoinDataSource) { - final Pair> flattened = flattenJoin((JoinDataSource) current); - return new DataSourceAnalysis(dataSource, flattened.lhs, baseQuery, flattened.rhs); + final Pair, List> flattened = flattenJoin((JoinDataSource) current); + return new DataSourceAnalysis(dataSource, flattened.lhs.lhs, baseQuery, flattened.lhs.rhs, flattened.rhs); } else { - return new DataSourceAnalysis(dataSource, current, baseQuery, Collections.emptyList()); + return new DataSourceAnalysis(dataSource, current, baseQuery, null, Collections.emptyList()); } } @@ -134,14 +139,19 @@ public class DataSourceAnalysis * * @throws IllegalArgumentException if dataSource cannot be fully flattened. */ - private static Pair> flattenJoin(final JoinDataSource dataSource) + private static Pair, List> flattenJoin(final JoinDataSource dataSource) { DataSource current = dataSource; + DimFilter currentDimFilter = null; final List preJoinableClauses = new ArrayList<>(); while (current instanceof JoinDataSource) { final JoinDataSource joinDataSource = (JoinDataSource) current; current = joinDataSource.getLeft(); + if (currentDimFilter != null) { + throw new IAE("Left filters are only allowed when left child is direct table access"); + } + currentDimFilter = joinDataSource.getLeftFilter(); preJoinableClauses.add( new PreJoinableClause( joinDataSource.getRightPrefix(), @@ -156,7 +166,7 @@ public class DataSourceAnalysis // going-up order. So reverse them. Collections.reverse(preJoinableClauses); - return Pair.of(current, preJoinableClauses); + return Pair.of(Pair.of(current, currentDimFilter), preJoinableClauses); } /** @@ -214,11 +224,20 @@ public class DataSourceAnalysis return Optional.ofNullable(baseQuery); } + /** + * If the original data source is a join data source and there is a DimFilter on the base table data source, + * that DimFilter is returned here + */ + public Optional getJoinBaseTableFilter() + { + return Optional.ofNullable(joinBaseTableFilter); + } + /** * Returns the {@link QuerySegmentSpec} that is associated with the base datasource, if any. This only happens * when there is an outer query datasource. In this case, the base querySegmentSpec is the one associated with the * innermost subquery. - * + *

* This {@link QuerySegmentSpec} is taken from the query returned by {@link #getBaseQuery()}. * * @return the query segment spec associated with the base datasource if {@link #isQuery()} is true, else empty diff --git a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java index bcfd109970d..2002ee10099 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java +++ b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java @@ -22,6 +22,7 @@ package org.apache.druid.segment.join; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.guava.CloseQuietly; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.query.filter.Filter; import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.SegmentReference; import org.apache.druid.segment.StorageAdapter; @@ -43,6 +44,7 @@ import java.util.Optional; public class HashJoinSegment implements SegmentReference { private final SegmentReference baseSegment; + private final Filter baseFilter; private final List clauses; private final JoinFilterPreAnalysis joinFilterPreAnalysis; @@ -54,11 +56,13 @@ public class HashJoinSegment implements SegmentReference */ public HashJoinSegment( SegmentReference baseSegment, + @Nullable Filter baseFilter, List clauses, JoinFilterPreAnalysis joinFilterPreAnalysis ) { this.baseSegment = baseSegment; + this.baseFilter = baseFilter; this.clauses = clauses; this.joinFilterPreAnalysis = joinFilterPreAnalysis; @@ -93,7 +97,12 @@ public class HashJoinSegment implements SegmentReference @Override public StorageAdapter asStorageAdapter() { - return new HashJoinSegmentStorageAdapter(baseSegment.asStorageAdapter(), clauses, joinFilterPreAnalysis); + return new HashJoinSegmentStorageAdapter( + baseSegment.asStorageAdapter(), + baseFilter, + clauses, + joinFilterPreAnalysis + ); } @Override diff --git a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java index d6517c1bfbc..7d3cde3236d 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java +++ b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java @@ -56,6 +56,7 @@ import java.util.Set; public class HashJoinSegmentStorageAdapter implements StorageAdapter { private final StorageAdapter baseAdapter; + private final Filter baseFilter; private final List clauses; private final JoinFilterPreAnalysis joinFilterPreAnalysis; @@ -69,8 +70,25 @@ public class HashJoinSegmentStorageAdapter implements StorageAdapter final List clauses, final JoinFilterPreAnalysis joinFilterPreAnalysis ) + { + this(baseAdapter, null, clauses, joinFilterPreAnalysis); + } + + /** + * @param baseAdapter A StorageAdapter for the left-hand side base segment + * @param baseFilter A filter for the left-hand side base segment + * @param clauses The right-hand side clauses. The caller is responsible for ensuring that there are no + * @param joinFilterPreAnalysis Pre-analysis for the query we expect to run on this storage adapter + */ + HashJoinSegmentStorageAdapter( + final StorageAdapter baseAdapter, + final Filter baseFilter, + final List clauses, + final JoinFilterPreAnalysis joinFilterPreAnalysis + ) { this.baseAdapter = baseAdapter; + this.baseFilter = baseFilter; this.clauses = clauses; this.joinFilterPreAnalysis = joinFilterPreAnalysis; } @@ -239,14 +257,12 @@ public class HashJoinSegmentStorageAdapter implements StorageAdapter postJoinVirtualColumns ); - JoinFilterSplit joinFilterSplit = JoinFilterAnalyzer.splitFilter(joinFilterPreAnalysis); + // We merge the filter on base table specified by the user and filter on the base table that is pushed from + // the join + JoinFilterSplit joinFilterSplit = JoinFilterAnalyzer.splitFilter(joinFilterPreAnalysis, baseFilter); preJoinVirtualColumns.addAll(joinFilterSplit.getPushDownVirtualColumns()); - // Soon, we will need a way to push filters past a join when possible. This could potentially be done right here - // (by splitting out pushable pieces of 'filter') or it could be done at a higher level (i.e. in the SQL planner). - // - // If it's done in the SQL planner, that will likely mean adding a 'baseFilter' parameter to this class that would - // be passed in to the below baseAdapter.makeCursors call (instead of the null filter). + final Sequence baseCursorSequence = baseAdapter.makeCursors( joinFilterSplit.getBaseTableFilter().isPresent() ? joinFilterSplit.getBaseTableFilter().get() : null, interval, diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinType.java b/processing/src/main/java/org/apache/druid/segment/join/JoinType.java index e628d7d0d47..e343e7e64fd 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/JoinType.java +++ b/processing/src/main/java/org/apache/druid/segment/join/JoinType.java @@ -23,13 +23,13 @@ public enum JoinType { INNER { @Override - boolean isLefty() + public boolean isLefty() { return false; } @Override - boolean isRighty() + public boolean isRighty() { return false; } @@ -37,13 +37,13 @@ public enum JoinType LEFT { @Override - boolean isLefty() + public boolean isLefty() { return true; } @Override - boolean isRighty() + public boolean isRighty() { return false; } @@ -51,13 +51,13 @@ public enum JoinType RIGHT { @Override - boolean isLefty() + public boolean isLefty() { return false; } @Override - boolean isRighty() + public boolean isRighty() { return true; } @@ -65,13 +65,13 @@ public enum JoinType FULL { @Override - boolean isLefty() + public boolean isLefty() { return true; } @Override - boolean isRighty() + public boolean isRighty() { return true; } @@ -80,10 +80,10 @@ public enum JoinType /** * "Lefty" joins (LEFT or FULL) always include the full left-hand side, and can generate nulls on the right. */ - abstract boolean isLefty(); + public abstract boolean isLefty(); /** * "Righty" joins (RIGHT or FULL) always include the full right-hand side, and can generate nulls on the left. */ - abstract boolean isRighty(); + public abstract boolean isRighty(); } diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java index df7e7f3ba64..b076b1ad825 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java +++ b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java @@ -24,6 +24,7 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.Query; import org.apache.druid.query.cache.CacheKeyBuilder; +import org.apache.druid.query.filter.Filter; import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.query.planning.PreJoinableClause; import org.apache.druid.segment.SegmentReference; @@ -69,6 +70,7 @@ public class JoinableFactoryWrapper * query from the end user. */ public Function createSegmentMapFn( + final Filter baseFilter, final List clauses, final AtomicLong cpuTimeAccumulator, final Query query @@ -94,6 +96,7 @@ public class JoinableFactoryWrapper return baseSegment -> new HashJoinSegment( baseSegment, + baseFilter, joinableClauses.getJoinableClauses(), joinFilterPreAnalysis ); @@ -103,7 +106,8 @@ public class JoinableFactoryWrapper } /** - * Compute a cache key prefix for data sources that participate in the RHS of a join. This key prefix + * Compute a cache key prefix for a join data source. This includes the data sources that participate in the RHS of a + * join as well as any query specific constructs associated with join data source such as base table filter. This key prefix * can be used in segment level cache or result level cache. The function can return following wrapped in an * Optional * - Non-empty byte array - If there is join datasource involved and caching is possible. The result includes @@ -126,6 +130,9 @@ public class JoinableFactoryWrapper final CacheKeyBuilder keyBuilder; keyBuilder = new CacheKeyBuilder(JOIN_OPERATION); + if (dataSourceAnalysis.getJoinBaseTableFilter().isPresent()) { + keyBuilder.appendCacheable(dataSourceAnalysis.getJoinBaseTableFilter().get()); + } for (PreJoinableClause clause : clauses) { Optional bytes = joinableFactory.computeJoinCacheKey(clause.getDataSource(), clause.getCondition()); if (!bytes.isPresent()) { diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java index cc9f244dc04..1fb9cc61d17 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java @@ -33,6 +33,7 @@ import org.apache.druid.segment.filter.OrFilter; import org.apache.druid.segment.filter.SelectorFilter; import org.apache.druid.segment.virtual.ExpressionVirtualColumn; +import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -137,18 +138,27 @@ public class JoinFilterAnalyzer return preAnalysisBuilder.withCorrelations(correlations).build(); } - /** - * @param joinFilterPreAnalysis The pre-analysis computed by {@link #computeJoinFilterPreAnalysis)} - * - * @return A JoinFilterSplit indicating what parts of the filter should be applied pre-join and post-join - */ public static JoinFilterSplit splitFilter( JoinFilterPreAnalysis joinFilterPreAnalysis ) + { + return splitFilter(joinFilterPreAnalysis, null); + } + + /** + * @param joinFilterPreAnalysis The pre-analysis computed by {@link #computeJoinFilterPreAnalysis)} + * @param baseFilter - Filter on base table that was specified in the query itself + * + * @return A JoinFilterSplit indicating what parts of the filter should be applied pre-join and post-join + */ + public static JoinFilterSplit splitFilter( + JoinFilterPreAnalysis joinFilterPreAnalysis, + @Nullable Filter baseFilter + ) { if (joinFilterPreAnalysis.getOriginalFilter() == null || !joinFilterPreAnalysis.isEnableFilterPushDown()) { return new JoinFilterSplit( - null, + baseFilter, joinFilterPreAnalysis.getOriginalFilter(), ImmutableSet.of() ); @@ -159,6 +169,10 @@ public class JoinFilterAnalyzer List rightFilters = new ArrayList<>(); Map pushDownVirtualColumnsForLhsExprs = new HashMap<>(); + if (null != baseFilter) { + leftFilters.add(baseFilter); + } + for (Filter baseTableFilter : joinFilterPreAnalysis.getNormalizedBaseTableClauses()) { if (!Filters.filterMatchesNull(baseTableFilter)) { leftFilters.add(baseTableFilter); diff --git a/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java b/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java index bc79731e47b..7a76a097657 100644 --- a/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java +++ b/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java @@ -24,12 +24,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import nl.jqno.equalsverifier.EqualsVerifier; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.query.filter.TrueDimFilter; import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.join.JoinType; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.Mockito; import java.util.Collections; @@ -152,11 +154,52 @@ public class JoinDataSourceTest public void test_serde() throws Exception { final ObjectMapper jsonMapper = TestHelper.makeJsonMapper(); + JoinDataSource joinDataSource = JoinDataSource.create( + new TableDataSource("table1"), + new TableDataSource("table2"), + "j.", + "x == \"j.x\"", + JoinType.LEFT, + TrueDimFilter.instance(), + ExprMacroTable.nil() + ); + final JoinDataSource deserialized = (JoinDataSource) jsonMapper.readValue( - jsonMapper.writeValueAsString(joinTableToLookup), + jsonMapper.writeValueAsString(joinDataSource), DataSource.class ); - Assert.assertEquals(joinTableToLookup, deserialized); + Assert.assertEquals(joinDataSource, deserialized); + } + + @Test + public void testException_leftFilterOnNonTableSource() + { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("left filter is only supported if left data source is direct table access"); + JoinDataSource ignored = JoinDataSource.create( + new QueryDataSource(Mockito.mock(Query.class)), + new TableDataSource("table"), + "j.", + "x == \"j.x\"", + JoinType.LEFT, + TrueDimFilter.instance(), + ExprMacroTable.nil() + ); + } + + @Test + public void testLeftFilter() + { + JoinDataSource dataSource = JoinDataSource.create( + new TableDataSource("table1"), + new TableDataSource("table2"), + "j.", + "x == \"j.x\"", + JoinType.LEFT, + TrueDimFilter.instance(), + ExprMacroTable.nil() + ); + Assert.assertEquals(TrueDimFilter.instance(), dataSource.getLeftFilter()); } } diff --git a/processing/src/test/java/org/apache/druid/query/QueriesTest.java b/processing/src/test/java/org/apache/druid/query/QueriesTest.java index 78dcafd7352..ea16df63920 100644 --- a/processing/src/test/java/org/apache/druid/query/QueriesTest.java +++ b/processing/src/test/java/org/apache/druid/query/QueriesTest.java @@ -30,6 +30,7 @@ import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.ConstantPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; +import org.apache.druid.query.filter.TrueDimFilter; import org.apache.druid.query.spec.MultipleSpecificSegmentSpec; import org.apache.druid.query.timeseries.TimeseriesQuery; import org.apache.druid.query.timeseries.TimeseriesResultValue; @@ -423,6 +424,7 @@ public class QueriesTest "j0.", "\"foo.x\" == \"bar.x\"", JoinType.INNER, + null, ExprMacroTable.nil() ) ) @@ -459,6 +461,7 @@ public class QueriesTest "j0.", "\"foo.x\" == \"bar.x\"", JoinType.INNER, + null, ExprMacroTable.nil() ) ) @@ -479,4 +482,78 @@ public class QueriesTest ) ); } + + @Test + public void testWithBaseDataSourcedBaseFilterWithMultiJoin() + { + Assert.assertEquals( + Druids.newTimeseriesQueryBuilder() + .dataSource( + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource( + JoinDataSource.create( + JoinDataSource.create( + new TableDataSource("foo"), + new TableDataSource("bar"), + "j1.", + "\"foo.x\" == \"bar.x\"", + JoinType.INNER, + TrueDimFilter.instance(), + ExprMacroTable.nil() + ), + new TableDataSource("foo_outer"), + "j0.", + "\"foo_outer.x\" == \"bar.x\"", + JoinType.INNER, + null, + ExprMacroTable.nil() + ) + + ) + .intervals("2000/3000") + .granularity(Granularities.ALL) + .build() + ) + ) + .intervals("2000/3000") + .granularity(Granularities.ALL) + .build(), + Queries.withBaseDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource( + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource( + JoinDataSource.create( + JoinDataSource.create( + new TableDataSource("foo_inner"), + new TableDataSource("bar"), + "j1.", + "\"foo.x\" == \"bar.x\"", + JoinType.INNER, + TrueDimFilter.instance(), + ExprMacroTable.nil() + ), + new TableDataSource("foo_outer"), + "j0.", + "\"foo_outer.x\" == \"bar.x\"", + JoinType.INNER, + null, + ExprMacroTable.nil() + ) + + ) + .intervals("2000/3000") + .granularity(Granularities.ALL) + .build() + ) + ) + .intervals("2000/3000") + .granularity(Granularities.ALL) + .build(), + new TableDataSource("foo") + ) + ); + } } diff --git a/processing/src/test/java/org/apache/druid/query/planning/DataSourceAnalysisTest.java b/processing/src/test/java/org/apache/druid/query/planning/DataSourceAnalysisTest.java index 82d66351470..93e75c3f190 100644 --- a/processing/src/test/java/org/apache/druid/query/planning/DataSourceAnalysisTest.java +++ b/processing/src/test/java/org/apache/druid/query/planning/DataSourceAnalysisTest.java @@ -32,6 +32,8 @@ import org.apache.druid.query.LookupDataSource; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.UnionDataSource; +import org.apache.druid.query.filter.DimFilter; +import org.apache.druid.query.filter.TrueDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.segment.column.RowSignature; @@ -237,6 +239,54 @@ public class DataSourceAnalysisTest Assert.assertEquals(joinDataSource, analysis.getDataSource()); Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource()); Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource()); + Assert.assertEquals(Optional.empty(), analysis.getJoinBaseTableFilter()); + Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource()); + Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource()); + Assert.assertEquals(Optional.empty(), analysis.getBaseQuery()); + Assert.assertEquals(Optional.empty(), analysis.getBaseQuerySegmentSpec()); + Assert.assertEquals( + ImmutableList.of( + new PreJoinableClause("1.", LOOKUP_LOOKYLOO, JoinType.INNER, joinClause("1.")), + new PreJoinableClause("2.", INLINE, JoinType.LEFT, joinClause("2.")), + new PreJoinableClause("3.", subquery(LOOKUP_LOOKYLOO), JoinType.FULL, joinClause("3.")) + ), + analysis.getPreJoinableClauses() + ); + Assert.assertTrue(analysis.isJoin()); + } + + @Test + public void testJoinSimpleLeftLeaningWithLeftFilter() + { + final JoinDataSource joinDataSource = + join( + join( + join( + TABLE_FOO, + LOOKUP_LOOKYLOO, + "1.", + JoinType.INNER, + TrueDimFilter.instance() + ), + INLINE, + "2.", + JoinType.LEFT + ), + subquery(LOOKUP_LOOKYLOO), + "3.", + JoinType.FULL + ); + + final DataSourceAnalysis analysis = DataSourceAnalysis.forDataSource(joinDataSource); + + Assert.assertTrue(analysis.isConcreteBased()); + Assert.assertTrue(analysis.isConcreteTableBased()); + Assert.assertFalse(analysis.isGlobal()); + Assert.assertFalse(analysis.isQuery()); + Assert.assertEquals(joinDataSource, analysis.getDataSource()); + Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource()); + Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource()); + Assert.assertEquals(TrueDimFilter.instance(), analysis.getJoinBaseTableFilter().orElse(null)); Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource()); Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource()); Assert.assertEquals(Optional.empty(), analysis.getBaseQuery()); @@ -290,6 +340,54 @@ public class DataSourceAnalysisTest Assert.assertEquals(joinDataSource, analysis.getDataSource()); Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource()); Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource()); + Assert.assertEquals(Optional.empty(), analysis.getJoinBaseTableFilter()); + Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource()); + Assert.assertEquals(Optional.empty(), analysis.getBaseQuery()); + Assert.assertEquals(Optional.empty(), analysis.getBaseQuerySegmentSpec()); + Assert.assertEquals( + ImmutableList.of( + new PreJoinableClause("3.", rightLeaningJoinStack, JoinType.RIGHT, joinClause("3.")) + ), + analysis.getPreJoinableClauses() + ); + Assert.assertTrue(analysis.isJoin()); + } + + @Test + public void testJoinSimpleRightLeaningWithLeftFilter() + { + final JoinDataSource rightLeaningJoinStack = + join( + LOOKUP_LOOKYLOO, + join( + INLINE, + subquery(LOOKUP_LOOKYLOO), + "1.", + JoinType.LEFT + ), + "2.", + JoinType.FULL + ); + + final JoinDataSource joinDataSource = + join( + TABLE_FOO, + rightLeaningJoinStack, + "3.", + JoinType.RIGHT, + TrueDimFilter.instance() + ); + + final DataSourceAnalysis analysis = DataSourceAnalysis.forDataSource(joinDataSource); + + Assert.assertTrue(analysis.isConcreteBased()); + Assert.assertTrue(analysis.isConcreteTableBased()); + Assert.assertFalse(analysis.isGlobal()); + Assert.assertFalse(analysis.isQuery()); + Assert.assertEquals(joinDataSource, analysis.getDataSource()); + Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource()); + Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource()); + Assert.assertEquals(TrueDimFilter.instance(), analysis.getJoinBaseTableFilter().orElse(null)); Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource()); Assert.assertEquals(Optional.empty(), analysis.getBaseQuery()); Assert.assertEquals(Optional.empty(), analysis.getBaseQuerySegmentSpec()); @@ -309,7 +407,8 @@ public class DataSourceAnalysisTest TABLE_FOO, subquery(TABLE_FOO), "1.", - JoinType.INNER + JoinType.INNER, + TrueDimFilter.instance() ); final DataSourceAnalysis analysis = DataSourceAnalysis.forDataSource(joinDataSource); @@ -320,6 +419,7 @@ public class DataSourceAnalysisTest Assert.assertFalse(analysis.isQuery()); Assert.assertEquals(joinDataSource, analysis.getDataSource()); Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource()); + Assert.assertEquals(TrueDimFilter.instance(), analysis.getJoinBaseTableFilter().orElse(null)); Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource()); Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource()); Assert.assertEquals( @@ -350,6 +450,7 @@ public class DataSourceAnalysisTest Assert.assertFalse(analysis.isQuery()); Assert.assertEquals(joinDataSource, analysis.getDataSource()); Assert.assertEquals(Optional.empty(), analysis.getBaseTableDataSource()); + Assert.assertEquals(Optional.empty(), analysis.getJoinBaseTableFilter()); Assert.assertEquals(Optional.of(unionDataSource), analysis.getBaseUnionDataSource()); Assert.assertEquals(unionDataSource, analysis.getBaseDataSource()); Assert.assertEquals(Optional.empty(), analysis.getBaseQuery()); @@ -373,7 +474,8 @@ public class DataSourceAnalysisTest TABLE_FOO, LOOKUP_LOOKYLOO, "1.", - JoinType.INNER + JoinType.INNER, + TrueDimFilter.instance() ) ) ); @@ -386,6 +488,7 @@ public class DataSourceAnalysisTest Assert.assertTrue(analysis.isQuery()); Assert.assertEquals(queryDataSource, analysis.getDataSource()); Assert.assertEquals(TABLE_FOO, analysis.getBaseDataSource()); + Assert.assertEquals(TrueDimFilter.instance(), analysis.getJoinBaseTableFilter().orElse(null)); Assert.assertEquals(Optional.of(TABLE_FOO), analysis.getBaseTableDataSource()); Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource()); Assert.assertEquals( @@ -395,7 +498,8 @@ public class DataSourceAnalysisTest TABLE_FOO, LOOKUP_LOOKYLOO, "1.", - JoinType.INNER + JoinType.INNER, + TrueDimFilter.instance() ) ).getQuery() ), @@ -436,6 +540,7 @@ public class DataSourceAnalysisTest Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource()); Assert.assertEquals(Optional.empty(), analysis.getBaseQuery()); Assert.assertEquals(Optional.empty(), analysis.getBaseQuerySegmentSpec()); + Assert.assertEquals(Optional.empty(), analysis.getJoinBaseTableFilter()); Assert.assertEquals( ImmutableList.of( new PreJoinableClause("1.", LOOKUP_LOOKYLOO, JoinType.INNER, joinClause("1.")) @@ -467,6 +572,7 @@ public class DataSourceAnalysisTest Assert.assertEquals(Optional.empty(), analysis.getBaseUnionDataSource()); Assert.assertEquals(Optional.empty(), analysis.getBaseQuery()); Assert.assertEquals(Optional.empty(), analysis.getBaseQuerySegmentSpec()); + Assert.assertEquals(Optional.empty(), analysis.getJoinBaseTableFilter()); Assert.assertEquals( ImmutableList.of( new PreJoinableClause("1.", TABLE_FOO, JoinType.INNER, joinClause("1.")) @@ -484,7 +590,7 @@ public class DataSourceAnalysisTest .withNonnullFields("dataSource") // These fields are not necessary, because they're wholly determined by "dataSource" - .withIgnoredFields("baseDataSource", "baseQuery", "preJoinableClauses") + .withIgnoredFields("baseDataSource", "baseQuery", "preJoinableClauses", "joinBaseTableFilter") .verify(); } @@ -495,7 +601,8 @@ public class DataSourceAnalysisTest final DataSource left, final DataSource right, final String rightPrefix, - final JoinType joinType + final JoinType joinType, + final DimFilter dimFilter ) { return JoinDataSource.create( @@ -504,10 +611,21 @@ public class DataSourceAnalysisTest rightPrefix, joinClause(rightPrefix).getOriginalExpression(), joinType, + dimFilter, ExprMacroTable.nil() ); } + private static JoinDataSource join( + final DataSource left, + final DataSource right, + final String rightPrefix, + final JoinType joinType + ) + { + return join(left, right, rightPrefix, joinType, null); + } + /** * Generate a join clause that joins on a column named "x" on both sides. */ diff --git a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java index 55469624ac8..68e6426551e 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java @@ -35,6 +35,7 @@ import org.apache.druid.segment.VirtualColumn; import org.apache.druid.segment.VirtualColumns; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.filter.SelectorFilter; import org.apache.druid.segment.join.filter.JoinFilterPreAnalysis; import org.apache.druid.segment.join.lookup.LookupJoinable; @@ -43,6 +44,7 @@ import org.junit.Assert; import org.junit.Test; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -2019,6 +2021,207 @@ public class HashJoinSegmentStorageAdapterTest extends BaseHashJoinSegmentStorag ); } + @Test + public void test_makeCursors_factToCountryLeftWithBaseFilter() + { + List joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.LEFT)); + + JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( + null, + joinableClauses, + VirtualColumns.EMPTY + ); + JoinTestHelper.verifyCursors( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + Filters.or(Arrays.asList( + new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), + new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() + )), + joinableClauses, + joinFilterPreAnalysis + ).makeCursors( + null, + Intervals.ETERNITY, + VirtualColumns.EMPTY, + Granularities.ALL, + false, + null + ), + ImmutableList.of( + "page", + "countryIsoCode", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryIsoCode", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryName", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryNumber" + ), + ImmutableList.of( + new Object[]{"Didier Leclair", "CA", "CA", "Canada", 1L}, + new Object[]{"Les Argonautes", "CA", "CA", "Canada", 1L}, + new Object[]{"Sarah Michelle Gellar", "CA", "CA", "Canada", 1L}, + new Object[]{"Orange Soda", "MatchNothing", null, null, NULL_COUNTRY} + ) + ); + } + + @Test + public void test_makeCursors_factToCountryInnerWithBaseFilter() + { + List joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.INNER)); + JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( + null, + joinableClauses, + VirtualColumns.EMPTY + ); + JoinTestHelper.verifyCursors( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + Filters.or(Arrays.asList( + new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), + new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() + )), + joinableClauses, + joinFilterPreAnalysis + ).makeCursors( + null, + Intervals.ETERNITY, + VirtualColumns.EMPTY, + Granularities.ALL, + false, + null + ), + ImmutableList.of( + "page", + "countryIsoCode", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryIsoCode", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryName", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryNumber" + ), + ImmutableList.of( + new Object[]{"Didier Leclair", "CA", "CA", "Canada", 1L}, + new Object[]{"Les Argonautes", "CA", "CA", "Canada", 1L}, + new Object[]{"Sarah Michelle Gellar", "CA", "CA", "Canada", 1L} + ) + ); + } + + @Test + public void test_makeCursors_factToCountryRightWithBaseFilter() + { + List joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.RIGHT)); + JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( + null, + joinableClauses, + VirtualColumns.EMPTY + ); + JoinTestHelper.verifyCursors( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + Filters.or(Arrays.asList( + new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), + new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() + )), + joinableClauses, + joinFilterPreAnalysis + ).makeCursors( + null, + Intervals.ETERNITY, + VirtualColumns.EMPTY, + Granularities.ALL, + false, + null + ), + ImmutableList.of( + "page", + "countryIsoCode", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryIsoCode", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryName", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryNumber" + ), + ImmutableList.of( + new Object[]{"Didier Leclair", "CA", "CA", "Canada", 1L}, + new Object[]{"Les Argonautes", "CA", "CA", "Canada", 1L}, + new Object[]{"Sarah Michelle Gellar", "CA", "CA", "Canada", 1L}, + new Object[]{null, null, "AU", "Australia", 0L}, + new Object[]{null, null, "CL", "Chile", 2L}, + new Object[]{null, null, "DE", "Germany", 3L}, + new Object[]{null, null, "EC", "Ecuador", 4L}, + new Object[]{null, null, "FR", "France", 5L}, + new Object[]{null, null, "GB", "United Kingdom", 6L}, + new Object[]{null, null, "IT", "Italy", 7L}, + new Object[]{null, null, "JP", "Japan", 8L}, + new Object[]{null, null, "KR", "Republic of Korea", 9L}, + new Object[]{null, null, "MX", "Mexico", 10L}, + new Object[]{null, null, "NO", "Norway", 11L}, + new Object[]{null, null, "SV", "El Salvador", 12L}, + new Object[]{null, null, "US", "United States", 13L}, + new Object[]{null, null, "AX", "Atlantis", 14L}, + new Object[]{null, null, "SU", "States United", 15L}, + new Object[]{null, null, "USCA", "Usca", 16L}, + new Object[]{null, null, "MMMM", "Fourems", 205L} + ) + ); + } + + @Test + public void test_makeCursors_factToCountryFullWithBaseFilter() + { + List joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.FULL)); + JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( + null, + joinableClauses, + VirtualColumns.EMPTY + ); + JoinTestHelper.verifyCursors( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + Filters.or(Arrays.asList( + new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), + new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() + )), + joinableClauses, + joinFilterPreAnalysis + ).makeCursors( + null, + Intervals.ETERNITY, + VirtualColumns.EMPTY, + Granularities.ALL, + false, + null + ), + ImmutableList.of( + "page", + "countryIsoCode", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryIsoCode", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryName", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX + "countryNumber" + ), + ImmutableList.of( + new Object[]{"Didier Leclair", "CA", "CA", "Canada", 1L}, + new Object[]{"Les Argonautes", "CA", "CA", "Canada", 1L}, + new Object[]{"Sarah Michelle Gellar", "CA", "CA", "Canada", 1L}, + new Object[]{"Orange Soda", "MatchNothing", null, null, NullHandling.sqlCompatible() ? null : 0L}, + new Object[]{null, null, "AU", "Australia", 0L}, + new Object[]{null, null, "CL", "Chile", 2L}, + new Object[]{null, null, "DE", "Germany", 3L}, + new Object[]{null, null, "EC", "Ecuador", 4L}, + new Object[]{null, null, "FR", "France", 5L}, + new Object[]{null, null, "GB", "United Kingdom", 6L}, + new Object[]{null, null, "IT", "Italy", 7L}, + new Object[]{null, null, "JP", "Japan", 8L}, + new Object[]{null, null, "KR", "Republic of Korea", 9L}, + new Object[]{null, null, "MX", "Mexico", 10L}, + new Object[]{null, null, "NO", "Norway", 11L}, + new Object[]{null, null, "SV", "El Salvador", 12L}, + new Object[]{null, null, "US", "United States", 13L}, + new Object[]{null, null, "AX", "Atlantis", 14L}, + new Object[]{null, null, "SU", "States United", 15L}, + new Object[]{null, null, "USCA", "Usca", 16L}, + new Object[]{null, null, "MMMM", "Fourems", 205L} + ) + ); + } + @Test public void test_determineBaseColumnsWithPreAndPostJoinVirtualColumns() { diff --git a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java index 2cdfcadaa48..9a56b3b6bdc 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java @@ -182,6 +182,7 @@ public class HashJoinSegmentTest extends InitializedNullHandlingTest }; hashJoinSegment = new HashJoinSegment( testWrapper, + null, joinableClauses, null ) @@ -210,6 +211,7 @@ public class HashJoinSegmentTest extends InitializedNullHandlingTest final HashJoinSegment ignored = new HashJoinSegment( ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment), + null, joinableClauses, null ); diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java index 72eb42ffb27..e422423b952 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java @@ -40,6 +40,7 @@ import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.filter.AndFilter; import org.apache.druid.segment.filter.BoundFilter; import org.apache.druid.segment.filter.FalseFilter; +import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.filter.OrFilter; import org.apache.druid.segment.filter.SelectorFilter; import org.apache.druid.segment.join.filter.JoinFilterAnalyzer; @@ -54,6 +55,7 @@ import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.junit.Assert; import org.junit.Test; +import java.util.Arrays; import java.util.List; import java.util.Set; @@ -2650,4 +2652,29 @@ public class JoinFilterAnalyzerTest extends BaseHashJoinSegmentStorageAdapterTes .withNonnullFields("virtualColumns") .verify(); } + + @Test + public void test_filterPushDown_baseTableFilter() + { + Filter originalFilter = new SelectorFilter("channel", "#en.wikipedia"); + Filter baseTableFilter = new SelectorFilter("countryIsoCode", "CA"); + List joinableClauses = ImmutableList.of( + factToRegion(JoinType.LEFT), + regionToCountry(JoinType.LEFT) + ); + + JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( + originalFilter, + joinableClauses, + VirtualColumns.EMPTY + ); + + JoinFilterSplit expectedFilterSplit = new JoinFilterSplit( + Filters.and(Arrays.asList(originalFilter, baseTableFilter)), + null, + ImmutableSet.of() + ); + JoinFilterSplit actualFilterSplit = JoinFilterAnalyzer.splitFilter(joinFilterPreAnalysis, baseTableFilter); + Assert.assertEquals(expectedFilterSplit, actualFilterSplit); + } } diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java index 3957afb31b8..94067c34412 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java @@ -33,6 +33,8 @@ import org.apache.druid.query.QueryContexts; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.TestQuery; import org.apache.druid.query.extraction.MapLookupExtractor; +import org.apache.druid.query.filter.FalseDimFilter; +import org.apache.druid.query.filter.TrueDimFilter; import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.query.planning.PreJoinableClause; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; @@ -72,6 +74,7 @@ public class JoinableFactoryWrapperTest public void test_createSegmentMapFn_noClauses() { final Function segmentMapFn = NOOP_JOINABLE_FACTORY_WRAPPER.createSegmentMapFn( + null, ImmutableList.of(), new AtomicLong(), null @@ -95,6 +98,7 @@ public class JoinableFactoryWrapperTest expectedException.expectMessage("dataSource is not joinable"); final Function ignored = NOOP_JOINABLE_FACTORY_WRAPPER.createSegmentMapFn( + null, ImmutableList.of(clause), new AtomicLong(), null @@ -138,6 +142,7 @@ public class JoinableFactoryWrapperTest } }); final Function segmentMapFn = joinableFactoryWrapper.createSegmentMapFn( + null, ImmutableList.of(clause), new AtomicLong(), new TestQuery( @@ -157,6 +162,7 @@ public class JoinableFactoryWrapperTest DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class); DataSource dataSource = new NoopDataSource(); EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.emptyList()); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()); EasyMock.expect(analysis.getDataSource()).andReturn(dataSource); EasyMock.replay(analysis); JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey()); @@ -172,11 +178,11 @@ public class JoinableFactoryWrapperTest @Test public void test_computeJoinDataSourceCacheKey_noHashJoin() { - PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j."); PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_2", "x != \"h.x\"", "h."); DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class); EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Arrays.asList(clause1, clause2)).anyTimes(); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.of(TrueDimFilter.instance())).anyTimes(); EasyMock.replay(analysis); JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey()); Optional cacheKey = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); @@ -192,6 +198,7 @@ public class JoinableFactoryWrapperTest PreJoinableClause clause2 = makePreJoinableClause(dataSource, "x == \"h.x\"", "h.", JoinType.LEFT); DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class); EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Arrays.asList(clause1, clause2)).anyTimes(); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.of(TrueDimFilter.instance())).anyTimes(); EasyMock.replay(analysis); JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey()); Optional cacheKey = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); @@ -207,6 +214,7 @@ public class JoinableFactoryWrapperTest PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_2", "x == \"h.x\"", "h."); DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class); EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Arrays.asList(clause1, clause2)).anyTimes(); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); EasyMock.replay(analysis); JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey()); Optional cacheKey = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); @@ -218,6 +226,7 @@ public class JoinableFactoryWrapperTest public void test_computeJoinDataSourceCacheKey_keyChangesWithExpression() { DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey()); PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "y == \"j.y\"", "j."); @@ -231,6 +240,7 @@ public class JoinableFactoryWrapperTest PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j."); EasyMock.reset(analysis); EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes(); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); EasyMock.replay(analysis); Optional cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); Assert.assertTrue(cacheKey2.isPresent()); @@ -242,6 +252,7 @@ public class JoinableFactoryWrapperTest public void test_computeJoinDataSourceCacheKey_keyChangesWithJoinType() { DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey()); PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j.", JoinType.LEFT); @@ -255,6 +266,7 @@ public class JoinableFactoryWrapperTest PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j.", JoinType.INNER); EasyMock.reset(analysis); EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes(); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); EasyMock.replay(analysis); Optional cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); Assert.assertTrue(cacheKey2.isPresent()); @@ -266,6 +278,7 @@ public class JoinableFactoryWrapperTest public void test_computeJoinDataSourceCacheKey_keyChangesWithPrefix() { DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey()); PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "abc == xyz", "ab"); @@ -279,6 +292,33 @@ public class JoinableFactoryWrapperTest PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_1", "abc == xyz", "xy"); EasyMock.reset(analysis); EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes(); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); + EasyMock.replay(analysis); + Optional cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); + Assert.assertTrue(cacheKey2.isPresent()); + + Assert.assertFalse(Arrays.equals(cacheKey1.get(), cacheKey2.get())); + } + + @Test + public void test_computeJoinDataSourceCacheKey_keyChangesWithBaseFilter() + { + DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.of(TrueDimFilter.instance())).anyTimes(); + JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey()); + + PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "abc == xyz", "ab"); + EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause1)).anyTimes(); + EasyMock.replay(analysis); + + Optional cacheKey1 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); + Assert.assertTrue(cacheKey1.isPresent()); + Assert.assertNotEquals(0, cacheKey1.get().length); + + PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_1", "abc == xyz", "ab"); + EasyMock.reset(analysis); + EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes(); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.of(FalseDimFilter.instance())).anyTimes(); EasyMock.replay(analysis); Optional cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); Assert.assertTrue(cacheKey2.isPresent()); @@ -290,6 +330,7 @@ public class JoinableFactoryWrapperTest public void test_computeJoinDataSourceCacheKey_keyChangesWithJoinable() { DataSourceAnalysis analysis = EasyMock.mock(DataSourceAnalysis.class); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(new JoinableFactoryWithCacheKey()); PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j."); @@ -303,6 +344,8 @@ public class JoinableFactoryWrapperTest PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_2", "x == \"j.x\"", "j."); EasyMock.reset(analysis); EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes(); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); + EasyMock.replay(analysis); Optional cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); Assert.assertTrue(cacheKey2.isPresent()); @@ -318,6 +361,7 @@ public class JoinableFactoryWrapperTest PreJoinableClause clause1 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j."); EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause1)).anyTimes(); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); EasyMock.replay(analysis); Optional cacheKey1 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); @@ -327,6 +371,7 @@ public class JoinableFactoryWrapperTest PreJoinableClause clause2 = makeGlobalPreJoinableClause("dataSource_1", "x == \"j.x\"", "j."); EasyMock.reset(analysis); EasyMock.expect(analysis.getPreJoinableClauses()).andReturn(Collections.singletonList(clause2)).anyTimes(); + EasyMock.expect(analysis.getJoinBaseTableFilter()).andReturn(Optional.empty()).anyTimes(); EasyMock.replay(analysis); Optional cacheKey2 = joinableFactoryWrapper.computeJoinDataSourceCacheKey(analysis); Assert.assertTrue(cacheKey2.isPresent()); diff --git a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java index 2c3d38f0a36..7704bb89a29 100644 --- a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java +++ b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java @@ -59,6 +59,7 @@ import org.apache.druid.query.spec.SpecificSegmentQueryRunner; import org.apache.druid.query.spec.SpecificSegmentSpec; import org.apache.druid.segment.SegmentReference; import org.apache.druid.segment.StorageAdapter; +import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.join.JoinableFactory; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.segment.realtime.FireHydrant; @@ -172,6 +173,7 @@ public class SinkQuerySegmentWalker implements QuerySegmentWalker // segmentMapFn maps each base Segment into a joined Segment if necessary. final Function segmentMapFn = joinableFactoryWrapper.createSegmentMapFn( + analysis.getJoinBaseTableFilter().map(Filters::toFilter).orElse(null), analysis.getPreJoinableClauses(), cpuTimeAccumulator, analysis.getBaseQuery().orElse(query) diff --git a/server/src/main/java/org/apache/druid/server/LocalQuerySegmentWalker.java b/server/src/main/java/org/apache/druid/server/LocalQuerySegmentWalker.java index 201e51461df..5ae9700f34d 100644 --- a/server/src/main/java/org/apache/druid/server/LocalQuerySegmentWalker.java +++ b/server/src/main/java/org/apache/druid/server/LocalQuerySegmentWalker.java @@ -36,6 +36,7 @@ import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.segment.ReferenceCountingSegment; import org.apache.druid.segment.SegmentReference; import org.apache.druid.segment.SegmentWrangler; +import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.join.JoinableFactory; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.joda.time.Interval; @@ -95,6 +96,7 @@ public class LocalQuerySegmentWalker implements QuerySegmentWalker final AtomicLong cpuAccumulator = new AtomicLong(0L); final Function segmentMapFn = joinableFactoryWrapper.createSegmentMapFn( + analysis.getJoinBaseTableFilter().map(Filters::toFilter).orElse(null), analysis.getPreJoinableClauses(), cpuAccumulator, analysis.getBaseQuery().orElse(query) diff --git a/server/src/main/java/org/apache/druid/server/coordination/ServerManager.java b/server/src/main/java/org/apache/druid/server/coordination/ServerManager.java index da4fc59b392..2ceaf96753a 100644 --- a/server/src/main/java/org/apache/druid/server/coordination/ServerManager.java +++ b/server/src/main/java/org/apache/druid/server/coordination/ServerManager.java @@ -59,6 +59,7 @@ import org.apache.druid.query.spec.SpecificSegmentSpec; import org.apache.druid.segment.ReferenceCountingSegment; import org.apache.druid.segment.SegmentReference; import org.apache.druid.segment.StorageAdapter; +import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.join.JoinableFactory; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.server.SegmentManager; @@ -200,6 +201,7 @@ public class ServerManager implements QuerySegmentWalker // segmentMapFn maps each base Segment into a joined Segment if necessary. final Function segmentMapFn = joinableFactoryWrapper.createSegmentMapFn( + analysis.getJoinBaseTableFilter().map(Filters::toFilter).orElse(null), analysis.getPreJoinableClauses(), cpuTimeAccumulator, analysis.getBaseQuery().orElse(query) diff --git a/server/src/test/java/org/apache/druid/server/TestClusterQuerySegmentWalker.java b/server/src/test/java/org/apache/druid/server/TestClusterQuerySegmentWalker.java index da6d89f75b0..3ba135f9b6c 100644 --- a/server/src/test/java/org/apache/druid/server/TestClusterQuerySegmentWalker.java +++ b/server/src/test/java/org/apache/druid/server/TestClusterQuerySegmentWalker.java @@ -47,6 +47,7 @@ import org.apache.druid.query.spec.SpecificSegmentSpec; import org.apache.druid.segment.ReferenceCountingSegment; import org.apache.druid.segment.Segment; import org.apache.druid.segment.SegmentReference; +import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.join.JoinableFactory; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.timeline.TimelineObjectHolder; @@ -143,6 +144,7 @@ public class TestClusterQuerySegmentWalker implements QuerySegmentWalker } final Function segmentMapFn = joinableFactoryWrapper.createSegmentMapFn( + analysis.getJoinBaseTableFilter().map(Filters::toFilter).orElse(null), analysis.getPreJoinableClauses(), new AtomicLong(), analysis.getBaseQuery().orElse(query) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java index 5500e50d7f0..1b28a9830c8 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java @@ -31,6 +31,7 @@ import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.plan.volcano.RelSubset; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.metadata.RelMetadataQuery; @@ -38,6 +39,7 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlKind; import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.guava.Sequence; @@ -45,13 +47,16 @@ import org.apache.druid.query.DataSource; import org.apache.druid.query.JoinDataSource; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.filter.DimFilter; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.join.JoinType; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.table.RowSignatures; +import javax.annotation.Nullable; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -64,6 +69,7 @@ public class DruidJoinQueryRel extends DruidRel { private static final TableDataSource DUMMY_DATA_SOURCE = new TableDataSource("__join__"); + private final Filter leftFilter; private final PartialDruidQuery partialQuery; private final Join joinRel; private RelNode left; @@ -73,6 +79,7 @@ public class DruidJoinQueryRel extends DruidRel RelOptCluster cluster, RelTraitSet traitSet, Join joinRel, + Filter leftFilter, PartialDruidQuery partialQuery, QueryMaker queryMaker ) @@ -81,6 +88,7 @@ public class DruidJoinQueryRel extends DruidRel this.joinRel = joinRel; this.left = joinRel.getLeft(); this.right = joinRel.getRight(); + this.leftFilter = leftFilter; this.partialQuery = partialQuery; } @@ -89,6 +97,7 @@ public class DruidJoinQueryRel extends DruidRel */ public static DruidJoinQueryRel create( final Join joinRel, + final Filter leftFilter, final QueryMaker queryMaker ) { @@ -96,6 +105,7 @@ public class DruidJoinQueryRel extends DruidRel joinRel.getCluster(), joinRel.getTraitSet(), joinRel, + leftFilter, PartialDruidQuery.create(joinRel), queryMaker ); @@ -125,6 +135,7 @@ public class DruidJoinQueryRel extends DruidRel getCluster(), getTraitSet().plusAll(newQueryBuilder.getRelTraits()), joinRel, + leftFilter, newQueryBuilder, getQueryMaker() ); @@ -145,6 +156,9 @@ public class DruidJoinQueryRel extends DruidRel if (computeLeftRequiresSubquery(leftDruidRel)) { leftDataSource = new QueryDataSource(leftQuery.getQuery()); + if (leftFilter != null) { + throw new ISE("Filter on left table is supposed to be null if left child is a query source"); + } } else { leftDataSource = leftQuery.getDataSource(); } @@ -177,6 +191,7 @@ public class DruidJoinQueryRel extends DruidRel prefixSignaturePair.lhs, condition.getExpression(), toDruidJoinType(joinRel.getJoinType()), + getDimFilter(getPlannerContext(), leftSignature, leftFilter), getPlannerContext().getExprMacroTable() ), prefixSignaturePair.rhs, @@ -214,6 +229,7 @@ public class DruidJoinQueryRel extends DruidRel .map(input -> RelOptRule.convert(input, DruidConvention.instance())) .collect(Collectors.toList()) ), + leftFilter, partialQuery, getQueryMaker() ); @@ -252,6 +268,7 @@ public class DruidJoinQueryRel extends DruidRel getCluster(), traitSet, joinRel.copy(joinRel.getTraitSet(), inputs), + leftFilter, getPartialDruidQuery(), getQueryMaker() ); @@ -312,7 +329,7 @@ public class DruidJoinQueryRel extends DruidRel return planner.getCostFactory().makeCost(cost, 0, 0); } - private static JoinType toDruidJoinType(JoinRelType calciteJoinType) + public static JoinType toDruidJoinType(JoinRelType calciteJoinType) { switch (calciteJoinType) { case LEFT: @@ -377,4 +394,28 @@ public class DruidJoinQueryRel extends DruidRel return (DruidRel) Iterables.getFirst(subset.getRels(), null); } } + + @Nullable + private static DimFilter getDimFilter( + final PlannerContext plannerContext, + final RowSignature rowSignature, + @Nullable final Filter filter + ) + { + if (filter == null) { + return null; + } + final RexNode condition = filter.getCondition(); + final DimFilter dimFilter = Expressions.toFilter( + plannerContext, + rowSignature, + null, + condition + ); + if (dimFilter == null) { + throw new CannotBuildQueryException(filter, condition); + } else { + return dimFilter; + } + } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java index deeea65745d..a12ba3444ca 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java @@ -19,6 +19,7 @@ package org.apache.druid.sql.calcite.rel; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSortedMap; @@ -42,9 +43,11 @@ import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.ImmutableBitSet; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.query.DataSource; +import org.apache.druid.query.JoinDataSource; import org.apache.druid.query.Query; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.aggregation.PostAggregator; @@ -648,6 +651,52 @@ public class DruidQuery return VirtualColumns.create(columns); } + /** + * Returns a pair of DataSource and Filtration object created on the query filter. In case the, data source is + * a join datasource, the datasource may be altered and left filter of join datasource may + * be rid of time filters. + * TODO: should we optimize the base table filter just like we do with query filters + */ + @VisibleForTesting + static Pair getFiltration( + DataSource dataSource, + DimFilter filter, + VirtualColumnRegistry virtualColumnRegistry + ) + { + if (!(dataSource instanceof JoinDataSource)) { + return Pair.of(dataSource, toFiltration(filter, virtualColumnRegistry)); + } + JoinDataSource joinDataSource = (JoinDataSource) dataSource; + if (joinDataSource.getLeftFilter() == null) { + return Pair.of(dataSource, toFiltration(filter, virtualColumnRegistry)); + } + //TODO: We should avoid promoting the time filter as interval for right outer and full outer joins. This is not + // done now as we apply the intervals to left base table today irrespective of the join type. + + // If the join is left or inner, we can pull the intervals up to the query. This is done + // so that broker can prune the segments to query. + Filtration leftFiltration = Filtration.create(joinDataSource.getLeftFilter()) + .optimize(virtualColumnRegistry.getFullRowSignature()); + // Adds the intervals from the join left filter to query filtration + Filtration queryFiltration = Filtration.create(filter, leftFiltration.getIntervals()) + .optimize(virtualColumnRegistry.getFullRowSignature()); + JoinDataSource newDataSource = JoinDataSource.create( + joinDataSource.getLeft(), + joinDataSource.getRight(), + joinDataSource.getRightPrefix(), + joinDataSource.getConditionAnalysis(), + joinDataSource.getJoinType(), + leftFiltration.getDimFilter() + ); + return Pair.of(newDataSource, queryFiltration); + } + + private static Filtration toFiltration(DimFilter filter, VirtualColumnRegistry virtualColumnRegistry) + { + return Filtration.create(filter).optimize(virtualColumnRegistry.getFullRowSignature()); + } + public DataSource getDataSource() { return dataSource; @@ -793,7 +842,13 @@ public class DruidQuery return null; } - final Filtration filtration = Filtration.create(filter).optimize(virtualColumnRegistry.getFullRowSignature()); + final Pair dataSourceFiltrationPair = getFiltration( + dataSource, + filter, + virtualColumnRegistry + ); + final DataSource newDataSource = dataSourceFiltrationPair.lhs; + final Filtration filtration = dataSourceFiltrationPair.rhs; final List postAggregators = new ArrayList<>(grouping.getPostAggregators()); if (sorting != null && sorting.getProjection() != null) { @@ -801,7 +856,7 @@ public class DruidQuery } return new TimeseriesQuery( - dataSource, + newDataSource, filtration.getQuerySegmentSpec(), descending, getVirtualColumns(false), @@ -872,7 +927,13 @@ public class DruidQuery return null; } - final Filtration filtration = Filtration.create(filter).optimize(virtualColumnRegistry.getFullRowSignature()); + final Pair dataSourceFiltrationPair = getFiltration( + dataSource, + filter, + virtualColumnRegistry + ); + final DataSource newDataSource = dataSourceFiltrationPair.lhs; + final Filtration filtration = dataSourceFiltrationPair.rhs; final List postAggregators = new ArrayList<>(grouping.getPostAggregators()); if (sorting.getProjection() != null) { @@ -880,7 +941,7 @@ public class DruidQuery } return new TopNQuery( - dataSource, + newDataSource, getVirtualColumns(true), dimensionSpec, topNMetricSpec, @@ -911,7 +972,13 @@ public class DruidQuery return null; } - final Filtration filtration = Filtration.create(filter).optimize(virtualColumnRegistry.getFullRowSignature()); + final Pair dataSourceFiltrationPair = getFiltration( + dataSource, + filter, + virtualColumnRegistry + ); + final DataSource newDataSource = dataSourceFiltrationPair.lhs; + final Filtration filtration = dataSourceFiltrationPair.rhs; final DimFilterHavingSpec havingSpec; if (grouping.getHavingFilter() != null) { @@ -930,7 +997,7 @@ public class DruidQuery } return new GroupByQuery( - dataSource, + newDataSource, filtration.getQuerySegmentSpec(), getVirtualColumns(true), filtration.getDimFilter(), @@ -963,7 +1030,14 @@ public class DruidQuery throw new ISE("Cannot convert to Scan query without any columns."); } - final Filtration filtration = Filtration.create(filter).optimize(virtualColumnRegistry.getFullRowSignature()); + final Pair dataSourceFiltrationPair = getFiltration( + dataSource, + filter, + virtualColumnRegistry + ); + final DataSource newDataSource = dataSourceFiltrationPair.lhs; + final Filtration filtration = dataSourceFiltrationPair.rhs; + final ScanQuery.Order order; long scanOffset = 0L; long scanLimit = 0L; @@ -1008,7 +1082,7 @@ public class DruidQuery } return new ScanQuery( - dataSource, + newDataSource, filtration.getQuerySegmentSpec(), getVirtualColumns(true), ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java index b9f8f349a04..ac93b6951bf 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java @@ -26,6 +26,7 @@ import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; @@ -97,18 +98,19 @@ public class DruidJoinRule extends RelOptRule final DruidRel newLeft; final DruidRel newRight; + final Filter leftFilter; final List newProjectExprs = new ArrayList<>(); // Already verified to be present in "matches", so just call "get". // Can't be final, because we're going to reassign it up to a couple of times. ConditionAnalysis conditionAnalysis = analyzeCondition(join.getCondition(), join.getLeft().getRowType()).get(); - if (left.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT - && left.getPartialDruidQuery().getWhereFilter() == null) { + if (left.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT) { // Swap the left-side projection above the join, so the left side is a simple scan or mapping. This helps us // avoid subqueries. final RelNode leftScan = left.getPartialDruidQuery().getScan(); final Project leftProject = left.getPartialDruidQuery().getSelectProject(); + leftFilter = left.getPartialDruidQuery().getWhereFilter(); // Left-side projection expressions rewritten to be on top of the join. newProjectExprs.addAll(leftProject.getProjects()); @@ -121,6 +123,7 @@ public class DruidJoinRule extends RelOptRule } newLeft = left; + leftFilter = null; } if (right.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT @@ -163,6 +166,7 @@ public class DruidJoinRule extends RelOptRule join.getJoinType(), join.isSemiJoinDone() ), + leftFilter, left.getQueryMaker() ); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java index 5b913aa34ae..3767731e0b0 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.calcite.plan.RelOptPlanner; import org.apache.calcite.schema.SchemaPlus; +import org.apache.druid.annotations.UsedByJUnitParamsRunner; import org.apache.druid.common.config.NullHandling; import org.apache.druid.hll.VersionOneHyperLogLogCollector; import org.apache.druid.java.util.common.DateTimes; @@ -374,7 +375,8 @@ public class BaseCalciteQueryTest extends CalciteTestBase DataSource right, String rightPrefix, String condition, - JoinType joinType + JoinType joinType, + DimFilter filter ) { return JoinDataSource.create( @@ -383,10 +385,22 @@ public class BaseCalciteQueryTest extends CalciteTestBase rightPrefix, condition, joinType, + filter, CalciteTests.createExprMacroTable() ); } + public static JoinDataSource join( + DataSource left, + DataSource right, + String rightPrefix, + String condition, + JoinType joinType + ) + { + return join(left, right, rightPrefix, condition, joinType, null); + } + public static String equalsCondition(DruidExpression left, DruidExpression right) { return StringUtils.format("(%s == %s)", left.getExpression(), right.getExpression()); @@ -828,4 +842,39 @@ public class BaseCalciteQueryTest extends CalciteTestBase { skipVectorize = true; } + + /** + * This is a provider of query contexts that should be used by join tests. + * It tests various configs that can be passed to join queries. All the configs provided by this provider should + * have the join query engine return the same results. + */ + public static class QueryContextForJoinProvider + { + @UsedByJUnitParamsRunner + public static Object[] provideQueryContexts() + { + return new Object[]{ + // default behavior + QUERY_CONTEXT_DEFAULT, + // filter value re-writes enabled + new ImmutableMap.Builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true) + .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, true) + .build(), + // rewrite values enabled but filter re-writes disabled. + // This should be drive the same behavior as the previous config + new ImmutableMap.Builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true) + .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false) + .build(), + // filter re-writes disabled + new ImmutableMap.Builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false) + .build(), + }; + } + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java new file mode 100644 index 00000000000..57d343db2be --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java @@ -0,0 +1,481 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; +import org.apache.druid.data.input.InputRow; +import org.apache.druid.data.input.MapBasedInputRow; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.granularity.AllGranularity; +import org.apache.druid.query.QueryDataSource; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.FilteredAggregatorFactory; +import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; +import org.apache.druid.query.aggregation.LongSumAggregatorFactory; +import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory; +import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory; +import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator; +import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; +import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.expression.TestExprMacroTable; +import org.apache.druid.query.groupby.GroupByQuery; +import org.apache.druid.segment.IndexBuilder; +import org.apache.druid.segment.QueryableIndex; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.incremental.IncrementalIndexSchema; +import org.apache.druid.segment.join.JoinType; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; +import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; +import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.partition.LinearShardSpec; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.io.File; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +@RunWith(JUnitParamsRunner.class) +public class CalciteCorrelatedQueryTest extends BaseCalciteQueryTest +{ + private static final IncrementalIndexSchema INDEX_SCHEMA = new IncrementalIndexSchema.Builder() + .withMetrics( + new CountAggregatorFactory("cnt") + ) + .withRollup(false) + .withMinTimestamp(DateTimes.of("2020-12-31").getMillis()) + .build(); + private static final List DIMENSIONS = ImmutableList.of("user", "country", "city"); + + @Before + public void setup() throws Exception + { + final QueryableIndex index1 = IndexBuilder + .create() + .tmpDir(new File(temporaryFolder.newFolder(), "1")) + .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) + .schema(INDEX_SCHEMA) + .rows(getRawRows()) + .buildMMappedIndex(); + final DataSegment segment = DataSegment.builder() + .dataSource("visits") + .interval(index1.getDataInterval()) + .version("1") + .shardSpec(new LinearShardSpec(0)) + .size(0) + .build(); + walker.add(segment, index1); + + } + + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testCorrelatedSubquery(Map queryContext) throws Exception + { + cannotVectorize(); + + testQuery( + "select country, ANY_VALUE(\n" + + " select avg(\"users\") from (\n" + + " select floor(__time to day), count(distinct user) \"users\" from visits f where f.country = visits.country group by 1\n" + + " )\n" + + " ) as \"DAU\"\n" + + "from visits \n" + + "group by 1", + queryContext, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + join( + new TableDataSource("visits"), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource( + GroupByQuery.builder() + .setDataSource("visits") + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "timestamp_floor(\"__time\",'P1D',null,'UTC')", + ValueType.LONG, + TestExprMacroTable.INSTANCE + )) + .setDimFilter(not(selector("country", null, null))) + .setDimensions( + new DefaultDimensionSpec( + "v0", + "d0", + ValueType.LONG + ), + new DefaultDimensionSpec( + "country", + "d1" + ) + ) + .setAggregatorSpecs(new CardinalityAggregatorFactory( + "a0:a", + null, + Collections.singletonList( + new DefaultDimensionSpec( + "user", + "user" + )), + false, + true + )) + .setPostAggregatorSpecs(Collections.singletonList(new HyperUniqueFinalizingPostAggregator( + "a0", + "a0:a" + ))) + .setContext(queryContext) + .setGranularity(new AllGranularity()) + .build() + ) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setDimensions(new DefaultDimensionSpec("d1", "_d0")) + .setAggregatorSpecs( + new LongSumAggregatorFactory("_a0:sum", "a0"), + new CountAggregatorFactory("_a0:count") + ) + .setPostAggregatorSpecs(Collections.singletonList(new ArithmeticPostAggregator( + "_a0", + "quotient", + Arrays + .asList( + new FieldAccessPostAggregator(null, "_a0:sum"), + new FieldAccessPostAggregator(null, "_a0:count") + ) + ))) + .setGranularity(new AllGranularity()) + .setContext(queryContext) + .build() + ), + "j0.", + equalsCondition( + DruidExpression.fromColumn("country"), + DruidExpression.fromColumn("j0._d0") + ), + JoinType.LEFT + ) + ) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setDimensions(new DefaultDimensionSpec("country", "d0")) + .setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0")) + .setGranularity(new AllGranularity()) + .setContext(queryContext) + .build() + ), + ImmutableList.of( + new Object[]{"India", 2L}, + new Object[]{"USA", 1L}, + new Object[]{"canada", 3L} + ) + ); + } + + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testCorrelatedSubqueryWithLeftFilter(Map queryContext) throws Exception + { + cannotVectorize(); + + testQuery( + "select country, ANY_VALUE(\n" + + " select max(\"users\") from (\n" + + " select floor(__time to day), count(*) \"users\" from visits f where f.country = visits.country group by 1\n" + + " )\n" + + " ) as \"dailyVisits\"\n" + + "from visits \n" + + " where city = 'B' and __time between '2021-01-01 01:00:00' AND '2021-01-02 23:59:59'" + + " group by 1", + queryContext, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + join( + new TableDataSource("visits"), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource( + GroupByQuery.builder() + .setDataSource("visits") + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "timestamp_floor(\"__time\",'P1D',null,'UTC')", + ValueType.LONG, + TestExprMacroTable.INSTANCE + )) + .setDimFilter(not(selector("country", null, null))) + .setDimensions( + new DefaultDimensionSpec( + "v0", + "d0", + ValueType.LONG + ), + new DefaultDimensionSpec( + "country", + "d1" + ) + ) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(queryContext) + .setGranularity(new AllGranularity()) + .build() + ) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setDimensions(new DefaultDimensionSpec("d1", "_d0")) + .setAggregatorSpecs( + new LongMaxAggregatorFactory("_a0", "a0") + ) + .setGranularity(new AllGranularity()) + .setContext(queryContext) + .build() + ), + "j0.", + equalsCondition( + DruidExpression.fromColumn("country"), + DruidExpression.fromColumn("j0._d0") + ), + JoinType.LEFT, + selector("city", "B", null) + ) + ) + .setQuerySegmentSpec(querySegmentSpec(Intervals.of( + "2021-01-01T01:00:00.000Z/2021-01-02T23:59:59.001Z"))) + .setDimensions(new DefaultDimensionSpec("country", "d0")) + .setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0")) + .setGranularity(new AllGranularity()) + .setContext(queryContext) + .build() + ), + ImmutableList.of( + new Object[]{"canada", 4L} + ) + ); + } + + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testCorrelatedSubqueryWithCorrelatedQueryFilter(Map queryContext) throws Exception + { + cannotVectorize(); + + testQuery( + "select country, ANY_VALUE(\n" + + " select max(\"users\") from (\n" + + " select floor(__time to day), count(user) \"users\" from visits f where f.country = visits.country and f.city = 'A' group by 1\n" + + " )\n" + + " ) as \"dailyVisits\"\n" + + "from visits \n" + + " where city = 'B'" + + " group by 1", + queryContext, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + join( + new TableDataSource("visits"), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource( + GroupByQuery.builder() + .setDataSource("visits") + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "timestamp_floor(\"__time\",'P1D',null,'UTC')", + ValueType.LONG, + TestExprMacroTable.INSTANCE + )) + .setDimensions( + new DefaultDimensionSpec( + "v0", + "d0", + ValueType.LONG + ), + new DefaultDimensionSpec( + "country", + "d1" + ) + ) + .setAggregatorSpecs(new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + not(selector("user", null, null)) + )) + .setDimFilter(and( + selector("city", "A", null), + not(selector("country", null, null)) + )) + .setContext(queryContext) + .setGranularity(new AllGranularity()) + .build() + ) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setDimensions(new DefaultDimensionSpec("d1", "_d0")) + .setAggregatorSpecs( + new LongMaxAggregatorFactory("_a0", "a0") + ) + .setGranularity(new AllGranularity()) + .setContext(queryContext) + .build() + ), + "j0.", + equalsCondition( + DruidExpression.fromColumn("country"), + DruidExpression.fromColumn("j0._d0") + ), + JoinType.LEFT, + selector("city", "B", null) + ) + ) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setDimensions(new DefaultDimensionSpec("country", "d0")) + .setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0")) + .setGranularity(new AllGranularity()) + .setContext(queryContext) + .build() + ), + ImmutableList.of( + new Object[]{"canada", 2L} + ) + ); + } + + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testCorrelatedSubqueryWithCorrelatedQueryFilter_Scan(Map queryContext) throws Exception + { + cannotVectorize(); + + testQuery( + "select country, ANY_VALUE(\n" + + " select max(\"users\") from (\n" + + " select floor(__time to day), count(user) \"users\" from visits f where f.country = visits.country and f.city = 'A' group by 1\n" + + " )\n" + + " ) as \"dailyVisits\"\n" + + "from visits \n" + + " where city = 'B'" + + " group by 1", + queryContext, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + join( + new TableDataSource("visits"), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource( + GroupByQuery.builder() + .setDataSource("visits") + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "timestamp_floor(\"__time\",'P1D',null,'UTC')", + ValueType.LONG, + TestExprMacroTable.INSTANCE + )) + .setDimensions( + new DefaultDimensionSpec( + "v0", + "d0", + ValueType.LONG + ), + new DefaultDimensionSpec( + "country", + "d1" + ) + ) + .setAggregatorSpecs(new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + not(selector("user", null, null)) + )) + .setDimFilter(and( + selector("city", "A", null), + not(selector("country", null, null)) + )) + .setContext(queryContext) + .setGranularity(new AllGranularity()) + .build() + ) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setDimensions(new DefaultDimensionSpec("d1", "_d0")) + .setAggregatorSpecs( + new LongMaxAggregatorFactory("_a0", "a0") + ) + .setGranularity(new AllGranularity()) + .setContext(queryContext) + .build() + ), + "j0.", + equalsCondition( + DruidExpression.fromColumn("country"), + DruidExpression.fromColumn("j0._d0") + ), + JoinType.LEFT, + selector("city", "B", null) + ) + ) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setDimensions(new DefaultDimensionSpec("country", "d0")) + .setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0")) + .setGranularity(new AllGranularity()) + .setContext(queryContext) + .build() + ), + ImmutableList.of( + new Object[]{"canada", 2L} + ) + ); + } + + private List getRawRows() + { + return ImmutableList.of( + toRow("2021-01-01T01:00:00Z", ImmutableMap.of("user", "alice", "country", "canada", "city", "A")), + toRow("2021-01-01T02:00:00Z", ImmutableMap.of("user", "alice", "country", "canada", "city", "B")), + toRow("2021-01-01T03:00:00Z", ImmutableMap.of("user", "bob", "country", "canada", "city", "A")), + toRow("2021-01-01T04:00:00Z", ImmutableMap.of("user", "alice", "country", "India", "city", "Y")), + toRow("2021-01-02T01:00:00Z", ImmutableMap.of("user", "alice", "country", "canada", "city", "A")), + toRow("2021-01-02T02:00:00Z", ImmutableMap.of("user", "bob", "country", "canada", "city", "A")), + toRow("2021-01-02T03:00:00Z", ImmutableMap.of("user", "foo", "country", "canada", "city", "B")), + toRow("2021-01-02T04:00:00Z", ImmutableMap.of("user", "bar", "country", "canada", "city", "B")), + toRow("2021-01-02T05:00:00Z", ImmutableMap.of("user", "alice", "country", "India", "city", "X")), + toRow("2021-01-02T06:00:00Z", ImmutableMap.of("user", "bob", "country", "India", "city", "X")), + toRow("2021-01-02T07:00:00Z", ImmutableMap.of("user", "foo", "country", "India", "city", "X")), + toRow("2021-01-03T01:00:00Z", ImmutableMap.of("user", "foo", "country", "USA", "city", "M")) + ); + } + + private MapBasedInputRow toRow(String time, Map event) + { + return new MapBasedInputRow(DateTimes.ISO_DATE_OPTIONAL_TIME.parse(time), DIMENSIONS, event); + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 0dc242a5206..c4279f22838 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -25,7 +25,6 @@ import com.google.common.collect.ImmutableMap; import junitparams.JUnitParamsRunner; import junitparams.Parameters; import org.apache.calcite.plan.RelOptPlanner; -import org.apache.druid.annotations.UsedByJUnitParamsRunner; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.IAE; @@ -85,7 +84,6 @@ import org.apache.druid.query.filter.BoundDimFilter; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; -import org.apache.druid.query.filter.NotDimFilter; import org.apache.druid.query.filter.RegexDimFilter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.query.groupby.GroupByQuery; @@ -5365,14 +5363,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .dataSource( join( join( - new QueryDataSource( - newScanQueryBuilder().dataSource(CalciteTests.DATASOURCE1) - .intervals(querySegmentSpec(Filtration.eternity())) - .columns("dim1", "dim2") - .filters(selector("dim2", "a", null)) - .context(QUERY_CONTEXT_DEFAULT) - .build() - ), + new TableDataSource(CalciteTests.DATASOURCE1), new QueryDataSource( newScanQueryBuilder().dataSource(CalciteTests.DATASOURCE3) .intervals(querySegmentSpec(Filtration.eternity())) @@ -5382,7 +5373,8 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ), "j0.", "(\"dim2\" == \"j0.dim2\")", - JoinType.INNER + JoinType.INNER, + bound("dim2", "a", "a", false, false, null, null) ), new QueryDataSource( newScanQueryBuilder().dataSource(CalciteTests.DATASOURCE1) @@ -13227,15 +13219,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .builder() .setDataSource( join( - new QueryDataSource( - newScanQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) - .intervals(querySegmentSpec(Intervals.of("2001-01-02T00:00:00.000Z/146140482-04-24T15:36:27.903Z"))) - .columns("dim1") - .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) - .context(queryContext) - .build() - ), + new TableDataSource(CalciteTests.DATASOURCE1), new QueryDataSource( newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) @@ -13254,7 +13238,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ) ) .setGranularity(Granularities.ALL) - .setInterval(querySegmentSpec(Filtration.eternity())) + .setInterval(querySegmentSpec(Intervals.of("2001-01-02T00:00:00.000Z/146140482-04-24T15:36:27.903Z"))) .setDimFilter(selector("dim1", "def", null)) .setDimensions( dimensions( @@ -16037,24 +16021,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest newScanQueryBuilder() .dataSource( join( - new QueryDataSource( - newScanQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) - .intervals( - querySegmentSpec( - Intervals.utc( - DateTimes.of("1999-01-01").getMillis(), - JodaUtils.MAX_INSTANT - ) - ) - ) - .filters(new SelectorDimFilter("dim1", "10.1", null)) - .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING)) - .columns(ImmutableList.of("__time", "v0")) - .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) - .context(queryContext) - .build() - ), + new TableDataSource(CalciteTests.DATASOURCE1), new QueryDataSource( newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) @@ -16074,14 +16041,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), "j0.", - equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.v0")), - JoinType.LEFT + equalsCondition(DruidExpression.fromExpression("'10.1'"), DruidExpression.fromColumn("j0.v0")), + JoinType.LEFT, + selector("dim1", "10.1", null) ) ) - .intervals(querySegmentSpec(Filtration.eternity())) - .virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING)) - .columns("__time", "_v0") - .filters(new SelectorDimFilter("v0", "10.1", null)) + .intervals(querySegmentSpec( + Intervals.utc( + DateTimes.of("1999-01-01").getMillis(), + JodaUtils.MAX_INSTANT + ) + )) + .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING)) + .columns("__time", "v0") .context(queryContext) .build() ), @@ -16106,17 +16078,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest newScanQueryBuilder() .dataSource( join( - new QueryDataSource( - newScanQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) - .intervals(querySegmentSpec(Filtration.eternity())) - .filters(new SelectorDimFilter("dim1", "10.1", null)) - .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING)) - .columns(ImmutableList.of("__time", "v0")) - .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) - .context(queryContext) - .build() - ), + new TableDataSource(CalciteTests.DATASOURCE1), new QueryDataSource( newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) @@ -16128,14 +16090,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), "j0.", - equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.dim1")), - JoinType.LEFT + equalsCondition(DruidExpression.fromExpression("'10.1'"), DruidExpression.fromColumn("j0.dim1")), + JoinType.LEFT, + selector("dim1", "10.1", null) ) ) .intervals(querySegmentSpec(Filtration.eternity())) - .virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING)) - .columns("__time", "_v0") - .filters(new SelectorDimFilter("v0", "10.1", null)) + .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING)) + .columns("__time", "v0") .context(queryContext) .build() ), @@ -16160,17 +16122,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest newScanQueryBuilder() .dataSource( join( - new QueryDataSource( - newScanQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) - .intervals(querySegmentSpec(Filtration.eternity())) - .filters(new SelectorDimFilter("dim1", "10.1", null)) - .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING)) - .columns(ImmutableList.of("__time", "v0")) - .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) - .context(queryContext) - .build() - ), + new TableDataSource(CalciteTests.DATASOURCE1), new QueryDataSource( newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) @@ -16182,13 +16134,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), "j0.", - equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.dim1")), - JoinType.LEFT + equalsCondition(DruidExpression.fromExpression("'10.1'"), DruidExpression.fromColumn("j0.dim1")), + JoinType.LEFT, + selector("dim1", "10.1", null) ) ) .intervals(querySegmentSpec(Filtration.eternity())) - .virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING)) - .columns("__time", "_v0") + .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING)) + .columns("__time", "v0") .context(queryContext) .build() ), @@ -16213,17 +16166,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest newScanQueryBuilder() .dataSource( join( - new QueryDataSource( - newScanQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) - .intervals(querySegmentSpec(Filtration.eternity())) - .filters(new SelectorDimFilter("dim1", "10.1", null)) - .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING)) - .columns(ImmutableList.of("__time", "v0")) - .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) - .context(queryContext) - .build() - ), + new TableDataSource(CalciteTests.DATASOURCE1), new QueryDataSource( newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) @@ -16235,14 +16178,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), "j0.", - equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.dim1")), - JoinType.INNER + equalsCondition(DruidExpression.fromExpression("'10.1'"), DruidExpression.fromColumn("j0.dim1")), + JoinType.INNER, + selector("dim1", "10.1", null) ) ) .intervals(querySegmentSpec(Filtration.eternity())) - .virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING)) - .columns("__time", "_v0") - .filters(new NotDimFilter(new SelectorDimFilter("v0", null, null))) + .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING)) + .columns("__time", "v0") .context(queryContext) .build() ), @@ -16267,17 +16210,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest newScanQueryBuilder() .dataSource( join( - new QueryDataSource( - newScanQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) - .intervals(querySegmentSpec(Filtration.eternity())) - .filters(new SelectorDimFilter("dim1", "10.1", null)) - .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING)) - .columns(ImmutableList.of("__time", "v0")) - .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) - .context(queryContext) - .build() - ), + new TableDataSource(CalciteTests.DATASOURCE1), new QueryDataSource( newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) @@ -16289,13 +16222,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), "j0.", - equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.dim1")), - JoinType.INNER + equalsCondition(DruidExpression.fromExpression("'10.1'"), DruidExpression.fromColumn("j0.dim1")), + JoinType.INNER, + selector("dim1", "10.1", null) ) ) .intervals(querySegmentSpec(Filtration.eternity())) - .virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING)) - .columns("__time", "_v0") + .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING)) + .columns("__time", "v0") .context(queryContext) .build() ), @@ -16710,38 +16644,4 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } - /** - * This is a provider of query contexts that should be used by join tests. - * It tests various configs that can be passed to join queries. All the configs provided by this provider should - * have the join query engine return the same results. - */ - public static class QueryContextForJoinProvider - { - @UsedByJUnitParamsRunner - public static Object[] provideQueryContexts() - { - return new Object[] { - // default behavior - QUERY_CONTEXT_DEFAULT, - // filter value re-writes enabled - new ImmutableMap.Builder() - .putAll(QUERY_CONTEXT_DEFAULT) - .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true) - .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, true) - .build(), - // rewrite values enabled but filter re-writes disabled. - // This should be drive the same behavior as the previous config - new ImmutableMap.Builder() - .putAll(QUERY_CONTEXT_DEFAULT) - .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true) - .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false) - .build(), - // filter re-writes disabled - new ImmutableMap.Builder() - .putAll(QUERY_CONTEXT_DEFAULT) - .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false) - .build(), - }; - } - } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidQueryTest.java new file mode 100644 index 00000000000..fe4a41a3684 --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidQueryTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.rel; + +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.query.DataSource; +import org.apache.druid.query.JoinDataSource; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.filter.AndDimFilter; +import org.apache.druid.query.filter.BoundDimFilter; +import org.apache.druid.query.filter.DimFilter; +import org.apache.druid.query.filter.SelectorDimFilter; +import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.join.JoinType; +import org.apache.druid.sql.calcite.filtration.Filtration; +import org.joda.time.Interval; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Collections; + +public class DruidQueryTest +{ + + static { + NullHandling.initializeForTests(); + } + + private final DimFilter selectorFilter = new SelectorDimFilter("column", "value", null); + private final DimFilter otherFilter = new SelectorDimFilter("column_2", "value_2", null); + private final DimFilter filterWithInterval = new AndDimFilter( + selectorFilter, + new BoundDimFilter("__time", "100", "200", false, true, null, null, StringComparators.NUMERIC) + ); + + @Test + public void test_filtration_noJoinAndInterval() + { + DataSource dataSource = new TableDataSource("test"); + Pair pair = DruidQuery.getFiltration( + dataSource, + selectorFilter, + VirtualColumnRegistry.create(RowSignature.empty()) + ); + verify(pair, dataSource, selectorFilter, Intervals.ETERNITY); + } + + @Test + public void test_filtration_intervalInQueryFilter() + { + DataSource dataSource = new TableDataSource("test"); + Pair pair = DruidQuery.getFiltration( + dataSource, + filterWithInterval, + VirtualColumnRegistry.create(RowSignature.empty()) + ); + verify(pair, dataSource, selectorFilter, Intervals.utc(100, 200)); + } + + @Test + public void test_filtration_joinDataSource_intervalInQueryFilter() + { + DataSource dataSource = join(JoinType.INNER, otherFilter); + Pair pair = DruidQuery.getFiltration( + dataSource, + filterWithInterval, + VirtualColumnRegistry.create(RowSignature.empty()) + ); + verify(pair, dataSource, selectorFilter, Intervals.utc(100, 200)); + } + + @Test + public void test_filtration_joinDataSource_intervalInBaseTableFilter_inner() + { + DataSource dataSource = join(JoinType.INNER, filterWithInterval); + DataSource expectedDataSource = join(JoinType.INNER, selectorFilter); + Pair pair = DruidQuery.getFiltration( + dataSource, + otherFilter, + VirtualColumnRegistry.create(RowSignature.empty()) + ); + verify(pair, expectedDataSource, otherFilter, Intervals.utc(100, 200)); + } + + @Test + public void test_filtration_joinDataSource_intervalInBaseTableFilter_left() + { + DataSource dataSource = join(JoinType.LEFT, filterWithInterval); + DataSource expectedDataSource = join(JoinType.LEFT, selectorFilter); + Pair pair = DruidQuery.getFiltration( + dataSource, + otherFilter, + VirtualColumnRegistry.create(RowSignature.empty()) + ); + verify(pair, expectedDataSource, otherFilter, Intervals.utc(100, 200)); + } + + @Test + public void test_filtration_joinDataSource_intervalInBaseTableFilter_right() + { + DataSource dataSource = join(JoinType.RIGHT, filterWithInterval); + DataSource expectedDataSource = join(JoinType.RIGHT, selectorFilter); + Pair pair = DruidQuery.getFiltration( + dataSource, + otherFilter, + VirtualColumnRegistry.create(RowSignature.empty()) + ); + verify(pair, expectedDataSource, otherFilter, Intervals.utc(100, 200)); + } + + @Test + public void test_filtration_joinDataSource_intervalInBaseTableFilter_full() + { + DataSource dataSource = join(JoinType.FULL, filterWithInterval); + DataSource expectedDataSource = join(JoinType.FULL, selectorFilter); + Pair pair = DruidQuery.getFiltration( + dataSource, + otherFilter, + VirtualColumnRegistry.create(RowSignature.empty()) + ); + verify(pair, expectedDataSource, otherFilter, Intervals.utc(100, 200)); + } + + @Test + public void test_filtration_intervalsInBothFilters() + { + DataSource dataSource = join(JoinType.INNER, filterWithInterval); + DataSource expectedDataSource = join(JoinType.INNER, selectorFilter); + DimFilter queryFilter = new AndDimFilter( + otherFilter, + new BoundDimFilter("__time", "150", "250", false, true, null, null, StringComparators.NUMERIC) + + ); + Pair pair = DruidQuery.getFiltration( + dataSource, + queryFilter, + VirtualColumnRegistry.create(RowSignature.empty()) + ); + verify(pair, expectedDataSource, otherFilter, Intervals.utc(150, 200)); + } + + private JoinDataSource join(JoinType joinType, DimFilter filter) + { + return JoinDataSource.create( + new TableDataSource("left"), + new TableDataSource("right"), + "r.", + "c == \"r.c\"", + joinType, + filter, + ExprMacroTable.nil() + ); + } + + private void verify( + Pair pair, + DataSource dataSource, + DimFilter columnFilter, + Interval interval + ) + { + Assert.assertEquals(dataSource, pair.lhs); + Assert.assertEquals("dim-filter: " + pair.rhs.getDimFilter(), columnFilter, pair.rhs.getDimFilter()); + Assert.assertEquals(Collections.singletonList(interval), pair.rhs.getIntervals()); + } +}