Fixing regression issues on unnest (#13976)

* select sum(c) on an unnested column now does not return 'Type mismatch' error and works properly
* Making sure an inner join query works properly
* Having on unnested column with a group by now works correctly
* count(*) on an unnested query now works correctly
This commit is contained in:
Soumyava 2023-03-30 20:36:43 -07:00 committed by GitHub
parent eb31207402
commit 1eeecf5fb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 236 additions and 18 deletions

View File

@ -173,7 +173,7 @@ public class UnnestStorageAdapter implements StorageAdapter
@Nullable
public Filter getUnnestFilter()
{
return unnestFilter.toFilter();
return unnestFilter == null ? null : unnestFilter.toFilter();
}
@Override
@ -390,6 +390,9 @@ public class UnnestStorageAdapter implements StorageAdapter
}
// add the entire query filter to unnest filter to be used in Value matcher
filterSplitter.addPostFilterWithPreFilterIfRewritePossible(queryFilter, true);
} else {
// case where the outer filter has reference to the outputcolumn
filterSplitter.addPostFilterWithPreFilterIfRewritePossible(queryFilter, false);
}
} else {
// normal case without any filter on unnested column

View File

@ -63,7 +63,6 @@ public class UnnestStorageAdapterTest extends InitializedNullHandlingTest
private static IncrementalIndexStorageAdapter INCREMENTAL_INDEX_STORAGE_ADAPTER;
private static UnnestStorageAdapter UNNEST_STORAGE_ADAPTER;
private static UnnestStorageAdapter UNNEST_STORAGE_ADAPTER1;
private static UnnestStorageAdapter UNNEST_STORAGE_ADAPTER2;
private static List<StorageAdapter> ADAPTERS;
private static String COLUMNNAME = "multi-string1";
private static String OUTPUT_COLUMN_NAME = "unnested-multi-string1";
@ -101,13 +100,6 @@ public class UnnestStorageAdapterTest extends InitializedNullHandlingTest
null
);
UNNEST_STORAGE_ADAPTER2 = new UnnestStorageAdapter(
INCREMENTAL_INDEX_STORAGE_ADAPTER,
new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + COLUMNNAME + "\"", null, ExprMacroTable.nil()),
new SelectorDimFilter(OUTPUT_COLUMN_NAME, "1", null)
);
ADAPTERS = ImmutableList.of(
UNNEST_STORAGE_ADAPTER,
UNNEST_STORAGE_ADAPTER1
@ -344,14 +336,55 @@ public class UnnestStorageAdapterTest extends InitializedNullHandlingTest
final Filter postFilter = ((PostJoinCursor) cursor).getPostJoinFilter();
Assert.assertEquals(unnestStorageAdapter.getUnnestFilter(), postFilter);
ColumnSelectorFactory factory = cursor.getColumnSelectorFactory();
DimensionSelector dimSelector = factory.makeDimensionSelector(DefaultDimensionSpec.of(OUTPUT_COLUMN_NAME));
int count = 0;
while (!cursor.isDone()) {
Object dimSelectorVal = dimSelector.getObject();
if (dimSelectorVal == null) {
Assert.assertNull(dimSelectorVal);
}
cursor.advance();
count++;
}
Assert.assertEquals(1, count);
return null;
});
}
@Test
public void test_pushdown_filters_unnested_dimension_outside()
{
final UnnestStorageAdapter unnestStorageAdapter = new UnnestStorageAdapter(
new TestStorageAdapter(INCREMENTAL_INDEX),
new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + COLUMNNAME + "\"", null, ExprMacroTable.nil()),
null
);
final VirtualColumn vc = unnestStorageAdapter.getUnnestColumn();
final String inputColumn = unnestStorageAdapter.getUnnestInputIfDirectAccess(vc);
final Filter expectedPushDownFilter =
new SelectorDimFilter(inputColumn, "1", null).toFilter();
final Filter queryFilter = new SelectorDimFilter(OUTPUT_COLUMN_NAME, "1", null).toFilter();
final Sequence<Cursor> cursorSequence = unnestStorageAdapter.makeCursors(
queryFilter,
unnestStorageAdapter.getInterval(),
VirtualColumns.EMPTY,
Granularities.ALL,
false,
null
);
final TestStorageAdapter base = (TestStorageAdapter) unnestStorageAdapter.getBaseAdapter();
final Filter pushDownFilter = base.getPushDownFilter();
Assert.assertEquals(expectedPushDownFilter, pushDownFilter);
cursorSequence.accumulate(null, (accumulated, cursor) -> {
Assert.assertEquals(cursor.getClass(), PostJoinCursor.class);
final Filter postFilter = ((PostJoinCursor) cursor).getPostJoinFilter();
Assert.assertEquals(queryFilter, postFilter);
int count = 0;
while (!cursor.isDone()) {
cursor.advance();
count++;
}

View File

@ -24,6 +24,7 @@ import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.sql.calcite.rel.DruidUnnestRel;
public class DruidFilterUnnestRule extends RelOptRule
@ -94,9 +95,15 @@ public class DruidFilterUnnestRule extends RelOptRule
public boolean matches(RelOptRuleCall call)
{
final Project rightP = call.rel(0);
final SqlKind rightProjectKind = rightP.getChildExps().get(0).getKind();
// allow rule to trigger only if there's a string CAST or numeric literal cast
return rightP.getProjects().size() == 1 && (rightProjectKind == SqlKind.CAST || rightProjectKind == SqlKind.LITERAL);
if (rightP.getChildExps().size() > 0) {
final SqlKind rightProjectKind = rightP.getChildExps().get(0).getKind();
final SqlTypeName projectType = rightP.getChildExps().get(0).getType().getSqlTypeName();
final SqlTypeName unnestDataType = call.rel(1).getRowType().getFieldList().get(0).getType().getSqlTypeName();
// allow rule to trigger only if project involves a cast on the same row type
return rightP.getProjects().size() == 1 && ((rightProjectKind == SqlKind.CAST || rightProjectKind == SqlKind.LITERAL)
&& projectType == unnestDataType);
}
return false;
}
@Override

View File

@ -35,6 +35,7 @@ import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnnestDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
@ -4101,4 +4102,178 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
)
);
}
@Test
public void testUnnestWithCountOnColumn()
{
skipVectorize();
cannotVectorize();
testQuery(
"SELECT count(*) d3 FROM druid.numfoo, UNNEST(MV_TO_ARRAY(dim3)) as unnested (d3)",
QUERY_CONTEXT_UNNEST,
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(UnnestDataSource.create(
new TableDataSource(CalciteTests.DATASOURCE3),
expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING),
null
))
.intervals(querySegmentSpec(Filtration.eternity()))
.context(QUERY_CONTEXT_UNNEST)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.build()
),
ImmutableList.of(
new Object[]{8L}
)
);
}
@Test
public void testUnnestWithGroupByHavingSelector()
{
skipVectorize();
cannotVectorize();
testQuery(
"SELECT d3, COUNT(*) FROM druid.numfoo, UNNEST(MV_TO_ARRAY(dim3)) AS unnested(d3) GROUP BY d3 HAVING d3='b'",
QUERY_CONTEXT_UNNEST,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(UnnestDataSource.create(
new TableDataSource(CalciteTests.DATASOURCE3),
expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING),
null
))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setContext(QUERY_CONTEXT_UNNEST)
.setDimensions(new DefaultDimensionSpec("j0.unnest", "_d0", ColumnType.STRING))
.setGranularity(Granularities.ALL)
.setDimFilter(selector("j0.unnest", "b", null))
.setAggregatorSpecs(new CountAggregatorFactory("a0"))
.setContext(QUERY_CONTEXT_UNNEST)
.build()
),
ImmutableList.of(
new Object[]{"b", 2L}
)
);
}
@Test
public void testUnnestWithSumOnUnnestedVirtualColumn()
{
skipVectorize();
cannotVectorize();
testQuery(
"select sum(c) col from druid.numfoo, unnest(ARRAY[m1,m2]) as u(c)",
QUERY_CONTEXT_UNNEST,
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(UnnestDataSource.create(
new TableDataSource(CalciteTests.DATASOURCE3),
expressionVirtualColumn("j0.unnest", "array(\"m1\",\"m2\")", ColumnType.FLOAT_ARRAY),
null
))
.intervals(querySegmentSpec(Filtration.eternity()))
.context(QUERY_CONTEXT_UNNEST)
.aggregators(aggregators(new DoubleSumAggregatorFactory("a0", "j0.unnest")))
.build()
),
ImmutableList.of(
new Object[]{42.0}
)
);
}
@Test
public void testUnnestWithSumOnUnnestedColumn()
{
skipVectorize();
cannotVectorize();
testQuery(
"select sum(c) col from druid.numfoo, unnest(mv_to_array(dim3)) as u(c)",
QUERY_CONTEXT_UNNEST,
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(UnnestDataSource.create(
new TableDataSource(CalciteTests.DATASOURCE3),
expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING),
null
))
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(expressionVirtualColumn("v0", "CAST(\"j0.unnest\", 'DOUBLE')", ColumnType.DOUBLE))
.context(QUERY_CONTEXT_UNNEST)
.aggregators(aggregators(new DoubleSumAggregatorFactory("a0", "v0")))
.build()
),
useDefault ?
ImmutableList.of(
new Object[]{0.0}
) :
ImmutableList.of(
new Object[]{null}
)
);
}
@Test
public void testUnnestWithGroupByHavingWithWhereOnAggCol()
{
skipVectorize();
cannotVectorize();
testQuery(
"SELECT d3, COUNT(*) FROM druid.numfoo, UNNEST(MV_TO_ARRAY(dim3)) AS unnested(d3) WHERE d3 IN ('a','c') GROUP BY d3 HAVING COUNT(*) = 1",
QUERY_CONTEXT_UNNEST,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(UnnestDataSource.create(
new TableDataSource(CalciteTests.DATASOURCE3),
expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING),
new InDimFilter("j0.unnest", ImmutableSet.of("a", "c"), null)
))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setContext(QUERY_CONTEXT_UNNEST)
.setDimensions(new DefaultDimensionSpec("j0.unnest", "_d0", ColumnType.STRING))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(new CountAggregatorFactory("a0"))
.setHavingSpec(new DimFilterHavingSpec(selector("a0", "1", null), true))
.setContext(QUERY_CONTEXT_UNNEST)
.build()
),
ImmutableList.of(
new Object[]{"a", 1L},
new Object[]{"c", 1L}
)
);
}
@Test
public void testUnnestWithGroupByHavingWithWhereOnUnnestCol()
{
skipVectorize();
cannotVectorize();
testQuery(
"SELECT d3, COUNT(*) FROM druid.numfoo, UNNEST(MV_TO_ARRAY(dim3)) AS unnested(d3) WHERE d3 IN ('a','c') GROUP BY d3 HAVING d3='a'",
QUERY_CONTEXT_UNNEST,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(UnnestDataSource.create(
new TableDataSource(CalciteTests.DATASOURCE3),
expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING),
new InDimFilter("j0.unnest", ImmutableSet.of("a", "c"), null)
))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setContext(QUERY_CONTEXT_UNNEST)
.setDimensions(new DefaultDimensionSpec("j0.unnest", "_d0", ColumnType.STRING))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(new CountAggregatorFactory("a0"))
.setDimFilter(selector("j0.unnest", "a", null))
.setContext(QUERY_CONTEXT_UNNEST)
.build()
),
ImmutableList.of(
new Object[]{"a", 1L}
)
);
}
}