Constant flattening in math expression (#3090)

* Constant flatteing in math expression

* Addressed comments and fixed some bugs

* Addressed comments
This commit is contained in:
Navis Ryu 2016-11-15 07:14:10 +09:00 committed by Fangjin Yang
parent 9ad34a3f03
commit bb26636289
7 changed files with 213 additions and 104 deletions

View File

@ -20,11 +20,17 @@
package io.druid.math.expr; package io.druid.math.expr;
import io.druid.common.guava.GuavaUtils; import io.druid.common.guava.GuavaUtils;
import io.druid.java.util.common.logger.Logger;
import java.util.Arrays;
import java.util.List;
/** /**
*/ */
public class Evals public class Evals
{ {
private static final Logger log = new Logger(Evals.class);
public static Number toNumber(Object value) public static Number toNumber(Object value)
{ {
if (value == null) { if (value == null) {
@ -40,4 +46,39 @@ public class Evals
} }
return longValue; return longValue;
} }
public static boolean isConstant(Expr expr)
{
return expr instanceof ConstantExpr;
}
public static boolean isAllConstants(Expr... exprs)
{
return isAllConstants(Arrays.asList(exprs));
}
public static boolean isAllConstants(List<Expr> exprs)
{
for (Expr expr : exprs) {
if (!(expr instanceof ConstantExpr)) {
return false;
}
}
return true;
}
// for binary operator not providing constructor of form <init>(String, Expr, Expr),
// you should create it explicitly in here
public static Expr binaryOp(BinaryOpExprBase binary, Expr left, Expr right)
{
try {
return binary.getClass()
.getDeclaredConstructor(String.class, Expr.class, Expr.class)
.newInstance(binary.op, left, right);
}
catch (Exception e) {
log.warn(e, "failed to rewrite expression " + binary);
return binary; // best effort.. keep it working
}
}
} }

View File

@ -55,9 +55,9 @@ abstract class ConstantExpr implements Expr
class LongExpr extends ConstantExpr class LongExpr extends ConstantExpr
{ {
private final long value; private final Long value;
public LongExpr(long value) public LongExpr(Long value)
{ {
this.value = value; this.value = value;
} }
@ -71,7 +71,7 @@ class LongExpr extends ConstantExpr
@Override @Override
public ExprEval eval(ObjectBinding bindings) public ExprEval eval(ObjectBinding bindings)
{ {
return ExprEval.of(value); return ExprEval.ofLong(value);
} }
} }
@ -99,9 +99,9 @@ class StringExpr extends ConstantExpr
class DoubleExpr extends ConstantExpr class DoubleExpr extends ConstantExpr
{ {
private final double value; private final Double value;
public DoubleExpr(double value) public DoubleExpr(Double value)
{ {
this.value = value; this.value = value;
} }
@ -115,11 +115,11 @@ class DoubleExpr extends ConstantExpr
@Override @Override
public ExprEval eval(ObjectBinding bindings) public ExprEval eval(ObjectBinding bindings)
{ {
return ExprEval.of(value); return ExprEval.ofDouble(value);
} }
} }
class IdentifierExpr extends ConstantExpr class IdentifierExpr implements Expr
{ {
private final String value; private final String value;
@ -139,6 +139,12 @@ class IdentifierExpr extends ConstantExpr
{ {
return ExprEval.bestEffortOf(bindings.get(value)); return ExprEval.bestEffortOf(bindings.get(value));
} }
@Override
public void visit(Visitor visitor)
{
visitor.visit(this);
}
} }
class FunctionExpr implements Expr class FunctionExpr implements Expr
@ -161,7 +167,7 @@ class FunctionExpr implements Expr
@Override @Override
public ExprEval eval(ObjectBinding bindings) public ExprEval eval(ObjectBinding bindings)
{ {
return Parser.func.get(name.toLowerCase()).apply(args, bindings); return Parser.getFunction(name).apply(args, bindings);
} }
@Override @Override
@ -252,6 +258,8 @@ class UnaryNotExpr extends UnaryExpr
} }
} }
// all concrete subclass of this should have constructor with the form of <init>(String, Expr, Expr)
// if it's not possible, just be sure Evals.binaryOp() can handle that
abstract class BinaryOpExprBase implements Expr abstract class BinaryOpExprBase implements Expr
{ {
protected final String op; protected final String op;

View File

@ -26,14 +26,24 @@ import io.druid.java.util.common.IAE;
*/ */
public abstract class ExprEval<T> public abstract class ExprEval<T>
{ {
public static ExprEval ofLong(Number longValue)
{
return new LongExprEval(longValue);
}
public static ExprEval of(long longValue) public static ExprEval of(long longValue)
{ {
return new LongExprEval(longValue); return new LongExprEval(longValue);
} }
public static ExprEval of(double longValue) public static ExprEval ofDouble(Number doubleValue)
{ {
return new DoubleExprEval(longValue); return new DoubleExprEval(doubleValue);
}
public static ExprEval of(double doubleValue)
{
return new DoubleExprEval(doubleValue);
} }
public static ExprEval of(String stringValue) public static ExprEval of(String stringValue)
@ -108,6 +118,8 @@ public abstract class ExprEval<T>
public abstract ExprEval castTo(ExprType castTo); public abstract ExprEval castTo(ExprType castTo);
public abstract Expr toExpr();
private static abstract class NumericExprEval extends ExprEval<Number> { private static abstract class NumericExprEval extends ExprEval<Number> {
private NumericExprEval(Number value) private NumericExprEval(Number value)
@ -166,6 +178,12 @@ public abstract class ExprEval<T>
} }
throw new IAE("invalid type " + castTo); throw new IAE("invalid type " + castTo);
} }
@Override
public Expr toExpr()
{
return new DoubleExpr(value == null ? null : value.doubleValue());
}
} }
private static class LongExprEval extends NumericExprEval private static class LongExprEval extends NumericExprEval
@ -200,6 +218,12 @@ public abstract class ExprEval<T>
} }
throw new IAE("invalid type " + castTo); throw new IAE("invalid type " + castTo);
} }
@Override
public Expr toExpr()
{
return new LongExpr(value == null ? null : value.longValue());
}
} }
private static class StringExprEval extends ExprEval<String> private static class StringExprEval extends ExprEval<String>
@ -258,5 +282,11 @@ public abstract class ExprEval<T>
} }
throw new IAE("invalid type " + castTo); throw new IAE("invalid type " + castTo);
} }
@Override
public Expr toExpr()
{
return new StringExpr(value);
}
} }
} }

