mirror of https://github.com/apache/druid.git
modify multi-value expression transformation behavior to not treat re-use of the same input as a candidate for cartesian mapping (#8957)
This commit is contained in:
parent
0330744793
commit
4327892b84
|
@ -223,8 +223,6 @@ public class Parser
|
|||
*/
|
||||
private static Expr applyUnapplied(Expr expr, List<String> unappliedBindings)
|
||||
{
|
||||
final Map<IdentifierExpr, IdentifierExpr> toReplace = new HashMap<>();
|
||||
|
||||
// filter to get list of IdentifierExpr that are backed by the unapplied bindings
|
||||
final List<IdentifierExpr> args = expr.analyzeInputs()
|
||||
.getFreeVariables()
|
||||
|
@ -236,18 +234,23 @@ public class Parser
|
|||
|
||||
// construct lambda args from list of args to apply. Identifiers in a lambda body have artificial 'binding' values
|
||||
// that is the same as the 'identifier', because the bindings are supplied by the wrapping apply function
|
||||
// replacements are done by binding rather than identifier because repeats of the same input should not result
|
||||
// in a cartesian product
|
||||
final Map<String, IdentifierExpr> toReplace = new HashMap<>();
|
||||
for (IdentifierExpr applyFnArg : args) {
|
||||
IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getIdentifier());
|
||||
lambdaArgs.add(lambdaRewrite);
|
||||
toReplace.put(applyFnArg, lambdaRewrite);
|
||||
if (!toReplace.containsKey(applyFnArg.getBinding())) {
|
||||
IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getBinding());
|
||||
lambdaArgs.add(lambdaRewrite);
|
||||
toReplace.put(applyFnArg.getBinding(), lambdaRewrite);
|
||||
}
|
||||
}
|
||||
|
||||
// rewrite identifiers in the expression which will become the lambda body, so they match the lambda identifiers we
|
||||
// are constructing
|
||||
Expr newExpr = expr.visit(childExpr -> {
|
||||
if (childExpr instanceof IdentifierExpr) {
|
||||
if (toReplace.containsKey(childExpr)) {
|
||||
return toReplace.get(childExpr);
|
||||
if (toReplace.containsKey(((IdentifierExpr) childExpr).getBinding())) {
|
||||
return toReplace.get(((IdentifierExpr) childExpr).getBinding());
|
||||
}
|
||||
}
|
||||
return childExpr;
|
||||
|
@ -257,13 +260,13 @@ public class Parser
|
|||
// wrap an expression in either map or cartesian_map to apply any unapplied identifiers
|
||||
final LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr);
|
||||
final ApplyFunction fn;
|
||||
if (args.size() == 1) {
|
||||
if (lambdaArgs.size() == 1) {
|
||||
fn = new ApplyFunction.MapFunction();
|
||||
} else {
|
||||
fn = new ApplyFunction.CartesianMapFunction();
|
||||
}
|
||||
|
||||
final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(args));
|
||||
final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(lambdaArgs));
|
||||
return magic;
|
||||
}
|
||||
|
||||
|
|
|
@ -400,6 +400,13 @@ public class ParserTest extends InitializedNullHandlingTest
|
|||
"(cast [x, LONG_ARRAY])",
|
||||
ImmutableList.of("x")
|
||||
);
|
||||
|
||||
validateApplyUnapplied(
|
||||
"case_searched((x == 'b'),'b',(x == 'g'),'g','Other')",
|
||||
"(case_searched [(== x b), b, (== x g), g, Other])",
|
||||
"(map ([x] -> (case_searched [(== x b), b, (== x g), g, Other])), [x])",
|
||||
ImmutableList.of("x")
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -424,14 +431,14 @@ public class ParserTest extends InitializedNullHandlingTest
|
|||
validateApplyUnapplied(
|
||||
"x + x",
|
||||
"(+ x x)",
|
||||
"(cartesian_map ([x, x_0] -> (+ x x_0)), [x, x])",
|
||||
"(map ([x] -> (+ x x)), [x])",
|
||||
ImmutableList.of("x")
|
||||
);
|
||||
|
||||
validateApplyUnapplied(
|
||||
"x + x + x",
|
||||
"(+ (+ x x) x)",
|
||||
"(cartesian_map ([x, x_0, x_1] -> (+ (+ x x_0) x_1)), [x, x, x])",
|
||||
"(map ([x] -> (+ (+ x x) x)), [x])",
|
||||
ImmutableList.of("x")
|
||||
);
|
||||
|
||||
|
@ -439,7 +446,7 @@ public class ParserTest extends InitializedNullHandlingTest
|
|||
validateApplyUnapplied(
|
||||
"x + x + x + y + y + y + y + z + z + z",
|
||||
"(+ (+ (+ (+ (+ (+ (+ (+ (+ x x) x) y) y) y) y) z) z) z)",
|
||||
"(cartesian_map ([x, x_0, x_1, y, y_2, y_3, y_4, z, z_5, z_6] -> (+ (+ (+ (+ (+ (+ (+ (+ (+ x x_0) x_1) y) y_2) y_3) y_4) z) z_5) z_6)), [x, x, x, y, y, y, y, z, z, z])",
|
||||
"(cartesian_map ([x, y, z] -> (+ (+ (+ (+ (+ (+ (+ (+ (+ x x) x) y) y) y) y) z) z) z)), [x, y, z])",
|
||||
ImmutableList.of("x", "y", "z")
|
||||
);
|
||||
}
|
||||
|
|
|
@ -67,6 +67,7 @@ import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
|
|||
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
|
||||
import org.apache.druid.segment.writeout.SegmentWriteOutMediumFactory;
|
||||
import org.apache.druid.segment.writeout.TmpFileSegmentWriteOutMediumFactory;
|
||||
import org.apache.druid.testing.InitializedNullHandlingTest;
|
||||
import org.apache.druid.timeline.SegmentId;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
@ -89,7 +90,7 @@ import java.util.Map;
|
|||
/**
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class MultiValuedDimensionTest
|
||||
public class MultiValuedDimensionTest extends InitializedNullHandlingTest
|
||||
{
|
||||
@Parameterized.Parameters(name = "groupby: {0} forceHashAggregation: {2} ({1})")
|
||||
public static Collection<?> constructorFeeder()
|
||||
|
@ -609,8 +610,8 @@ public class MultiValuedDimensionTest
|
|||
List<ResultRow> expectedResults = Arrays.asList(
|
||||
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t3t3", "count", 4L),
|
||||
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t5t5", "count", 4L),
|
||||
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2t1", "count", 2L),
|
||||
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1t2", "count", 2L),
|
||||
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t4t4", "count", 2L),
|
||||
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2t2", "count", 2L),
|
||||
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t7t7", "count", 2L)
|
||||
);
|
||||
|
||||
|
|
|
@ -8652,8 +8652,8 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{"[\"a-lol-a\",\"a-lol-b\",\"b-lol-a\",\"b-lol-b\"]"},
|
||||
new Object[]{"[\"b-lol-b\",\"b-lol-c\",\"c-lol-b\",\"c-lol-c\"]"},
|
||||
new Object[]{"[\"a-lol-a\",\"b-lol-b\"]"},
|
||||
new Object[]{"[\"b-lol-b\",\"c-lol-c\"]"},
|
||||
new Object[]{"[\"d-lol-d\"]"},
|
||||
new Object[]{"[\"-lol-\"]"},
|
||||
new Object[]{nullVal},
|
||||
|
|
Loading…
Reference in New Issue