From bb2663628961adf173557b07df07a2df039ffe86 Mon Sep 17 00:00:00 2001 From: Navis Ryu Date: Tue, 15 Nov 2016 07:14:10 +0900 Subject: [PATCH] Constant flattening in math expression (#3090) * Constant flatteing in math expression * Addressed comments and fixed some bugs * Addressed comments --- .../main/java/io/druid/math/expr/Evals.java | 41 ++++++ .../main/java/io/druid/math/expr/Expr.java | 24 ++-- .../java/io/druid/math/expr/ExprEval.java | 34 ++++- .../io/druid/math/expr/ExprListenerImpl.java | 2 +- .../java/io/druid/math/expr/Function.java | 5 + .../main/java/io/druid/math/expr/Parser.java | 78 +++++++++- .../java/io/druid/math/expr/ParserTest.java | 133 ++++++------------ 7 files changed, 213 insertions(+), 104 deletions(-) diff --git a/common/src/main/java/io/druid/math/expr/Evals.java b/common/src/main/java/io/druid/math/expr/Evals.java index 2ba6eff8c99..f61488ff0c7 100644 --- a/common/src/main/java/io/druid/math/expr/Evals.java +++ b/common/src/main/java/io/druid/math/expr/Evals.java @@ -20,11 +20,17 @@ package io.druid.math.expr; import io.druid.common.guava.GuavaUtils; +import io.druid.java.util.common.logger.Logger; + +import java.util.Arrays; +import java.util.List; /** */ public class Evals { + private static final Logger log = new Logger(Evals.class); + public static Number toNumber(Object value) { if (value == null) { @@ -40,4 +46,39 @@ public class Evals } return longValue; } + + public static boolean isConstant(Expr expr) + { + return expr instanceof ConstantExpr; + } + + public static boolean isAllConstants(Expr... exprs) + { + return isAllConstants(Arrays.asList(exprs)); + } + + public static boolean isAllConstants(List exprs) + { + for (Expr expr : exprs) { + if (!(expr instanceof ConstantExpr)) { + return false; + } + } + return true; + } + + // for binary operator not providing constructor of form (String, Expr, Expr), + // you should create it explicitly in here + public static Expr binaryOp(BinaryOpExprBase binary, Expr left, Expr right) + { + try { + return binary.getClass() + .getDeclaredConstructor(String.class, Expr.class, Expr.class) + .newInstance(binary.op, left, right); + } + catch (Exception e) { + log.warn(e, "failed to rewrite expression " + binary); + return binary; // best effort.. keep it working + } + } } diff --git a/common/src/main/java/io/druid/math/expr/Expr.java b/common/src/main/java/io/druid/math/expr/Expr.java index 8071ac6b8f8..5f904d94488 100644 --- a/common/src/main/java/io/druid/math/expr/Expr.java +++ b/common/src/main/java/io/druid/math/expr/Expr.java @@ -55,9 +55,9 @@ abstract class ConstantExpr implements Expr class LongExpr extends ConstantExpr { - private final long value; + private final Long value; - public LongExpr(long value) + public LongExpr(Long value) { this.value = value; } @@ -71,7 +71,7 @@ class LongExpr extends ConstantExpr @Override public ExprEval eval(ObjectBinding bindings) { - return ExprEval.of(value); + return ExprEval.ofLong(value); } } @@ -99,9 +99,9 @@ class StringExpr extends ConstantExpr class DoubleExpr extends ConstantExpr { - private final double value; + private final Double value; - public DoubleExpr(double value) + public DoubleExpr(Double value) { this.value = value; } @@ -115,11 +115,11 @@ class DoubleExpr extends ConstantExpr @Override public ExprEval eval(ObjectBinding bindings) { - return ExprEval.of(value); + return ExprEval.ofDouble(value); } } -class IdentifierExpr extends ConstantExpr +class IdentifierExpr implements Expr { private final String value; @@ -139,6 +139,12 @@ class IdentifierExpr extends ConstantExpr { return ExprEval.bestEffortOf(bindings.get(value)); } + + @Override + public void visit(Visitor visitor) + { + visitor.visit(this); + } } class FunctionExpr implements Expr @@ -161,7 +167,7 @@ class FunctionExpr implements Expr @Override public ExprEval eval(ObjectBinding bindings) { - return Parser.func.get(name.toLowerCase()).apply(args, bindings); + return Parser.getFunction(name).apply(args, bindings); } @Override @@ -252,6 +258,8 @@ class UnaryNotExpr extends UnaryExpr } } +// all concrete subclass of this should have constructor with the form of (String, Expr, Expr) +// if it's not possible, just be sure Evals.binaryOp() can handle that abstract class BinaryOpExprBase implements Expr { protected final String op; diff --git a/common/src/main/java/io/druid/math/expr/ExprEval.java b/common/src/main/java/io/druid/math/expr/ExprEval.java index 1012a24f08b..9a766076785 100644 --- a/common/src/main/java/io/druid/math/expr/ExprEval.java +++ b/common/src/main/java/io/druid/math/expr/ExprEval.java @@ -26,14 +26,24 @@ import io.druid.java.util.common.IAE; */ public abstract class ExprEval { + public static ExprEval ofLong(Number longValue) + { + return new LongExprEval(longValue); + } + public static ExprEval of(long longValue) { return new LongExprEval(longValue); } - public static ExprEval of(double longValue) + public static ExprEval ofDouble(Number doubleValue) { - return new DoubleExprEval(longValue); + return new DoubleExprEval(doubleValue); + } + + public static ExprEval of(double doubleValue) + { + return new DoubleExprEval(doubleValue); } public static ExprEval of(String stringValue) @@ -108,6 +118,8 @@ public abstract class ExprEval public abstract ExprEval castTo(ExprType castTo); + public abstract Expr toExpr(); + private static abstract class NumericExprEval extends ExprEval { private NumericExprEval(Number value) @@ -166,6 +178,12 @@ public abstract class ExprEval } throw new IAE("invalid type " + castTo); } + + @Override + public Expr toExpr() + { + return new DoubleExpr(value == null ? null : value.doubleValue()); + } } private static class LongExprEval extends NumericExprEval @@ -200,6 +218,12 @@ public abstract class ExprEval } throw new IAE("invalid type " + castTo); } + + @Override + public Expr toExpr() + { + return new LongExpr(value == null ? null : value.longValue()); + } } private static class StringExprEval extends ExprEval @@ -258,5 +282,11 @@ public abstract class ExprEval } throw new IAE("invalid type " + castTo); } + + @Override + public Expr toExpr() + { + return new StringExpr(value); + } } } diff --git a/common/src/main/java/io/druid/math/expr/ExprListenerImpl.java b/common/src/main/java/io/druid/math/expr/ExprListenerImpl.java index c2d2fdadbcc..cfb9414e101 100644 --- a/common/src/main/java/io/druid/math/expr/ExprListenerImpl.java +++ b/common/src/main/java/io/druid/math/expr/ExprListenerImpl.java @@ -285,7 +285,7 @@ public class ExprListenerImpl extends ExprBaseListener public void exitFunctionExpr(ExprParser.FunctionExprContext ctx) { String fnName = ctx.getChild(0).getText(); - if (!Parser.func.containsKey(fnName)) { + if (!Parser.hasFunction(fnName)) { throw new RuntimeException("function " + fnName + " is not defined."); } diff --git a/common/src/main/java/io/druid/math/expr/Function.java b/common/src/main/java/io/druid/math/expr/Function.java index a27e3c55fb1..2ae94bc79ac 100644 --- a/common/src/main/java/io/druid/math/expr/Function.java +++ b/common/src/main/java/io/druid/math/expr/Function.java @@ -24,6 +24,7 @@ import org.joda.time.DateTime; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.ISODateTimeFormat; +import com.google.common.base.Supplier; import java.util.List; @@ -35,6 +36,10 @@ interface Function ExprEval apply(List args, Expr.ObjectBinding bindings); + // optional interface to be used when function should be created per reference in expression + interface FunctionFactory extends Supplier, Function { + } + abstract class SingleParam implements Function { @Override diff --git a/common/src/main/java/io/druid/math/expr/Parser.java b/common/src/main/java/io/druid/math/expr/Parser.java index 80ab1adb8ad..6bffe83888a 100644 --- a/common/src/main/java/io/druid/math/expr/Parser.java +++ b/common/src/main/java/io/druid/math/expr/Parser.java @@ -21,11 +21,13 @@ package io.druid.math.expr; import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.Lists; +import io.druid.java.util.common.IAE; import io.druid.java.util.common.logger.Logger; import io.druid.math.expr.antlr.ExprLexer; import io.druid.math.expr.antlr.ExprParser; @@ -41,16 +43,20 @@ import java.util.Set; public class Parser { - static final Logger log = new Logger(Parser.class); - static final Map func; + private static final Logger log = new Logger(Parser.class); + private static final Map> func; static { - Map functionMap = Maps.newHashMap(); + Map> functionMap = Maps.newHashMap(); for (Class clazz : Function.class.getClasses()) { if (!Modifier.isAbstract(clazz.getModifiers()) && Function.class.isAssignableFrom(clazz)) { try { Function function = (Function)clazz.newInstance(); - functionMap.put(function.name().toLowerCase(), function); + if (function instanceof Function.FunctionFactory) { + functionMap.put(function.name().toLowerCase(), (Supplier) function); + } else { + functionMap.put(function.name().toLowerCase(), Suppliers.ofInstance(function)); + } } catch (Exception e) { log.info("failed to instantiate " + clazz.getName() + ".. ignoring", e); @@ -60,7 +66,25 @@ public class Parser func = ImmutableMap.copyOf(functionMap); } + public static Function getFunction(String name) { + Supplier supplier = func.get(name.toLowerCase()); + if (supplier == null) { + throw new IAE("Invalid function name '%s'", name); + } + return supplier.get(); + } + + public static boolean hasFunction(String name) + { + return func.containsKey(name.toLowerCase()); + } + public static Expr parse(String in) + { + return parse(in, true); + } + + public static Expr parse(String in, boolean withFlatten) { ExprLexer lexer = new ExprLexer(new ANTLRInputStream(in)); CommonTokenStream tokens = new CommonTokenStream(lexer); @@ -70,7 +94,51 @@ public class Parser ParseTreeWalker walker = new ParseTreeWalker(); ExprListenerImpl listener = new ExprListenerImpl(parseTree); walker.walk(listener, parseTree); - return listener.getAST(); + return withFlatten ? flatten(listener.getAST()) : listener.getAST(); + } + + public static Expr flatten(Expr expr) + { + if (expr instanceof BinaryOpExprBase) { + BinaryOpExprBase binary = (BinaryOpExprBase) expr; + Expr left = flatten(binary.left); + Expr right = flatten(binary.right); + if (Evals.isAllConstants(left, right)) { + expr = expr.eval(null).toExpr(); + } else if (left != binary.left || right != binary.right) { + return Evals.binaryOp(binary, left, right); + } + } else if (expr instanceof UnaryExpr) { + UnaryExpr unary = (UnaryExpr) expr; + Expr eval = flatten(unary.expr); + if (eval instanceof ConstantExpr) { + expr = expr.eval(null).toExpr(); + } else if (eval != unary.expr) { + if (expr instanceof UnaryMinusExpr) { + expr = new UnaryMinusExpr(eval); + } else if (expr instanceof UnaryNotExpr) { + expr = new UnaryNotExpr(eval); + } else { + expr = unary; // unknown type.. + } + } + } else if (expr instanceof FunctionExpr) { + FunctionExpr functionExpr = (FunctionExpr) expr; + List args = functionExpr.args; + boolean flattened = false; + List flattening = Lists.newArrayListWithCapacity(args.size()); + for (int i = 0; i < args.size(); i++) { + Expr flatten = flatten(args.get(i)); + flattened |= flatten != args.get(i); + flattening.add(flatten); + } + if (Evals.isAllConstants(flattening)) { + expr = expr.eval(null).toExpr(); + } else if (flattened) { + expr = new FunctionExpr(functionExpr.name, flattening); + } + } + return expr; } public static List findRequiredBindings(String in) diff --git a/common/src/test/java/io/druid/math/expr/ParserTest.java b/common/src/test/java/io/druid/math/expr/ParserTest.java index b06f01a4eb7..a8d68392d63 100644 --- a/common/src/test/java/io/druid/math/expr/ParserTest.java +++ b/common/src/test/java/io/druid/math/expr/ParserTest.java @@ -49,25 +49,11 @@ public class ParserTest @Test public void testSimpleUnaryOps2() { - String actual = Parser.parse("-1").toString(); - String expected = "-1"; - Assert.assertEquals(expected, actual); - - actual = Parser.parse("--1").toString(); - expected = "--1"; - Assert.assertEquals(expected, actual); - - actual = Parser.parse("-1+2").toString(); - expected = "(+ -1 2)"; - Assert.assertEquals(expected, actual); - - actual = Parser.parse("-1*2").toString(); - expected = "(* -1 2)"; - Assert.assertEquals(expected, actual); - - actual = Parser.parse("-1^2").toString(); - expected = "(^ -1 2)"; - Assert.assertEquals(expected, actual); + validate("-1", "-1", "-1"); + validate("--1", "--1", "1"); + validate("-1+2", "(+ -1 2)", "1"); + validate("-1*2", "(* -1 2)", "-2"); + validate("-1^2", "(^ -1 2)", "1"); } private void validateParser(String expression, String expected, String identifiers) @@ -103,6 +89,8 @@ public class ParserTest validateParser("x+y-z", "(- (+ x y) z)", "[x, y, z]"); validateParser("x-y+z", "(+ (- x y) z)", "[x, y, z]"); validateParser("x-y-z", "(- (- x y) z)", "[x, y, z]"); + + validateParser("x-y-x", "(- (- x y) x)", "[x, y]"); } @Test @@ -116,99 +104,68 @@ public class ParserTest @Test public void testSimpleMultiplicativeOp2() { - String actual = Parser.parse("1*2*3").toString(); - String expected = "(* (* 1 2) 3)"; - Assert.assertEquals(expected, actual); + validate("1*2*3", "(* (* 1 2) 3)", "6"); + validate("1*2/3", "(/ (* 1 2) 3)", "0"); + validate("1/2*3", "(* (/ 1 2) 3)", "0"); + validate("1/2/3", "(/ (/ 1 2) 3)", "0"); - actual = Parser.parse("1*2/3").toString(); - expected = "(/ (* 1 2) 3)"; - Assert.assertEquals(expected, actual); + validate("1.0*2*3", "(* (* 1.0 2) 3)", "6.0"); + validate("1.0*2/3", "(/ (* 1.0 2) 3)", "0.6666666666666666"); + validate("1.0/2*3", "(* (/ 1.0 2) 3)", "1.5"); + validate("1.0/2/3", "(/ (/ 1.0 2) 3)", "0.16666666666666666"); - actual = Parser.parse("1/2*3").toString(); - expected = "(* (/ 1 2) 3)"; - Assert.assertEquals(expected, actual); + // partial + validate("1.0*2*x", "(* (* 1.0 2) x)", "(* 2.0 x)"); + validate("1.0*2/x", "(/ (* 1.0 2) x)", "(/ 2.0 x)"); + validate("1.0/2*x", "(* (/ 1.0 2) x)", "(* 0.5 x)"); + validate("1.0/2/x", "(/ (/ 1.0 2) x)", "(/ 0.5 x)"); - actual = Parser.parse("1/2/3").toString(); - expected = "(/ (/ 1 2) 3)"; - Assert.assertEquals(expected, actual); + // not working yet + validate("1.0*x*3", "(* (* 1.0 x) 3)", "(* (* 1.0 x) 3)"); } @Test public void testSimpleCarrot1() { - String actual = Parser.parse("1^2").toString(); - String expected = "(^ 1 2)"; - Assert.assertEquals(expected, actual); + validate("1^2", "(^ 1 2)", "1"); } @Test public void testSimpleCarrot2() { - String actual = Parser.parse("1^2^3").toString(); - String expected = "(^ 1 (^ 2 3))"; - Assert.assertEquals(expected, actual); + validate("1^2^3", "(^ 1 (^ 2 3))", "1"); } @Test public void testMixed() { - String actual = Parser.parse("1+2*3").toString(); - String expected = "(+ 1 (* 2 3))"; - Assert.assertEquals(expected, actual); + validate("1+2*3", "(+ 1 (* 2 3))", "7"); + validate("1+(2*3)", "(+ 1 (* 2 3))", "7"); + validate("(1+2)*3", "(* (+ 1 2) 3)", "9"); - actual = Parser.parse("1+(2*3)").toString(); - Assert.assertEquals(expected, actual); + validate("1*2+3", "(+ (* 1 2) 3)", "5"); + validate("(1*2)+3", "(+ (* 1 2) 3)", "5"); + validate("1*(2+3)", "(* 1 (+ 2 3))", "5"); - actual = Parser.parse("(1+2)*3").toString(); - expected = "(* (+ 1 2) 3)"; - Assert.assertEquals(expected, actual); + validate("1+2^3", "(+ 1 (^ 2 3))", "9"); + validate("1+(2^3)", "(+ 1 (^ 2 3))", "9"); + validate("(1+2)^3", "(^ (+ 1 2) 3)", "27"); + validate("1^2+3", "(+ (^ 1 2) 3)", "4"); + validate("(1^2)+3", "(+ (^ 1 2) 3)", "4"); + validate("1^(2+3)", "(^ 1 (+ 2 3))", "1"); - actual = Parser.parse("1*2+3").toString(); - expected = "(+ (* 1 2) 3)"; - Assert.assertEquals(expected, actual); + validate("1^2*3+4", "(+ (* (^ 1 2) 3) 4)", "7"); + validate("-1^2*-3+-4", "(+ (* (^ -1 2) -3) -4)", "-7"); - actual = Parser.parse("(1*2)+3").toString(); - Assert.assertEquals(expected, actual); + validate("max(3, 4)", "(max [3, 4])", "4"); + validate("min(1, max(3, 4))", "(min [1, (max [3, 4])])", "1"); + } - actual = Parser.parse("1*(2+3)").toString(); - expected = "(* 1 (+ 2 3))"; - Assert.assertEquals(expected, actual); - - - actual = Parser.parse("1+2^3").toString(); - expected = "(+ 1 (^ 2 3))"; - Assert.assertEquals(expected, actual); - - actual = Parser.parse("1+(2^3)").toString(); - expected = "(+ 1 (^ 2 3))"; - Assert.assertEquals(expected, actual); - - actual = Parser.parse("(1+2)^3").toString(); - expected = "(^ (+ 1 2) 3)"; - Assert.assertEquals(expected, actual); - - - actual = Parser.parse("1^2+3").toString(); - expected = "(+ (^ 1 2) 3)"; - Assert.assertEquals(expected, actual); - - actual = Parser.parse("(1^2)+3").toString(); - expected = "(+ (^ 1 2) 3)"; - Assert.assertEquals(expected, actual); - - actual = Parser.parse("1^(2+3)").toString(); - expected = "(^ 1 (+ 2 3))"; - Assert.assertEquals(expected, actual); - - - actual = Parser.parse("1^2*3+4").toString(); - expected = "(+ (* (^ 1 2) 3) 4)"; - Assert.assertEquals(expected, actual); - - actual = Parser.parse("-1^-2*-3+-4").toString(); - expected = "(+ (* (^ -1 -2) -3) -4)"; - Assert.assertEquals(expected, actual); + private void validate(String expression, String withoutFlatten, String withFlatten) + { + Assert.assertEquals(withoutFlatten, Parser.parse(expression, false).toString()); + Assert.assertEquals(withFlatten, Parser.parse(expression, true).toString()); } @Test