diff --git a/processing/src/main/java/org/apache/druid/math/expr/ApplyFunction.java b/processing/src/main/java/org/apache/druid/math/expr/ApplyFunction.java index 6403f2b28c0..e2b1537fb0a 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/ApplyFunction.java +++ b/processing/src/main/java/org/apache/druid/math/expr/ApplyFunction.java @@ -442,11 +442,15 @@ public interface ApplyFunction extends NamedFunction Object[] array = arrayEval.asArray(); if (array == null) { - return ExprEval.of(null); + return ExprEval.ofArray(arrayEval.asArrayType(), null); } SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(arrayEval.elementType(), lambdaExpr, bindings); Object[] filtered = filter(arrayEval.asArray(), lambdaExpr, lambdaBinding).toArray(); + // return null array expr if nothing is left in filtered + if (filtered.length == 0) { + return ExprEval.ofArray(arrayEval.asArrayType(), null); + } return ExprEval.ofArray(arrayEval.asArrayType(), filtered); } diff --git a/processing/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java b/processing/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java index 272d219bd5b..2eb0aadf0bd 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java @@ -93,6 +93,10 @@ public class ApplyFunctionTest extends InitializedNullHandlingTest assertExpr("filter((x) -> x > 2, [1, 2, 3, 4, 5])", new Long[] {3L, 4L, 5L}); assertExpr("filter((x) -> x > 2, b)", new Long[] {3L, 4L, 5L}); + + String dummyNull = null; + assertExpr("filter((x) -> array_contains([], x), ['a', 'b'])", dummyNull); + assertExpr("filter((x) -> array_contains([], x), null)", dummyNull); } @Test diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java index da07083774c..0325dfd16a9 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java @@ -461,20 +461,6 @@ public class MultiValueStringOperatorConversions return null; } - Expr expr = plannerContext.parseExpression(druidExpressions.get(1).getExpression()); - // the right expression must be a literal array for this to work, since we need the values of the column - if (!expr.isLiteral()) { - return null; - } - Object[] lit = expr.eval(InputBindings.nilBindings()).asArray(); - if (lit == null || lit.length == 0) { - return null; - } - HashSet literals = Sets.newHashSetWithExpectedSize(lit.length); - for (Object o : lit) { - literals.add(Evals.asString(o)); - } - final DruidExpression.ExpressionGenerator builder = (args) -> { final StringBuilder expressionBuilder; if (isAllowList()) { @@ -490,7 +476,17 @@ public class MultiValueStringOperatorConversions return expressionBuilder.toString(); }; - if (druidExpressions.get(0).isSimpleExtraction()) { + Expr expr = plannerContext.parseExpression(druidExpressions.get(1).getExpression()); + if (druidExpressions.get(0).isSimpleExtraction() && expr.isLiteral()) { + Object[] lit = expr.eval(InputBindings.nilBindings()).asArray(); + if (lit == null || lit.length == 0) { + return null; + } + HashSet literals = Sets.newHashSetWithExpectedSize(lit.length); + for (Object o : lit) { + literals.add(Evals.asString(o)); + } + DruidExpression druidExpression = DruidExpression.ofVirtualColumn( Calcites.getColumnTypeForRelDataType(rexNode.getType()), builder, diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java index 259d0373f5e..3ec994c2e3e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java @@ -1336,6 +1336,41 @@ public class CalciteMultiValueStringQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testMultiValueListFilterNonLiteral() + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + + testQuery( + "SELECT MV_FILTER_ONLY(dim3, ARRAY[dim2]) FROM druid.numfoo", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + new ExpressionVirtualColumn( + "v0", + "filter((x) -> array_contains(array(\"dim2\"), x), \"dim3\")", + ColumnType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .columns("v0") + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"a"}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{NullHandling.defaultStringValue()} + ) + ); + } + @Test public void testMultiValueListFilterDeny() { @@ -1391,6 +1426,41 @@ public class CalciteMultiValueStringQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testMultiValueListFilterDenyNonLiteral() + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + + testQuery( + "SELECT MV_FILTER_NONE(dim3, ARRAY[dim2]) FROM druid.numfoo", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + new ExpressionVirtualColumn( + "v0", + "filter((x) -> !array_contains(array(\"dim2\"), x), \"dim3\")", + ColumnType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .columns("v0") + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"b"}, + new Object[]{"[\"b\",\"c\"]"}, + new Object[]{"d"}, + new Object[]{""}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{NullHandling.defaultStringValue()} + ) + ); + } + @Test public void testMultiValueListFilterComposed() {