Constant flattening in math expression (#3090)

* Constant flatteing in math expression

* Addressed comments and fixed some bugs

* Addressed comments
This commit is contained in:
Navis Ryu 2016-11-15 07:14:10 +09:00 committed by Fangjin Yang
parent 9ad34a3f03
commit bb26636289
7 changed files with 213 additions and 104 deletions

View File

@ -20,11 +20,17 @@
package io.druid.math.expr;
import io.druid.common.guava.GuavaUtils;
import io.druid.java.util.common.logger.Logger;
import java.util.Arrays;
import java.util.List;
/**
*/
public class Evals
{
private static final Logger log = new Logger(Evals.class);
public static Number toNumber(Object value)
{
if (value == null) {
@ -40,4 +46,39 @@ public class Evals
}
return longValue;
}
public static boolean isConstant(Expr expr)
{
return expr instanceof ConstantExpr;
}
public static boolean isAllConstants(Expr... exprs)
{
return isAllConstants(Arrays.asList(exprs));
}
public static boolean isAllConstants(List<Expr> exprs)
{
for (Expr expr : exprs) {
if (!(expr instanceof ConstantExpr)) {
return false;
}
}
return true;
}
// for binary operator not providing constructor of form <init>(String, Expr, Expr),
// you should create it explicitly in here
public static Expr binaryOp(BinaryOpExprBase binary, Expr left, Expr right)
{
try {
return binary.getClass()
.getDeclaredConstructor(String.class, Expr.class, Expr.class)
.newInstance(binary.op, left, right);
}
catch (Exception e) {
log.warn(e, "failed to rewrite expression " + binary);
return binary; // best effort.. keep it working
}
}
}

View File

@ -55,9 +55,9 @@ abstract class ConstantExpr implements Expr
class LongExpr extends ConstantExpr
{
private final long value;
private final Long value;
public LongExpr(long value)
public LongExpr(Long value)
{
this.value = value;
}
@ -71,7 +71,7 @@ class LongExpr extends ConstantExpr
@Override
public ExprEval eval(ObjectBinding bindings)
{
return ExprEval.of(value);
return ExprEval.ofLong(value);
}
}
@ -99,9 +99,9 @@ class StringExpr extends ConstantExpr
class DoubleExpr extends ConstantExpr
{
private final double value;
private final Double value;
public DoubleExpr(double value)
public DoubleExpr(Double value)
{
this.value = value;
}
@ -115,11 +115,11 @@ class DoubleExpr extends ConstantExpr
@Override
public ExprEval eval(ObjectBinding bindings)
{
return ExprEval.of(value);
return ExprEval.ofDouble(value);
}
}
class IdentifierExpr extends ConstantExpr
class IdentifierExpr implements Expr
{
private final String value;
@ -139,6 +139,12 @@ class IdentifierExpr extends ConstantExpr
{
return ExprEval.bestEffortOf(bindings.get(value));
}
@Override
public void visit(Visitor visitor)
{
visitor.visit(this);
}
}
class FunctionExpr implements Expr
@ -161,7 +167,7 @@ class FunctionExpr implements Expr
@Override
public ExprEval eval(ObjectBinding bindings)
{
return Parser.func.get(name.toLowerCase()).apply(args, bindings);
return Parser.getFunction(name).apply(args, bindings);
}
@Override
@ -252,6 +258,8 @@ class UnaryNotExpr extends UnaryExpr
}
}
// all concrete subclass of this should have constructor with the form of <init>(String, Expr, Expr)
// if it's not possible, just be sure Evals.binaryOp() can handle that
abstract class BinaryOpExprBase implements Expr
{
protected final String op;

View File

@ -26,14 +26,24 @@ import io.druid.java.util.common.IAE;
*/
public abstract class ExprEval<T>
{
public static ExprEval ofLong(Number longValue)
{
return new LongExprEval(longValue);
}
public static ExprEval of(long longValue)
{
return new LongExprEval(longValue);
}
public static ExprEval of(double longValue)
public static ExprEval ofDouble(Number doubleValue)
{
return new DoubleExprEval(longValue);
return new DoubleExprEval(doubleValue);
}
public static ExprEval of(double doubleValue)
{
return new DoubleExprEval(doubleValue);
}
public static ExprEval of(String stringValue)
@ -108,6 +118,8 @@ public abstract class ExprEval<T>
public abstract ExprEval castTo(ExprType castTo);
public abstract Expr toExpr();
private static abstract class NumericExprEval extends ExprEval<Number> {
private NumericExprEval(Number value)
@ -166,6 +178,12 @@ public abstract class ExprEval<T>
}
throw new IAE("invalid type " + castTo);
}
@Override
public Expr toExpr()
{
return new DoubleExpr(value == null ? null : value.doubleValue());
}
}
private static class LongExprEval extends NumericExprEval
@ -200,6 +218,12 @@ public abstract class ExprEval<T>
}
throw new IAE("invalid type " + castTo);
}
@Override
public Expr toExpr()
{
return new LongExpr(value == null ? null : value.longValue());
}
}
private static class StringExprEval extends ExprEval<String>
@ -258,5 +282,11 @@ public abstract class ExprEval<T>
}
throw new IAE("invalid type " + castTo);
}
@Override
public Expr toExpr()
{
return new StringExpr(value);
}
}
}

