modify multi-value expression transformation behavior to not treat re-use of the same input as a candidate for cartesian mapping (#8957)

This commit is contained in:
Clint Wylie 2019-12-09 20:38:15 -08:00 committed by Jonathan Wei
parent 0330744793
commit 4327892b84
4 changed files with 28 additions and 17 deletions

View File

@ -223,8 +223,6 @@ public class Parser
*/
private static Expr applyUnapplied(Expr expr, List<String> unappliedBindings)
{
final Map<IdentifierExpr, IdentifierExpr> toReplace = new HashMap<>();
// filter to get list of IdentifierExpr that are backed by the unapplied bindings
final List<IdentifierExpr> args = expr.analyzeInputs()
.getFreeVariables()
@ -236,18 +234,23 @@ public class Parser
// construct lambda args from list of args to apply. Identifiers in a lambda body have artificial 'binding' values
// that is the same as the 'identifier', because the bindings are supplied by the wrapping apply function
// replacements are done by binding rather than identifier because repeats of the same input should not result
// in a cartesian product
final Map<String, IdentifierExpr> toReplace = new HashMap<>();
for (IdentifierExpr applyFnArg : args) {
IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getIdentifier());
lambdaArgs.add(lambdaRewrite);
toReplace.put(applyFnArg, lambdaRewrite);
if (!toReplace.containsKey(applyFnArg.getBinding())) {
IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getBinding());
lambdaArgs.add(lambdaRewrite);
toReplace.put(applyFnArg.getBinding(), lambdaRewrite);
}
}
// rewrite identifiers in the expression which will become the lambda body, so they match the lambda identifiers we
// are constructing
Expr newExpr = expr.visit(childExpr -> {
if (childExpr instanceof IdentifierExpr) {
if (toReplace.containsKey(childExpr)) {
return toReplace.get(childExpr);
if (toReplace.containsKey(((IdentifierExpr) childExpr).getBinding())) {
return toReplace.get(((IdentifierExpr) childExpr).getBinding());
}
}
return childExpr;
@ -257,13 +260,13 @@ public class Parser
// wrap an expression in either map or cartesian_map to apply any unapplied identifiers
final LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr);
final ApplyFunction fn;
if (args.size() == 1) {
if (lambdaArgs.size() == 1) {
fn = new ApplyFunction.MapFunction();
} else {
fn = new ApplyFunction.CartesianMapFunction();
}
final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(args));
final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(lambdaArgs));
return magic;
}

View File

@ -400,6 +400,13 @@ public class ParserTest extends InitializedNullHandlingTest
"(cast [x, LONG_ARRAY])",
ImmutableList.of("x")
);
validateApplyUnapplied(
"case_searched((x == 'b'),'b',(x == 'g'),'g','Other')",
"(case_searched [(== x b), b, (== x g), g, Other])",
"(map ([x] -> (case_searched [(== x b), b, (== x g), g, Other])), [x])",
ImmutableList.of("x")
);
}
@Test
@ -424,14 +431,14 @@ public class ParserTest extends InitializedNullHandlingTest
validateApplyUnapplied(
"x + x",
"(+ x x)",
"(cartesian_map ([x, x_0] -> (+ x x_0)), [x, x])",
"(map ([x] -> (+ x x)), [x])",
ImmutableList.of("x")
);
validateApplyUnapplied(
"x + x + x",
"(+ (+ x x) x)",
"(cartesian_map ([x, x_0, x_1] -> (+ (+ x x_0) x_1)), [x, x, x])",
"(map ([x] -> (+ (+ x x) x)), [x])",
ImmutableList.of("x")
);
@ -439,7 +446,7 @@ public class ParserTest extends InitializedNullHandlingTest
validateApplyUnapplied(
"x + x + x + y + y + y + y + z + z + z",
"(+ (+ (+ (+ (+ (+ (+ (+ (+ x x) x) y) y) y) y) z) z) z)",
"(cartesian_map ([x, x_0, x_1, y, y_2, y_3, y_4, z, z_5, z_6] -> (+ (+ (+ (+ (+ (+ (+ (+ (+ x x_0) x_1) y) y_2) y_3) y_4) z) z_5) z_6)), [x, x, x, y, y, y, y, z, z, z])",
"(cartesian_map ([x, y, z] -> (+ (+ (+ (+ (+ (+ (+ (+ (+ x x) x) y) y) y) y) z) z) z)), [x, y, z])",
ImmutableList.of("x", "y", "z")
);
}

View File

@ -67,6 +67,7 @@ import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.segment.writeout.SegmentWriteOutMediumFactory;
import org.apache.druid.segment.writeout.TmpFileSegmentWriteOutMediumFactory;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.apache.druid.timeline.SegmentId;
import org.junit.After;
import org.junit.Before;
@ -89,7 +90,7 @@ import java.util.Map;
/**
*/
@RunWith(Parameterized.class)
public class MultiValuedDimensionTest
public class MultiValuedDimensionTest extends InitializedNullHandlingTest
{
@Parameterized.Parameters(name = "groupby: {0} forceHashAggregation: {2} ({1})")
public static Collection<?> constructorFeeder()
@ -609,8 +610,8 @@ public class MultiValuedDimensionTest
List<ResultRow> expectedResults = Arrays.asList(
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t3t3", "count", 4L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t5t5", "count", 4L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2t1", "count", 2L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1t2", "count", 2L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t4t4", "count", 2L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2t2", "count", 2L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t7t7", "count", 2L)
);

View File

@ -8652,8 +8652,8 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()
),
ImmutableList.of(
new Object[]{"[\"a-lol-a\",\"a-lol-b\",\"b-lol-a\",\"b-lol-b\"]"},
new Object[]{"[\"b-lol-b\",\"b-lol-c\",\"c-lol-b\",\"c-lol-c\"]"},
new Object[]{"[\"a-lol-a\",\"b-lol-b\"]"},
new Object[]{"[\"b-lol-b\",\"c-lol-c\"]"},
new Object[]{"[\"d-lol-d\"]"},
new Object[]{"[\"-lol-\"]"},
new Object[]{nullVal},