View File

@ -285,7 +285,7 @@ public class ExprListenerImpl extends ExprBaseListener
public void exitFunctionExpr(ExprParser.FunctionExprContext ctx) public void exitFunctionExpr(ExprParser.FunctionExprContext ctx)
{ {
String fnName = ctx.getChild(0).getText(); String fnName = ctx.getChild(0).getText();
if (!Parser.func.containsKey(fnName)) { if (!Parser.hasFunction(fnName)) {
throw new RuntimeException("function " + fnName + " is not defined."); throw new RuntimeException("function " + fnName + " is not defined.");
} }

View File

@ -24,6 +24,7 @@ import org.joda.time.DateTime;
import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.DateTimeFormatter;
import org.joda.time.format.ISODateTimeFormat; import org.joda.time.format.ISODateTimeFormat;
import com.google.common.base.Supplier;
import java.util.List; import java.util.List;
@ -35,6 +36,10 @@ interface Function
ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings); ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings);
// optional interface to be used when function should be created per reference in expression
interface FunctionFactory extends Supplier<Function>, Function {
}
abstract class SingleParam implements Function abstract class SingleParam implements Function
{ {
@Override @Override

View File

@ -21,11 +21,13 @@ package io.druid.math.expr;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import io.druid.java.util.common.IAE;
import io.druid.java.util.common.logger.Logger; import io.druid.java.util.common.logger.Logger;
import io.druid.math.expr.antlr.ExprLexer; import io.druid.math.expr.antlr.ExprLexer;
import io.druid.math.expr.antlr.ExprParser; import io.druid.math.expr.antlr.ExprParser;
@ -41,16 +43,20 @@ import java.util.Set;
public class Parser public class Parser
{ {
static final Logger log = new Logger(Parser.class); private static final Logger log = new Logger(Parser.class);
static final Map<String, Function> func; private static final Map<String, Supplier<Function>> func;
static { static {
Map<String, Function> functionMap = Maps.newHashMap(); Map<String, Supplier<Function>> functionMap = Maps.newHashMap();
for (Class clazz : Function.class.getClasses()) { for (Class clazz : Function.class.getClasses()) {
if (!Modifier.isAbstract(clazz.getModifiers()) && Function.class.isAssignableFrom(clazz)) { if (!Modifier.isAbstract(clazz.getModifiers()) && Function.class.isAssignableFrom(clazz)) {
try { try {
Function function = (Function)clazz.newInstance(); Function function = (Function)clazz.newInstance();
functionMap.put(function.name().toLowerCase(), function); if (function instanceof Function.FunctionFactory) {
functionMap.put(function.name().toLowerCase(), (Supplier<Function>) function);
} else {
functionMap.put(function.name().toLowerCase(), Suppliers.ofInstance(function));
}
} }
catch (Exception e) { catch (Exception e) {
log.info("failed to instantiate " + clazz.getName() + ".. ignoring", e); log.info("failed to instantiate " + clazz.getName() + ".. ignoring", e);
@ -60,7 +66,25 @@ public class Parser
func = ImmutableMap.copyOf(functionMap); func = ImmutableMap.copyOf(functionMap);
} }
public static Function getFunction(String name) {
Supplier<Function> supplier = func.get(name.toLowerCase());
if (supplier == null) {
throw new IAE("Invalid function name '%s'", name);
}
return supplier.get();
}
public static boolean hasFunction(String name)
{
return func.containsKey(name.toLowerCase());
}
public static Expr parse(String in) public static Expr parse(String in)
{
return parse(in, true);
}
public static Expr parse(String in, boolean withFlatten)
{ {
ExprLexer lexer = new ExprLexer(new ANTLRInputStream(in)); ExprLexer lexer = new ExprLexer(new ANTLRInputStream(in));
CommonTokenStream tokens = new CommonTokenStream(lexer); CommonTokenStream tokens = new CommonTokenStream(lexer);
@ -70,7 +94,51 @@ public class Parser
ParseTreeWalker walker = new ParseTreeWalker(); ParseTreeWalker walker = new ParseTreeWalker();
ExprListenerImpl listener = new ExprListenerImpl(parseTree); ExprListenerImpl listener = new ExprListenerImpl(parseTree);
walker.walk(listener, parseTree); walker.walk(listener, parseTree);
return listener.getAST(); return withFlatten ? flatten(listener.getAST()) : listener.getAST();
}
public static Expr flatten(Expr expr)
{
if (expr instanceof BinaryOpExprBase) {
BinaryOpExprBase binary = (BinaryOpExprBase) expr;
Expr left = flatten(binary.left);
Expr right = flatten(binary.right);
if (Evals.isAllConstants(left, right)) {
expr = expr.eval(null).toExpr();
} else if (left != binary.left || right != binary.right) {
return Evals.binaryOp(binary, left, right);
}
} else if (expr instanceof UnaryExpr) {
UnaryExpr unary = (UnaryExpr) expr;
Expr eval = flatten(unary.expr);
if (eval instanceof ConstantExpr) {
expr = expr.eval(null).toExpr();
} else if (eval != unary.expr) {
if (expr instanceof UnaryMinusExpr) {
expr = new UnaryMinusExpr(eval);
} else if (expr instanceof UnaryNotExpr) {
expr = new UnaryNotExpr(eval);
} else {
expr = unary; // unknown type..
}
}
} else if (expr instanceof FunctionExpr) {
FunctionExpr functionExpr = (FunctionExpr) expr;
List<Expr> args = functionExpr.args;
boolean flattened = false;
List<Expr> flattening = Lists.newArrayListWithCapacity(args.size());
for (int i = 0; i < args.size(); i++) {
Expr flatten = flatten(args.get(i));
flattened |= flatten != args.get(i);
flattening.add(flatten);
}
if (Evals.isAllConstants(flattening)) {
expr = expr.eval(null).toExpr();
} else if (flattened) {
expr = new FunctionExpr(functionExpr.name, flattening);
}
}
return expr;
} }
public static List<String> findRequiredBindings(String in) public static List<String> findRequiredBindings(String in)

View File

@ -49,25 +49,11 @@ public class ParserTest
@Test @Test
public void testSimpleUnaryOps2() public void testSimpleUnaryOps2()
{ {
String actual = Parser.parse("-1").toString(); validate("-1", "-1", "-1");
String expected = "-1"; validate("--1", "--1", "1");
Assert.assertEquals(expected, actual); validate("-1+2", "(+ -1 2)", "1");
validate("-1*2", "(* -1 2)", "-2");
actual = Parser.parse("--1").toString(); validate("-1^2", "(^ -1 2)", "1");
expected = "--1";
Assert.assertEquals(expected, actual);
actual = Parser.parse("-1+2").toString();
expected = "(+ -1 2)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("-1*2").toString();
expected = "(* -1 2)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("-1^2").toString();
expected = "(^ -1 2)";
Assert.assertEquals(expected, actual);
} }
private void validateParser(String expression, String expected, String identifiers) private void validateParser(String expression, String expected, String identifiers)
@ -103,6 +89,8 @@ public class ParserTest
validateParser("x+y-z", "(- (+ x y) z)", "[x, y, z]"); validateParser("x+y-z", "(- (+ x y) z)", "[x, y, z]");
validateParser("x-y+z", "(+ (- x y) z)", "[x, y, z]"); validateParser("x-y+z", "(+ (- x y) z)", "[x, y, z]");
validateParser("x-y-z", "(- (- x y) z)", "[x, y, z]"); validateParser("x-y-z", "(- (- x y) z)", "[x, y, z]");
validateParser("x-y-x", "(- (- x y) x)", "[x, y]");
} }
@Test @Test
@ -116,99 +104,68 @@ public class ParserTest
@Test @Test
public void testSimpleMultiplicativeOp2() public void testSimpleMultiplicativeOp2()
{ {
String actual = Parser.parse("1*2*3").toString(); validate("1*2*3", "(* (* 1 2) 3)", "6");
String expected = "(* (* 1 2) 3)"; validate("1*2/3", "(/ (* 1 2) 3)", "0");
Assert.assertEquals(expected, actual); validate("1/2*3", "(* (/ 1 2) 3)", "0");
validate("1/2/3", "(/ (/ 1 2) 3)", "0");
actual = Parser.parse("1*2/3").toString(); validate("1.0*2*3", "(* (* 1.0 2) 3)", "6.0");
expected = "(/ (* 1 2) 3)"; validate("1.0*2/3", "(/ (* 1.0 2) 3)", "0.6666666666666666");
Assert.assertEquals(expected, actual); validate("1.0/2*3", "(* (/ 1.0 2) 3)", "1.5");
validate("1.0/2/3", "(/ (/ 1.0 2) 3)", "0.16666666666666666");
actual = Parser.parse("1/2*3").toString(); // partial
expected = "(* (/ 1 2) 3)"; validate("1.0*2*x", "(* (* 1.0 2) x)", "(* 2.0 x)");
Assert.assertEquals(expected, actual); validate("1.0*2/x", "(/ (* 1.0 2) x)", "(/ 2.0 x)");
validate("1.0/2*x", "(* (/ 1.0 2) x)", "(* 0.5 x)");
validate("1.0/2/x", "(/ (/ 1.0 2) x)", "(/ 0.5 x)");
actual = Parser.parse("1/2/3").toString(); // not working yet
expected = "(/ (/ 1 2) 3)"; validate("1.0*x*3", "(* (* 1.0 x) 3)", "(* (* 1.0 x) 3)");
Assert.assertEquals(expected, actual);
} }
@Test @Test
public void testSimpleCarrot1() public void testSimpleCarrot1()
{ {
String actual = Parser.parse("1^2").toString(); validate("1^2", "(^ 1 2)", "1");
String expected = "(^ 1 2)";
Assert.assertEquals(expected, actual);
} }
@Test @Test
public void testSimpleCarrot2() public void testSimpleCarrot2()
{ {
String actual = Parser.parse("1^2^3").toString(); validate("1^2^3", "(^ 1 (^ 2 3))", "1");
String expected = "(^ 1 (^ 2 3))";
Assert.assertEquals(expected, actual);
} }
@Test @Test
public void testMixed() public void testMixed()
{ {
String actual = Parser.parse("1+2*3").toString(); validate("1+2*3", "(+ 1 (* 2 3))", "7");
String expected = "(+ 1 (* 2 3))"; validate("1+(2*3)", "(+ 1 (* 2 3))", "7");
Assert.assertEquals(expected, actual); validate("(1+2)*3", "(* (+ 1 2) 3)", "9");
actual = Parser.parse("1+(2*3)").toString(); validate("1*2+3", "(+ (* 1 2) 3)", "5");
Assert.assertEquals(expected, actual); validate("(1*2)+3", "(+ (* 1 2) 3)", "5");
validate("1*(2+3)", "(* 1 (+ 2 3))", "5");
actual = Parser.parse("(1+2)*3").toString(); validate("1+2^3", "(+ 1 (^ 2 3))", "9");
expected = "(* (+ 1 2) 3)"; validate("1+(2^3)", "(+ 1 (^ 2 3))", "9");
Assert.assertEquals(expected, actual); validate("(1+2)^3", "(^ (+ 1 2) 3)", "27");
validate("1^2+3", "(+ (^ 1 2) 3)", "4");
validate("(1^2)+3", "(+ (^ 1 2) 3)", "4");
validate("1^(2+3)", "(^ 1 (+ 2 3))", "1");
actual = Parser.parse("1*2+3").toString(); validate("1^2*3+4", "(+ (* (^ 1 2) 3) 4)", "7");
expected = "(+ (* 1 2) 3)"; validate("-1^2*-3+-4", "(+ (* (^ -1 2) -3) -4)", "-7");
Assert.assertEquals(expected, actual);
actual = Parser.parse("(1*2)+3").toString(); validate("max(3, 4)", "(max [3, 4])", "4");
Assert.assertEquals(expected, actual); validate("min(1, max(3, 4))", "(min [1, (max [3, 4])])", "1");
}
actual = Parser.parse("1*(2+3)").toString(); private void validate(String expression, String withoutFlatten, String withFlatten)
expected = "(* 1 (+ 2 3))"; {
Assert.assertEquals(expected, actual); Assert.assertEquals(withoutFlatten, Parser.parse(expression, false).toString());
Assert.assertEquals(withFlatten, Parser.parse(expression, true).toString());
actual = Parser.parse("1+2^3").toString();
expected = "(+ 1 (^ 2 3))";
Assert.assertEquals(expected, actual);
actual = Parser.parse("1+(2^3)").toString();
expected = "(+ 1 (^ 2 3))";
Assert.assertEquals(expected, actual);
actual = Parser.parse("(1+2)^3").toString();
expected = "(^ (+ 1 2) 3)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("1^2+3").toString();
expected = "(+ (^ 1 2) 3)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("(1^2)+3").toString();
expected = "(+ (^ 1 2) 3)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("1^(2+3)").toString();
expected = "(^ 1 (+ 2 3))";
Assert.assertEquals(expected, actual);
actual = Parser.parse("1^2*3+4").toString();
expected = "(+ (* (^ 1 2) 3) 4)";
Assert.assertEquals(expected, actual);
actual = Parser.parse("-1^-2*-3+-4").toString();
expected = "(+ (* (^ -1 -2) -3) -4)";
Assert.assertEquals(expected, actual);
} }
@Test @Test