From 4327892b84dcb2e96e58017d5238afcc80c81a5b Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Mon, 9 Dec 2019 20:38:15 -0800 Subject: [PATCH] modify multi-value expression transformation behavior to not treat re-use of the same input as a candidate for cartesian mapping (#8957) --- .../org/apache/druid/math/expr/Parser.java | 21 +++++++++++-------- .../apache/druid/math/expr/ParserTest.java | 13 +++++++++--- .../druid/query/MultiValuedDimensionTest.java | 7 ++++--- .../druid/sql/calcite/CalciteQueryTest.java | 4 ++-- 4 files changed, 28 insertions(+), 17 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/Parser.java b/core/src/main/java/org/apache/druid/math/expr/Parser.java index ce0c35f0663..64d5440497e 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Parser.java +++ b/core/src/main/java/org/apache/druid/math/expr/Parser.java @@ -223,8 +223,6 @@ public class Parser */ private static Expr applyUnapplied(Expr expr, List unappliedBindings) { - final Map toReplace = new HashMap<>(); - // filter to get list of IdentifierExpr that are backed by the unapplied bindings final List args = expr.analyzeInputs() .getFreeVariables() @@ -236,18 +234,23 @@ public class Parser // construct lambda args from list of args to apply. Identifiers in a lambda body have artificial 'binding' values // that is the same as the 'identifier', because the bindings are supplied by the wrapping apply function + // replacements are done by binding rather than identifier because repeats of the same input should not result + // in a cartesian product + final Map toReplace = new HashMap<>(); for (IdentifierExpr applyFnArg : args) { - IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getIdentifier()); - lambdaArgs.add(lambdaRewrite); - toReplace.put(applyFnArg, lambdaRewrite); + if (!toReplace.containsKey(applyFnArg.getBinding())) { + IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getBinding()); + lambdaArgs.add(lambdaRewrite); + toReplace.put(applyFnArg.getBinding(), lambdaRewrite); + } } // rewrite identifiers in the expression which will become the lambda body, so they match the lambda identifiers we // are constructing Expr newExpr = expr.visit(childExpr -> { if (childExpr instanceof IdentifierExpr) { - if (toReplace.containsKey(childExpr)) { - return toReplace.get(childExpr); + if (toReplace.containsKey(((IdentifierExpr) childExpr).getBinding())) { + return toReplace.get(((IdentifierExpr) childExpr).getBinding()); } } return childExpr; @@ -257,13 +260,13 @@ public class Parser // wrap an expression in either map or cartesian_map to apply any unapplied identifiers final LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr); final ApplyFunction fn; - if (args.size() == 1) { + if (lambdaArgs.size() == 1) { fn = new ApplyFunction.MapFunction(); } else { fn = new ApplyFunction.CartesianMapFunction(); } - final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(args)); + final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(lambdaArgs)); return magic; } diff --git a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java index 27a9eb4abde..d4361b6e822 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java @@ -400,6 +400,13 @@ public class ParserTest extends InitializedNullHandlingTest "(cast [x, LONG_ARRAY])", ImmutableList.of("x") ); + + validateApplyUnapplied( + "case_searched((x == 'b'),'b',(x == 'g'),'g','Other')", + "(case_searched [(== x b), b, (== x g), g, Other])", + "(map ([x] -> (case_searched [(== x b), b, (== x g), g, Other])), [x])", + ImmutableList.of("x") + ); } @Test @@ -424,14 +431,14 @@ public class ParserTest extends InitializedNullHandlingTest validateApplyUnapplied( "x + x", "(+ x x)", - "(cartesian_map ([x, x_0] -> (+ x x_0)), [x, x])", + "(map ([x] -> (+ x x)), [x])", ImmutableList.of("x") ); validateApplyUnapplied( "x + x + x", "(+ (+ x x) x)", - "(cartesian_map ([x, x_0, x_1] -> (+ (+ x x_0) x_1)), [x, x, x])", + "(map ([x] -> (+ (+ x x) x)), [x])", ImmutableList.of("x") ); @@ -439,7 +446,7 @@ public class ParserTest extends InitializedNullHandlingTest validateApplyUnapplied( "x + x + x + y + y + y + y + z + z + z", "(+ (+ (+ (+ (+ (+ (+ (+ (+ x x) x) y) y) y) y) z) z) z)", - "(cartesian_map ([x, x_0, x_1, y, y_2, y_3, y_4, z, z_5, z_6] -> (+ (+ (+ (+ (+ (+ (+ (+ (+ x x_0) x_1) y) y_2) y_3) y_4) z) z_5) z_6)), [x, x, x, y, y, y, y, z, z, z])", + "(cartesian_map ([x, y, z] -> (+ (+ (+ (+ (+ (+ (+ (+ (+ x x) x) y) y) y) y) z) z) z)), [x, y, z])", ImmutableList.of("x", "y", "z") ); } diff --git a/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java b/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java index da3d4aac541..c97e8f67fa8 100644 --- a/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java +++ b/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java @@ -67,6 +67,7 @@ import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; import org.apache.druid.segment.writeout.SegmentWriteOutMediumFactory; import org.apache.druid.segment.writeout.TmpFileSegmentWriteOutMediumFactory; +import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.timeline.SegmentId; import org.junit.After; import org.junit.Before; @@ -89,7 +90,7 @@ import java.util.Map; /** */ @RunWith(Parameterized.class) -public class MultiValuedDimensionTest +public class MultiValuedDimensionTest extends InitializedNullHandlingTest { @Parameterized.Parameters(name = "groupby: {0} forceHashAggregation: {2} ({1})") public static Collection constructorFeeder() @@ -609,8 +610,8 @@ public class MultiValuedDimensionTest List expectedResults = Arrays.asList( GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t3t3", "count", 4L), GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t5t5", "count", 4L), - GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2t1", "count", 2L), - GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1t2", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t4t4", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2t2", "count", 2L), GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t7t7", "count", 2L) ); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 0f9ffb764e7..09a38d56542 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -8652,8 +8652,8 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), ImmutableList.of( - new Object[]{"[\"a-lol-a\",\"a-lol-b\",\"b-lol-a\",\"b-lol-b\"]"}, - new Object[]{"[\"b-lol-b\",\"b-lol-c\",\"c-lol-b\",\"c-lol-c\"]"}, + new Object[]{"[\"a-lol-a\",\"b-lol-b\"]"}, + new Object[]{"[\"b-lol-b\",\"c-lol-c\"]"}, new Object[]{"[\"d-lol-d\"]"}, new Object[]{"[\"-lol-\"]"}, new Object[]{nullVal},