diff --git a/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java index 74ea48ad4eb..00a119388a6 100644 --- a/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java +++ b/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java @@ -173,7 +173,7 @@ public class UnnestStorageAdapter implements StorageAdapter @Nullable public Filter getUnnestFilter() { - return unnestFilter.toFilter(); + return unnestFilter == null ? null : unnestFilter.toFilter(); } @Override @@ -390,6 +390,9 @@ public class UnnestStorageAdapter implements StorageAdapter } // add the entire query filter to unnest filter to be used in Value matcher filterSplitter.addPostFilterWithPreFilterIfRewritePossible(queryFilter, true); + } else { + // case where the outer filter has reference to the outputcolumn + filterSplitter.addPostFilterWithPreFilterIfRewritePossible(queryFilter, false); } } else { // normal case without any filter on unnested column diff --git a/processing/src/test/java/org/apache/druid/segment/UnnestStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/UnnestStorageAdapterTest.java index f34dcc2b6b2..757a60d59e6 100644 --- a/processing/src/test/java/org/apache/druid/segment/UnnestStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/UnnestStorageAdapterTest.java @@ -63,7 +63,6 @@ public class UnnestStorageAdapterTest extends InitializedNullHandlingTest private static IncrementalIndexStorageAdapter INCREMENTAL_INDEX_STORAGE_ADAPTER; private static UnnestStorageAdapter UNNEST_STORAGE_ADAPTER; private static UnnestStorageAdapter UNNEST_STORAGE_ADAPTER1; - private static UnnestStorageAdapter UNNEST_STORAGE_ADAPTER2; private static List ADAPTERS; private static String COLUMNNAME = "multi-string1"; private static String OUTPUT_COLUMN_NAME = "unnested-multi-string1"; @@ -101,13 +100,6 @@ public class UnnestStorageAdapterTest extends InitializedNullHandlingTest null ); - UNNEST_STORAGE_ADAPTER2 = new UnnestStorageAdapter( - INCREMENTAL_INDEX_STORAGE_ADAPTER, - new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + COLUMNNAME + "\"", null, ExprMacroTable.nil()), - new SelectorDimFilter(OUTPUT_COLUMN_NAME, "1", null) - ); - - ADAPTERS = ImmutableList.of( UNNEST_STORAGE_ADAPTER, UNNEST_STORAGE_ADAPTER1 @@ -344,14 +336,55 @@ public class UnnestStorageAdapterTest extends InitializedNullHandlingTest final Filter postFilter = ((PostJoinCursor) cursor).getPostJoinFilter(); Assert.assertEquals(unnestStorageAdapter.getUnnestFilter(), postFilter); - ColumnSelectorFactory factory = cursor.getColumnSelectorFactory(); - DimensionSelector dimSelector = factory.makeDimensionSelector(DefaultDimensionSpec.of(OUTPUT_COLUMN_NAME)); int count = 0; while (!cursor.isDone()) { - Object dimSelectorVal = dimSelector.getObject(); - if (dimSelectorVal == null) { - Assert.assertNull(dimSelectorVal); - } + cursor.advance(); + count++; + } + Assert.assertEquals(1, count); + return null; + }); + } + + + @Test + public void test_pushdown_filters_unnested_dimension_outside() + { + final UnnestStorageAdapter unnestStorageAdapter = new UnnestStorageAdapter( + new TestStorageAdapter(INCREMENTAL_INDEX), + new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + COLUMNNAME + "\"", null, ExprMacroTable.nil()), + null + ); + + final VirtualColumn vc = unnestStorageAdapter.getUnnestColumn(); + + final String inputColumn = unnestStorageAdapter.getUnnestInputIfDirectAccess(vc); + + final Filter expectedPushDownFilter = + new SelectorDimFilter(inputColumn, "1", null).toFilter(); + + + final Filter queryFilter = new SelectorDimFilter(OUTPUT_COLUMN_NAME, "1", null).toFilter(); + final Sequence cursorSequence = unnestStorageAdapter.makeCursors( + queryFilter, + unnestStorageAdapter.getInterval(), + VirtualColumns.EMPTY, + Granularities.ALL, + false, + null + ); + + final TestStorageAdapter base = (TestStorageAdapter) unnestStorageAdapter.getBaseAdapter(); + final Filter pushDownFilter = base.getPushDownFilter(); + + Assert.assertEquals(expectedPushDownFilter, pushDownFilter); + cursorSequence.accumulate(null, (accumulated, cursor) -> { + Assert.assertEquals(cursor.getClass(), PostJoinCursor.class); + final Filter postFilter = ((PostJoinCursor) cursor).getPostJoinFilter(); + Assert.assertEquals(queryFilter, postFilter); + + int count = 0; + while (!cursor.isDone()) { cursor.advance(); count++; } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidFilterUnnestRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidFilterUnnestRule.java index c732caaa2ae..93bf287cd8f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidFilterUnnestRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidFilterUnnestRule.java @@ -24,6 +24,7 @@ import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Project; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.sql.calcite.rel.DruidUnnestRel; public class DruidFilterUnnestRule extends RelOptRule @@ -94,9 +95,15 @@ public class DruidFilterUnnestRule extends RelOptRule public boolean matches(RelOptRuleCall call) { final Project rightP = call.rel(0); - final SqlKind rightProjectKind = rightP.getChildExps().get(0).getKind(); - // allow rule to trigger only if there's a string CAST or numeric literal cast - return rightP.getProjects().size() == 1 && (rightProjectKind == SqlKind.CAST || rightProjectKind == SqlKind.LITERAL); + if (rightP.getChildExps().size() > 0) { + final SqlKind rightProjectKind = rightP.getChildExps().get(0).getKind(); + final SqlTypeName projectType = rightP.getChildExps().get(0).getType().getSqlTypeName(); + final SqlTypeName unnestDataType = call.rel(1).getRowType().getFieldList().get(0).getType().getSqlTypeName(); + // allow rule to trigger only if project involves a cast on the same row type + return rightP.getProjects().size() == 1 && ((rightProjectKind == SqlKind.CAST || rightProjectKind == SqlKind.LITERAL) + && projectType == unnestDataType); + } + return false; } @Override diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index 8bba413f04c..da9eb754c43 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -35,6 +35,7 @@ import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.UnnestDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; @@ -4101,4 +4102,178 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest ) ); } + + @Test + public void testUnnestWithCountOnColumn() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "SELECT count(*) d3 FROM druid.numfoo, UNNEST(MV_TO_ARRAY(dim3)) as unnested (d3)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE3), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .context(QUERY_CONTEXT_UNNEST) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .build() + ), + ImmutableList.of( + new Object[]{8L} + ) + ); + } + + @Test + public void testUnnestWithGroupByHavingSelector() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "SELECT d3, COUNT(*) FROM druid.numfoo, UNNEST(MV_TO_ARRAY(dim3)) AS unnested(d3) GROUP BY d3 HAVING d3='b'", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE3), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + )) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setContext(QUERY_CONTEXT_UNNEST) + .setDimensions(new DefaultDimensionSpec("j0.unnest", "_d0", ColumnType.STRING)) + .setGranularity(Granularities.ALL) + .setDimFilter(selector("j0.unnest", "b", null)) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_UNNEST) + .build() + ), + ImmutableList.of( + new Object[]{"b", 2L} + ) + ); + } + + @Test + public void testUnnestWithSumOnUnnestedVirtualColumn() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "select sum(c) col from druid.numfoo, unnest(ARRAY[m1,m2]) as u(c)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE3), + expressionVirtualColumn("j0.unnest", "array(\"m1\",\"m2\")", ColumnType.FLOAT_ARRAY), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .context(QUERY_CONTEXT_UNNEST) + .aggregators(aggregators(new DoubleSumAggregatorFactory("a0", "j0.unnest"))) + .build() + ), + ImmutableList.of( + new Object[]{42.0} + ) + ); + } + + @Test + public void testUnnestWithSumOnUnnestedColumn() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "select sum(c) col from druid.numfoo, unnest(mv_to_array(dim3)) as u(c)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE3), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(expressionVirtualColumn("v0", "CAST(\"j0.unnest\", 'DOUBLE')", ColumnType.DOUBLE)) + .context(QUERY_CONTEXT_UNNEST) + .aggregators(aggregators(new DoubleSumAggregatorFactory("a0", "v0"))) + .build() + ), + useDefault ? + ImmutableList.of( + new Object[]{0.0} + ) : + ImmutableList.of( + new Object[]{null} + ) + ); + } + + @Test + public void testUnnestWithGroupByHavingWithWhereOnAggCol() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "SELECT d3, COUNT(*) FROM druid.numfoo, UNNEST(MV_TO_ARRAY(dim3)) AS unnested(d3) WHERE d3 IN ('a','c') GROUP BY d3 HAVING COUNT(*) = 1", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE3), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + new InDimFilter("j0.unnest", ImmutableSet.of("a", "c"), null) + )) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setContext(QUERY_CONTEXT_UNNEST) + .setDimensions(new DefaultDimensionSpec("j0.unnest", "_d0", ColumnType.STRING)) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setHavingSpec(new DimFilterHavingSpec(selector("a0", "1", null), true)) + .setContext(QUERY_CONTEXT_UNNEST) + .build() + ), + ImmutableList.of( + new Object[]{"a", 1L}, + new Object[]{"c", 1L} + ) + ); + } + + @Test + public void testUnnestWithGroupByHavingWithWhereOnUnnestCol() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "SELECT d3, COUNT(*) FROM druid.numfoo, UNNEST(MV_TO_ARRAY(dim3)) AS unnested(d3) WHERE d3 IN ('a','c') GROUP BY d3 HAVING d3='a'", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE3), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + new InDimFilter("j0.unnest", ImmutableSet.of("a", "c"), null) + )) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setContext(QUERY_CONTEXT_UNNEST) + .setDimensions(new DefaultDimensionSpec("j0.unnest", "_d0", ColumnType.STRING)) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setDimFilter(selector("j0.unnest", "a", null)) + .setContext(QUERY_CONTEXT_UNNEST) + .build() + ), + ImmutableList.of( + new Object[]{"a", 1L} + ) + ); + } }