View File

@ -285,7 +285,7 @@ public class ExprListenerImpl extends ExprBaseListener
public void exitFunctionExpr(ExprParser.FunctionExprContext ctx)
{
String fnName = ctx.getChild(0).getText();
if (!Parser.func.containsKey(fnName)) {
if (!Parser.hasFunction(fnName)) {
throw new RuntimeException("function " + fnName + " is not defined.");
}

View File

@ -24,6 +24,7 @@ import org.joda.time.DateTime;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import org.joda.time.format.ISODateTimeFormat;
import com.google.common.base.Supplier;
import java.util.List;
@ -35,6 +36,10 @@ interface Function
ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings);
// optional interface to be used when function should be created per reference in expression
interface FunctionFactory extends Supplier<Function>, Function {
}
abstract class SingleParam implements Function
{
@Override

View File

@ -21,11 +21,13 @@ package io.druid.math.expr;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.collect.Lists;
import io.druid.java.util.common.IAE;
import io.druid.java.util.common.logger.Logger;
import io.druid.math.expr.antlr.ExprLexer;
import io.druid.math.expr.antlr.ExprParser;
@ -41,16 +43,20 @@ import java.util.Set;
public class Parser
{
static final Logger log = new Logger(Parser.class);
static final Map<String, Function> func;
private static final Logger log = new Logger(Parser.class);
private static final Map<String, Supplier<Function>> func;
static {
Map<String, Function> functionMap = Maps.newHashMap();
Map<String, Supplier<Function>> functionMap = Maps.newHashMap();
for (Class clazz : Function.class.getClasses()) {
if (!Modifier.isAbstract(clazz.getModifiers()) && Function.class.isAssignableFrom(clazz)) {
try {
Function function = (Function)clazz.newInstance();
functionMap.put(function.name().toLowerCase(), function);
if (function instanceof Function.FunctionFactory) {
functionMap.put(function.name().toLowerCase(), (Supplier<Function>) function);
} else {
functionMap.put(function.name().toLowerCase(), Suppliers.ofInstance(function));
}
}
catch (Exception e) {
log.info("failed to instantiate " + clazz.getName() + ".. ignoring", e);
@ -60,7 +66,25 @@ public class Parser
func = ImmutableMap.copyOf(functionMap);
}
public static Function getFunction(String name) {
Supplier<Function> supplier = func.get(name.toLowerCase());
if (supplier == null) {
throw new IAE("Invalid function name '%s'", name);
}
return supplier.get();
}
public static boolean hasFunction(String name)
{
return func.containsKey(name.toLowerCase());
}
public static Expr parse(String in)
{
return parse(in, true);
}
public static Expr parse(String in, boolean withFlatten)
{
ExprLexer lexer = new ExprLexer(new ANTLRInputStream(in));
CommonTokenStream tokens = new CommonTokenStream(lexer);
@ -70,7 +94,51 @@ public class Parser
ParseTreeWalker walker = new ParseTreeWalker();
ExprListenerImpl listener = new ExprListenerImpl(parseTree);
walker.walk(listener, parseTree);
return listener.getAST();
return withFlatten ? flatten(listener.getAST()) : listener.getAST();
}
public static Expr flatten(Expr expr)
{
if (expr instanceof BinaryOpExprBase) {
BinaryOpExprBase binary = (BinaryOpExprBase) expr;
Expr left = flatten(binary.left);
Expr right = flatten(binary.right);
if (Evals.isAllConstants(left, right)) {
expr = expr.eval(null).toExpr();
} else if (left != binary.left || right != binary.right) {
return Evals.binaryOp(binary, left, right);
}
} else if (expr instanceof UnaryExpr) {
UnaryExpr unary = (UnaryExpr) expr;
Expr eval = flatten(unary.expr);
if (eval instanceof ConstantExpr) {
expr = expr.eval(null).toExpr();
} else if (eval != unary.expr) {
if (expr instanceof UnaryMinusExpr) {
expr = new UnaryMinusExpr(eval);
} else if (expr instanceof UnaryNotExpr) {
expr = new UnaryNotExpr(eval);
} else {
expr = unary; // unknown type..
}
}
} else if (expr instanceof FunctionExpr) {
FunctionExpr functionExpr = (FunctionExpr) expr;
List<Expr> args = functionExpr.args;
boolean flattened = false;
List<Expr> flattening = Lists.newArrayListWithCapacity(args.size());
for (int i = 0; i < args.size(); i++) {
Expr flatten = flatten(args.get(i));
flattened |= flatten != args.get(i);
flattening.add(flatten);
}
if (Evals.isAllConstants(flattening)) {
expr = expr.eval(null).toExpr();
} else if (flattened) {
expr = new FunctionExpr(functionExpr.name, flattening);
}
}
return expr;
}
public static List<String> findRequiredBindings(String in)

View File

@ -49,25 +49,11 @@ public class ParserTest
@Test
public void testSimpleUnaryOps2()
{
String actual = Parser.parse("-1").toString();
String expected = "-1";
Assert.assertEquals(expected, actual);
actual = Parser.parse("--1").toString();
expected = "--1";
Assert.assertEquals(expected, actual);
actual = Parser.parse("-1+2").toString();
expected = "(+ -1 2)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("-1*2").toString();
expected = "(* -1 2)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("-1^2").toString();
expected = "(^ -1 2)";
Assert.assertEquals(expected, actual);
validate("-1", "-1", "-1");
validate("--1", "--1", "1");
validate("-1+2", "(+ -1 2)", "1");
validate("-1*2", "(* -1 2)", "-2");
validate("-1^2", "(^ -1 2)", "1");
}
private void validateParser(String expression, String expected, String identifiers)
@ -103,6 +89,8 @@ public class ParserTest
validateParser("x+y-z", "(- (+ x y) z)", "[x, y, z]");
validateParser("x-y+z", "(+ (- x y) z)", "[x, y, z]");
validateParser("x-y-z", "(- (- x y) z)", "[x, y, z]");
validateParser("x-y-x", "(- (- x y) x)", "[x, y]");
}
@Test
@ -116,99 +104,68 @@ public class ParserTest
@Test
public void testSimpleMultiplicativeOp2()
{
String actual = Parser.parse("1*2*3").toString();
String expected = "(* (* 1 2) 3)";
Assert.assertEquals(expected, actual);
validate("1*2*3", "(* (* 1 2) 3)", "6");
validate("1*2/3", "(/ (* 1 2) 3)", "0");
validate("1/2*3", "(* (/ 1 2) 3)", "0");
validate("1/2/3", "(/ (/ 1 2) 3)", "0");
actual = Parser.parse("1*2/3").toString();
expected = "(/ (* 1 2) 3)";
Assert.assertEquals(expected, actual);
validate("1.0*2*3", "(* (* 1.0 2) 3)", "6.0");
validate("1.0*2/3", "(/ (* 1.0 2) 3)", "0.6666666666666666");
validate("1.0/2*3", "(* (/ 1.0 2) 3)", "1.5");
validate("1.0/2/3", "(/ (/ 1.0 2) 3)", "0.16666666666666666");
actual = Parser.parse("1/2*3").toString();
expected = "(* (/ 1 2) 3)";
Assert.assertEquals(expected, actual);
// partial
validate("1.0*2*x", "(* (* 1.0 2) x)", "(* 2.0 x)");
validate("1.0*2/x", "(/ (* 1.0 2) x)", "(/ 2.0 x)");
validate("1.0/2*x", "(* (/ 1.0 2) x)", "(* 0.5 x)");
validate("1.0/2/x", "(/ (/ 1.0 2) x)", "(/ 0.5 x)");
actual = Parser.parse("1/2/3").toString();
expected = "(/ (/ 1 2) 3)";
Assert.assertEquals(expected, actual);
// not working yet
validate("1.0*x*3", "(* (* 1.0 x) 3)", "(* (* 1.0 x) 3)");
}
@Test
public void testSimpleCarrot1()
{
String actual = Parser.parse("1^2").toString();
String expected = "(^ 1 2)";
Assert.assertEquals(expected, actual);
validate("1^2", "(^ 1 2)", "1");
}
@Test
public void testSimpleCarrot2()
{
String actual = Parser.parse("1^2^3").toString();
String expected = "(^ 1 (^ 2 3))";
Assert.assertEquals(expected, actual);
validate("1^2^3", "(^ 1 (^ 2 3))", "1");
}
@Test
public void testMixed()
{
String actual = Parser.parse("1+2*3").toString();
String expected = "(+ 1 (* 2 3))";
Assert.assertEquals(expected, actual);
validate("1+2*3", "(+ 1 (* 2 3))", "7");
validate("1+(2*3)", "(+ 1 (* 2 3))", "7");
validate("(1+2)*3", "(* (+ 1 2) 3)", "9");
actual = Parser.parse("1+(2*3)").toString();
Assert.assertEquals(expected, actual);
validate("1*2+3", "(+ (* 1 2) 3)", "5");
validate("(1*2)+3", "(+ (* 1 2) 3)", "5");
validate("1*(2+3)", "(* 1 (+ 2 3))", "5");
actual = Parser.parse("(1+2)*3").toString();
expected = "(* (+ 1 2) 3)";
Assert.assertEquals(expected, actual);
validate("1+2^3", "(+ 1 (^ 2 3))", "9");
validate("1+(2^3)", "(+ 1 (^ 2 3))", "9");
validate("(1+2)^3", "(^ (+ 1 2) 3)", "27");
validate("1^2+3", "(+ (^ 1 2) 3)", "4");
validate("(1^2)+3", "(+ (^ 1 2) 3)", "4");
validate("1^(2+3)", "(^ 1 (+ 2 3))", "1");
actual = Parser.parse("1*2+3").toString();
expected = "(+ (* 1 2) 3)";
Assert.assertEquals(expected, actual);
validate("1^2*3+4", "(+ (* (^ 1 2) 3) 4)", "7");
validate("-1^2*-3+-4", "(+ (* (^ -1 2) -3) -4)", "-7");
actual = Parser.parse("(1*2)+3").toString();
Assert.assertEquals(expected, actual);
validate("max(3, 4)", "(max [3, 4])", "4");
validate("min(1, max(3, 4))", "(min [1, (max [3, 4])])", "1");
}
actual = Parser.parse("1*(2+3)").toString();
expected = "(* 1 (+ 2 3))";
Assert.assertEquals(expected, actual);
actual = Parser.parse("1+2^3").toString();
expected = "(+ 1 (^ 2 3))";
Assert.assertEquals(expected, actual);
actual = Parser.parse("1+(2^3)").toString();
expected = "(+ 1 (^ 2 3))";
Assert.assertEquals(expected, actual);
actual = Parser.parse("(1+2)^3").toString();
expected = "(^ (+ 1 2) 3)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("1^2+3").toString();
expected = "(+ (^ 1 2) 3)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("(1^2)+3").toString();
expected = "(+ (^ 1 2) 3)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("1^(2+3)").toString();
expected = "(^ 1 (+ 2 3))";
Assert.assertEquals(expected, actual);
actual = Parser.parse("1^2*3+4").toString();
expected = "(+ (* (^ 1 2) 3) 4)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("-1^-2*-3+-4").toString();
expected = "(+ (* (^ -1 -2) -3) -4)";
Assert.assertEquals(expected, actual);
private void validate(String expression, String withoutFlatten, String withFlatten)
{
Assert.assertEquals(withoutFlatten, Parser.parse(expression, false).toString());
Assert.assertEquals(withFlatten, Parser.parse(expression, true).toString());
}
@Test