diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java index d07586ec2fb..ec6f8c7ee30 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java @@ -29,12 +29,14 @@ import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; import javax.annotation.Nullable; +import java.util.ArrayDeque; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Queue; import java.util.stream.Collectors; /** @@ -198,13 +200,15 @@ public class VirtualColumnRegistry public void visitAllSubExpressions(DruidExpression.DruidExpressionShuttle shuttle) { - for (Map.Entry entry : virtualColumnsByName.entrySet()) { + final Queue> toVisit = new ArrayDeque<>(virtualColumnsByName.entrySet()); + while (!toVisit.isEmpty()) { + final Map.Entry entry = toVisit.poll(); final String key = entry.getKey(); final ExpressionAndTypeHint wrapped = entry.getValue(); - virtualColumnsByExpression.remove(wrapped); final List newArgs = shuttle.visitAll(wrapped.getExpression().getArguments()); final ExpressionAndTypeHint newWrapped = wrap(wrapped.getExpression().withArguments(newArgs), wrapped.getTypeHint()); virtualColumnsByName.put(key, newWrapped); + virtualColumnsByExpression.remove(wrapped); virtualColumnsByExpression.put(newWrapped, key); } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java index 28fd4d463f6..a62402cba40 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java @@ -1428,6 +1428,67 @@ public class CalciteMultiValueStringQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testMultiValueListFilterComposedMultipleExpressions() throws Exception + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + + testQuery( + "SELECT MV_LENGTH(MV_FILTER_ONLY(dim3, ARRAY['b'])), MV_LENGTH(dim3), SUM(cnt) FROM druid.numfoo GROUP BY 1,2 ORDER BY 3 DESC", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "array_length(\"v2\")", + ColumnType.LONG + ), + expressionVirtualColumn( + "v1", + "array_length(\"dim3\")", + ColumnType.LONG + ), + new ListFilteredVirtualColumn( + "v2", + DefaultDimensionSpec.of("dim3"), + ImmutableSet.of("b"), + true + ) + ) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "_d0", ColumnType.LONG), + new DefaultDimensionSpec("v1", "_d1", ColumnType.LONG) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + )), + Integer.MAX_VALUE + )) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + useDefault ? ImmutableList.of( + new Object[]{0, 0, 3L}, + new Object[]{1, 2, 2L}, + new Object[]{0, 1, 1L} + ) : ImmutableList.of( + new Object[]{null, null, 2L}, + new Object[]{null, 1, 2L}, + new Object[]{1, 2, 2L} + ) + ); + } + @Test public void testFilterOnMultiValueListFilterNoMatch() throws Exception {