mirror of https://github.com/apache/druid.git
Constant flattening in math expression (#3090)
* Constant flatteing in math expression * Addressed comments and fixed some bugs * Addressed comments
This commit is contained in:
parent
9ad34a3f03
commit
bb26636289
|
@ -20,11 +20,17 @@
|
|||
package io.druid.math.expr;
|
||||
|
||||
import io.druid.common.guava.GuavaUtils;
|
||||
import io.druid.java.util.common.logger.Logger;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
*/
|
||||
public class Evals
|
||||
{
|
||||
private static final Logger log = new Logger(Evals.class);
|
||||
|
||||
public static Number toNumber(Object value)
|
||||
{
|
||||
if (value == null) {
|
||||
|
@ -40,4 +46,39 @@ public class Evals
|
|||
}
|
||||
return longValue;
|
||||
}
|
||||
|
||||
public static boolean isConstant(Expr expr)
|
||||
{
|
||||
return expr instanceof ConstantExpr;
|
||||
}
|
||||
|
||||
public static boolean isAllConstants(Expr... exprs)
|
||||
{
|
||||
return isAllConstants(Arrays.asList(exprs));
|
||||
}
|
||||
|
||||
public static boolean isAllConstants(List<Expr> exprs)
|
||||
{
|
||||
for (Expr expr : exprs) {
|
||||
if (!(expr instanceof ConstantExpr)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// for binary operator not providing constructor of form <init>(String, Expr, Expr),
|
||||
// you should create it explicitly in here
|
||||
public static Expr binaryOp(BinaryOpExprBase binary, Expr left, Expr right)
|
||||
{
|
||||
try {
|
||||
return binary.getClass()
|
||||
.getDeclaredConstructor(String.class, Expr.class, Expr.class)
|
||||
.newInstance(binary.op, left, right);
|
||||
}
|
||||
catch (Exception e) {
|
||||
log.warn(e, "failed to rewrite expression " + binary);
|
||||
return binary; // best effort.. keep it working
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,9 +55,9 @@ abstract class ConstantExpr implements Expr
|
|||
|
||||
class LongExpr extends ConstantExpr
|
||||
{
|
||||
private final long value;
|
||||
private final Long value;
|
||||
|
||||
public LongExpr(long value)
|
||||
public LongExpr(Long value)
|
||||
{
|
||||
this.value = value;
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ class LongExpr extends ConstantExpr
|
|||
@Override
|
||||
public ExprEval eval(ObjectBinding bindings)
|
||||
{
|
||||
return ExprEval.of(value);
|
||||
return ExprEval.ofLong(value);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,9 +99,9 @@ class StringExpr extends ConstantExpr
|
|||
|
||||
class DoubleExpr extends ConstantExpr
|
||||
{
|
||||
private final double value;
|
||||
private final Double value;
|
||||
|
||||
public DoubleExpr(double value)
|
||||
public DoubleExpr(Double value)
|
||||
{
|
||||
this.value = value;
|
||||
}
|
||||
|
@ -115,11 +115,11 @@ class DoubleExpr extends ConstantExpr
|
|||
@Override
|
||||
public ExprEval eval(ObjectBinding bindings)
|
||||
{
|
||||
return ExprEval.of(value);
|
||||
return ExprEval.ofDouble(value);
|
||||
}
|
||||
}
|
||||
|
||||
class IdentifierExpr extends ConstantExpr
|
||||
class IdentifierExpr implements Expr
|
||||
{
|
||||
private final String value;
|
||||
|
||||
|
@ -139,6 +139,12 @@ class IdentifierExpr extends ConstantExpr
|
|||
{
|
||||
return ExprEval.bestEffortOf(bindings.get(value));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(Visitor visitor)
|
||||
{
|
||||
visitor.visit(this);
|
||||
}
|
||||
}
|
||||
|
||||
class FunctionExpr implements Expr
|
||||
|
@ -161,7 +167,7 @@ class FunctionExpr implements Expr
|
|||
@Override
|
||||
public ExprEval eval(ObjectBinding bindings)
|
||||
{
|
||||
return Parser.func.get(name.toLowerCase()).apply(args, bindings);
|
||||
return Parser.getFunction(name).apply(args, bindings);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -252,6 +258,8 @@ class UnaryNotExpr extends UnaryExpr
|
|||
}
|
||||
}
|
||||
|
||||
// all concrete subclass of this should have constructor with the form of <init>(String, Expr, Expr)
|
||||
// if it's not possible, just be sure Evals.binaryOp() can handle that
|
||||
abstract class BinaryOpExprBase implements Expr
|
||||
{
|
||||
protected final String op;
|
||||
|
|
|
@ -26,14 +26,24 @@ import io.druid.java.util.common.IAE;
|
|||
*/
|
||||
public abstract class ExprEval<T>
|
||||
{
|
||||
public static ExprEval ofLong(Number longValue)
|
||||
{
|
||||
return new LongExprEval(longValue);
|
||||
}
|
||||
|
||||
public static ExprEval of(long longValue)
|
||||
{
|
||||
return new LongExprEval(longValue);
|
||||
}
|
||||
|
||||
public static ExprEval of(double longValue)
|
||||
public static ExprEval ofDouble(Number doubleValue)
|
||||
{
|
||||
return new DoubleExprEval(longValue);
|
||||
return new DoubleExprEval(doubleValue);
|
||||
}
|
||||
|
||||
public static ExprEval of(double doubleValue)
|
||||
{
|
||||
return new DoubleExprEval(doubleValue);
|
||||
}
|
||||
|
||||
public static ExprEval of(String stringValue)
|
||||
|
@ -108,6 +118,8 @@ public abstract class ExprEval<T>
|
|||
|
||||
public abstract ExprEval castTo(ExprType castTo);
|
||||
|
||||
public abstract Expr toExpr();
|
||||
|
||||
private static abstract class NumericExprEval extends ExprEval<Number> {
|
||||
|
||||
private NumericExprEval(Number value)
|
||||
|
@ -166,6 +178,12 @@ public abstract class ExprEval<T>
|
|||
}
|
||||
throw new IAE("invalid type " + castTo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expr toExpr()
|
||||
{
|
||||
return new DoubleExpr(value == null ? null : value.doubleValue());
|
||||
}
|
||||
}
|
||||
|
||||
private static class LongExprEval extends NumericExprEval
|
||||
|
@ -200,6 +218,12 @@ public abstract class ExprEval<T>
|
|||
}
|
||||
throw new IAE("invalid type " + castTo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expr toExpr()
|
||||
{
|
||||
return new LongExpr(value == null ? null : value.longValue());
|
||||
}
|
||||
}
|
||||
|
||||
private static class StringExprEval extends ExprEval<String>
|
||||
|
@ -258,5 +282,11 @@ public abstract class ExprEval<T>
|
|||
}
|
||||
throw new IAE("invalid type " + castTo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expr toExpr()
|
||||
{
|
||||
return new StringExpr(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -285,7 +285,7 @@ public class ExprListenerImpl extends ExprBaseListener
|
|||
public void exitFunctionExpr(ExprParser.FunctionExprContext ctx)
|
||||
{
|
||||
String fnName = ctx.getChild(0).getText();
|
||||
if (!Parser.func.containsKey(fnName)) {
|
||||
if (!Parser.hasFunction(fnName)) {
|
||||
throw new RuntimeException("function " + fnName + " is not defined.");
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.joda.time.DateTime;
|
|||
import org.joda.time.format.DateTimeFormat;
|
||||
import org.joda.time.format.DateTimeFormatter;
|
||||
import org.joda.time.format.ISODateTimeFormat;
|
||||
import com.google.common.base.Supplier;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -35,6 +36,10 @@ interface Function
|
|||
|
||||
ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings);
|
||||
|
||||
// optional interface to be used when function should be created per reference in expression
|
||||
interface FunctionFactory extends Supplier<Function>, Function {
|
||||
}
|
||||
|
||||
abstract class SingleParam implements Function
|
||||
{
|
||||
@Override
|
||||
|
|
|
@ -21,11 +21,13 @@ package io.druid.math.expr;
|
|||
|
||||
|
||||
import com.google.common.base.Supplier;
|
||||
import com.google.common.base.Suppliers;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import io.druid.java.util.common.IAE;
|
||||
import io.druid.java.util.common.logger.Logger;
|
||||
import io.druid.math.expr.antlr.ExprLexer;
|
||||
import io.druid.math.expr.antlr.ExprParser;
|
||||
|
@ -41,16 +43,20 @@ import java.util.Set;
|
|||
|
||||
public class Parser
|
||||
{
|
||||
static final Logger log = new Logger(Parser.class);
|
||||
static final Map<String, Function> func;
|
||||
private static final Logger log = new Logger(Parser.class);
|
||||
private static final Map<String, Supplier<Function>> func;
|
||||
|
||||
static {
|
||||
Map<String, Function> functionMap = Maps.newHashMap();
|
||||
Map<String, Supplier<Function>> functionMap = Maps.newHashMap();
|
||||
for (Class clazz : Function.class.getClasses()) {
|
||||
if (!Modifier.isAbstract(clazz.getModifiers()) && Function.class.isAssignableFrom(clazz)) {
|
||||
try {
|
||||
Function function = (Function)clazz.newInstance();
|
||||
functionMap.put(function.name().toLowerCase(), function);
|
||||
if (function instanceof Function.FunctionFactory) {
|
||||
functionMap.put(function.name().toLowerCase(), (Supplier<Function>) function);
|
||||
} else {
|
||||
functionMap.put(function.name().toLowerCase(), Suppliers.ofInstance(function));
|
||||
}
|
||||
}
|
||||
catch (Exception e) {
|
||||
log.info("failed to instantiate " + clazz.getName() + ".. ignoring", e);
|
||||
|
@ -60,7 +66,25 @@ public class Parser
|
|||
func = ImmutableMap.copyOf(functionMap);
|
||||
}
|
||||
|
||||
public static Function getFunction(String name) {
|
||||
Supplier<Function> supplier = func.get(name.toLowerCase());
|
||||
if (supplier == null) {
|
||||
throw new IAE("Invalid function name '%s'", name);
|
||||
}
|
||||
return supplier.get();
|
||||
}
|
||||
|
||||
public static boolean hasFunction(String name)
|
||||
{
|
||||
return func.containsKey(name.toLowerCase());
|
||||
}
|
||||
|
||||
public static Expr parse(String in)
|
||||
{
|
||||
return parse(in, true);
|
||||
}
|
||||
|
||||
public static Expr parse(String in, boolean withFlatten)
|
||||
{
|
||||
ExprLexer lexer = new ExprLexer(new ANTLRInputStream(in));
|
||||
CommonTokenStream tokens = new CommonTokenStream(lexer);
|
||||
|
@ -70,7 +94,51 @@ public class Parser
|
|||
ParseTreeWalker walker = new ParseTreeWalker();
|
||||
ExprListenerImpl listener = new ExprListenerImpl(parseTree);
|
||||
walker.walk(listener, parseTree);
|
||||
return listener.getAST();
|
||||
return withFlatten ? flatten(listener.getAST()) : listener.getAST();
|
||||
}
|
||||
|
||||
public static Expr flatten(Expr expr)
|
||||
{
|
||||
if (expr instanceof BinaryOpExprBase) {
|
||||
BinaryOpExprBase binary = (BinaryOpExprBase) expr;
|
||||
Expr left = flatten(binary.left);
|
||||
Expr right = flatten(binary.right);
|
||||
if (Evals.isAllConstants(left, right)) {
|
||||
expr = expr.eval(null).toExpr();
|
||||
} else if (left != binary.left || right != binary.right) {
|
||||
return Evals.binaryOp(binary, left, right);
|
||||
}
|
||||
} else if (expr instanceof UnaryExpr) {
|
||||
UnaryExpr unary = (UnaryExpr) expr;
|
||||
Expr eval = flatten(unary.expr);
|
||||
if (eval instanceof ConstantExpr) {
|
||||
expr = expr.eval(null).toExpr();
|
||||
} else if (eval != unary.expr) {
|
||||
if (expr instanceof UnaryMinusExpr) {
|
||||
expr = new UnaryMinusExpr(eval);
|
||||
} else if (expr instanceof UnaryNotExpr) {
|
||||
expr = new UnaryNotExpr(eval);
|
||||
} else {
|
||||
expr = unary; // unknown type..
|
||||
}
|
||||
}
|
||||
} else if (expr instanceof FunctionExpr) {
|
||||
FunctionExpr functionExpr = (FunctionExpr) expr;
|
||||
List<Expr> args = functionExpr.args;
|
||||
boolean flattened = false;
|
||||
List<Expr> flattening = Lists.newArrayListWithCapacity(args.size());
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
Expr flatten = flatten(args.get(i));
|
||||
flattened |= flatten != args.get(i);
|
||||
flattening.add(flatten);
|
||||
}
|
||||
if (Evals.isAllConstants(flattening)) {
|
||||
expr = expr.eval(null).toExpr();
|
||||
} else if (flattened) {
|
||||
expr = new FunctionExpr(functionExpr.name, flattening);
|
||||
}
|
||||
}
|
||||
return expr;
|
||||
}
|
||||
|
||||
public static List<String> findRequiredBindings(String in)
|
||||
|
|
|
@ -49,25 +49,11 @@ public class ParserTest
|
|||
@Test
|
||||
public void testSimpleUnaryOps2()
|
||||
{
|
||||
String actual = Parser.parse("-1").toString();
|
||||
String expected = "-1";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
actual = Parser.parse("--1").toString();
|
||||
expected = "--1";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
actual = Parser.parse("-1+2").toString();
|
||||
expected = "(+ -1 2)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
actual = Parser.parse("-1*2").toString();
|
||||
expected = "(* -1 2)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
actual = Parser.parse("-1^2").toString();
|
||||
expected = "(^ -1 2)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
validate("-1", "-1", "-1");
|
||||
validate("--1", "--1", "1");
|
||||
validate("-1+2", "(+ -1 2)", "1");
|
||||
validate("-1*2", "(* -1 2)", "-2");
|
||||
validate("-1^2", "(^ -1 2)", "1");
|
||||
}
|
||||
|
||||
private void validateParser(String expression, String expected, String identifiers)
|
||||
|
@ -103,6 +89,8 @@ public class ParserTest
|
|||
validateParser("x+y-z", "(- (+ x y) z)", "[x, y, z]");
|
||||
validateParser("x-y+z", "(+ (- x y) z)", "[x, y, z]");
|
||||
validateParser("x-y-z", "(- (- x y) z)", "[x, y, z]");
|
||||
|
||||
validateParser("x-y-x", "(- (- x y) x)", "[x, y]");
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -116,99 +104,68 @@ public class ParserTest
|
|||
@Test
|
||||
public void testSimpleMultiplicativeOp2()
|
||||
{
|
||||
String actual = Parser.parse("1*2*3").toString();
|
||||
String expected = "(* (* 1 2) 3)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
validate("1*2*3", "(* (* 1 2) 3)", "6");
|
||||
validate("1*2/3", "(/ (* 1 2) 3)", "0");
|
||||
validate("1/2*3", "(* (/ 1 2) 3)", "0");
|
||||
validate("1/2/3", "(/ (/ 1 2) 3)", "0");
|
||||
|
||||
actual = Parser.parse("1*2/3").toString();
|
||||
expected = "(/ (* 1 2) 3)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
validate("1.0*2*3", "(* (* 1.0 2) 3)", "6.0");
|
||||
validate("1.0*2/3", "(/ (* 1.0 2) 3)", "0.6666666666666666");
|
||||
validate("1.0/2*3", "(* (/ 1.0 2) 3)", "1.5");
|
||||
validate("1.0/2/3", "(/ (/ 1.0 2) 3)", "0.16666666666666666");
|
||||
|
||||
actual = Parser.parse("1/2*3").toString();
|
||||
expected = "(* (/ 1 2) 3)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
// partial
|
||||
validate("1.0*2*x", "(* (* 1.0 2) x)", "(* 2.0 x)");
|
||||
validate("1.0*2/x", "(/ (* 1.0 2) x)", "(/ 2.0 x)");
|
||||
validate("1.0/2*x", "(* (/ 1.0 2) x)", "(* 0.5 x)");
|
||||
validate("1.0/2/x", "(/ (/ 1.0 2) x)", "(/ 0.5 x)");
|
||||
|
||||
actual = Parser.parse("1/2/3").toString();
|
||||
expected = "(/ (/ 1 2) 3)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
// not working yet
|
||||
validate("1.0*x*3", "(* (* 1.0 x) 3)", "(* (* 1.0 x) 3)");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimpleCarrot1()
|
||||
{
|
||||
String actual = Parser.parse("1^2").toString();
|
||||
String expected = "(^ 1 2)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
validate("1^2", "(^ 1 2)", "1");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimpleCarrot2()
|
||||
{
|
||||
String actual = Parser.parse("1^2^3").toString();
|
||||
String expected = "(^ 1 (^ 2 3))";
|
||||
Assert.assertEquals(expected, actual);
|
||||
validate("1^2^3", "(^ 1 (^ 2 3))", "1");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMixed()
|
||||
{
|
||||
String actual = Parser.parse("1+2*3").toString();
|
||||
String expected = "(+ 1 (* 2 3))";
|
||||
Assert.assertEquals(expected, actual);
|
||||
validate("1+2*3", "(+ 1 (* 2 3))", "7");
|
||||
validate("1+(2*3)", "(+ 1 (* 2 3))", "7");
|
||||
validate("(1+2)*3", "(* (+ 1 2) 3)", "9");
|
||||
|
||||
actual = Parser.parse("1+(2*3)").toString();
|
||||
Assert.assertEquals(expected, actual);
|
||||
validate("1*2+3", "(+ (* 1 2) 3)", "5");
|
||||
validate("(1*2)+3", "(+ (* 1 2) 3)", "5");
|
||||
validate("1*(2+3)", "(* 1 (+ 2 3))", "5");
|
||||
|
||||
actual = Parser.parse("(1+2)*3").toString();
|
||||
expected = "(* (+ 1 2) 3)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
validate("1+2^3", "(+ 1 (^ 2 3))", "9");
|
||||
validate("1+(2^3)", "(+ 1 (^ 2 3))", "9");
|
||||
validate("(1+2)^3", "(^ (+ 1 2) 3)", "27");
|
||||
|
||||
validate("1^2+3", "(+ (^ 1 2) 3)", "4");
|
||||
validate("(1^2)+3", "(+ (^ 1 2) 3)", "4");
|
||||
validate("1^(2+3)", "(^ 1 (+ 2 3))", "1");
|
||||
|
||||
actual = Parser.parse("1*2+3").toString();
|
||||
expected = "(+ (* 1 2) 3)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
validate("1^2*3+4", "(+ (* (^ 1 2) 3) 4)", "7");
|
||||
validate("-1^2*-3+-4", "(+ (* (^ -1 2) -3) -4)", "-7");
|
||||
|
||||
actual = Parser.parse("(1*2)+3").toString();
|
||||
Assert.assertEquals(expected, actual);
|
||||
validate("max(3, 4)", "(max [3, 4])", "4");
|
||||
validate("min(1, max(3, 4))", "(min [1, (max [3, 4])])", "1");
|
||||
}
|
||||
|
||||
actual = Parser.parse("1*(2+3)").toString();
|
||||
expected = "(* 1 (+ 2 3))";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
|
||||
actual = Parser.parse("1+2^3").toString();
|
||||
expected = "(+ 1 (^ 2 3))";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
actual = Parser.parse("1+(2^3)").toString();
|
||||
expected = "(+ 1 (^ 2 3))";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
actual = Parser.parse("(1+2)^3").toString();
|
||||
expected = "(^ (+ 1 2) 3)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
|
||||
actual = Parser.parse("1^2+3").toString();
|
||||
expected = "(+ (^ 1 2) 3)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
actual = Parser.parse("(1^2)+3").toString();
|
||||
expected = "(+ (^ 1 2) 3)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
actual = Parser.parse("1^(2+3)").toString();
|
||||
expected = "(^ 1 (+ 2 3))";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
|
||||
actual = Parser.parse("1^2*3+4").toString();
|
||||
expected = "(+ (* (^ 1 2) 3) 4)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
|
||||
actual = Parser.parse("-1^-2*-3+-4").toString();
|
||||
expected = "(+ (* (^ -1 -2) -3) -4)";
|
||||
Assert.assertEquals(expected, actual);
|
||||
private void validate(String expression, String withoutFlatten, String withFlatten)
|
||||
{
|
||||
Assert.assertEquals(withoutFlatten, Parser.parse(expression, false).toString());
|
||||
Assert.assertEquals(withFlatten, Parser.parse(expression, true).toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue