diff --git a/processing/src/main/java/io/druid/query/aggregation/post/ArithmeticPostAggregator.java b/processing/src/main/java/io/druid/query/aggregation/post/ArithmeticPostAggregator.java index 87138dadef8..6754d8c0079 100644 --- a/processing/src/main/java/io/druid/query/aggregation/post/ArithmeticPostAggregator.java +++ b/processing/src/main/java/io/druid/query/aggregation/post/ArithmeticPostAggregator.java @@ -36,7 +36,7 @@ import java.util.Set; */ public class ArithmeticPostAggregator implements PostAggregator { - private static final Comparator COMPARATOR = new Comparator() + private static final Comparator DEFAULT_COMPARATOR = new Comparator() { @Override public int compare(Object o, Object o1) @@ -48,13 +48,27 @@ public class ArithmeticPostAggregator implements PostAggregator private final String name; private final String fnName; private final List fields; - private final Ops op; + private final Op op; + private final Comparator comparator; + private final String ordering; + private final String opStrategy; + + public ArithmeticPostAggregator( + String name, + String fnName, + List fields + ) + { + this(name, fnName, fields, null, null); + } @JsonCreator public ArithmeticPostAggregator( @JsonProperty("name") String name, @JsonProperty("fn") String fnName, - @JsonProperty("fields") List fields + @JsonProperty("fields") List fields, + @JsonProperty("ordering") String ordering, + @JsonProperty("opStrategy") String opStrategy ) { this.name = name; @@ -64,10 +78,21 @@ public class ArithmeticPostAggregator implements PostAggregator throw new IAE("Illegal number of fields[%s], must be > 1", fields.size()); } - this.op = Ops.lookup(fnName); - if (op == null) { + Ops baseOp = Ops.lookup(fnName); + if (baseOp == null) { throw new IAE("Unknown operation[%s], known operations[%s]", fnName, Ops.getFns()); } + + this.ordering = ordering; + this.comparator = ordering == null ? DEFAULT_COMPARATOR : Ordering.valueOf(ordering); + + this.opStrategy = opStrategy == null && baseOp.equals(Ops.DIV) ? OpStrategy.zeroDivisionByZero.name() : opStrategy; + + if(this.opStrategy != null) { + this.op = Ops.withStrategy(baseOp, OpStrategy.valueOf(this.opStrategy)); + } else { + this.op = baseOp; + } } @Override @@ -83,7 +108,7 @@ public class ArithmeticPostAggregator implements PostAggregator @Override public Comparator getComparator() { - return COMPARATOR; + return comparator; } @Override @@ -113,6 +138,18 @@ public class ArithmeticPostAggregator implements PostAggregator return fnName; } + @JsonProperty("ordering") + public String getOrdering() + { + return ordering; + } + + @JsonProperty("opStrategy") + public String getOpStrategy() + { + return opStrategy; + } + @JsonProperty public List getFields() { @@ -130,34 +167,38 @@ public class ArithmeticPostAggregator implements PostAggregator '}'; } - private static enum Ops + static interface Op { + double compute(double lhs, double rhs); + } + + private static enum Ops implements Op { PLUS("+") { - double compute(double lhs, double rhs) + public double compute(double lhs, double rhs) { return lhs + rhs; } }, MINUS("-") { - double compute(double lhs, double rhs) + public double compute(double lhs, double rhs) { return lhs - rhs; } }, MULT("*") { - double compute(double lhs, double rhs) + public double compute(double lhs, double rhs) { return lhs * rhs; } }, DIV("/") { - double compute(double lhs, double rhs) + public double compute(double lhs, double rhs) { - return (rhs == 0.0) ? 0 : (lhs / rhs); + return lhs / rhs; } }; @@ -181,8 +222,6 @@ public class ArithmeticPostAggregator implements PostAggregator return fn; } - abstract double compute(double lhs, double rhs); - static Ops lookup(String fn) { return lookupMap.get(fn); @@ -192,6 +231,60 @@ public class ArithmeticPostAggregator implements PostAggregator { return lookupMap.keySet(); } + + public static Op withStrategy(final Op baseOp, final OpStrategy strategy) { + if(strategy.equals(OpStrategy.none)) { + return baseOp; + } + return new Op() + { + @Override + public double compute(double lhs, double rhs) + { + return strategy.compute(baseOp, lhs, rhs); + } + }; + } + } + + public static enum Ordering implements Comparator { + numericFirst { + public int compare(Double lhs, Double rhs) { + if(Double.isInfinite(lhs) || Double.isNaN(lhs)) { + return -1; + } + if(Double.isInfinite(rhs) || Double.isNaN(rhs)) { + return 1; + } + return Double.compare(lhs, rhs); + } + } + } + + public static enum OpStrategy + { + none { + public double compute(Op op, double lhs, double rhs) { + return op.compute(lhs, rhs); + } + }, + + zeroDivisionByZero { + public double compute(Op op, double lhs, double rhs) { + if(rhs == 0) { return 0; } + else return op.compute(lhs, rhs); + } + }, + + nanDivisionByZero { + public double compute(Op op, double lhs, double rhs) + { + if(rhs == 0) { return Double.NaN; } + else return op.compute(lhs, rhs); + } + }; + + public abstract double compute(Op op, double lhs, double rhs); } @Override diff --git a/processing/src/test/java/io/druid/query/aggregation/post/ArithmeticPostAggregatorTest.java b/processing/src/test/java/io/druid/query/aggregation/post/ArithmeticPostAggregatorTest.java index e56ccb94ad9..41c9a4fa453 100644 --- a/processing/src/test/java/io/druid/query/aggregation/post/ArithmeticPostAggregatorTest.java +++ b/processing/src/test/java/io/druid/query/aggregation/post/ArithmeticPostAggregatorTest.java @@ -19,12 +19,14 @@ package io.druid.query.aggregation.post; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import io.druid.query.aggregation.CountAggregator; import io.druid.query.aggregation.PostAggregator; import org.junit.Assert; import org.junit.Test; +import javax.annotation.concurrent.Immutable; import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -100,4 +102,92 @@ public class ArithmeticPostAggregatorTest Assert.assertEquals(0, comp.compare(after, after)); Assert.assertEquals(1, comp.compare(after, before)); } -} \ No newline at end of file + + @Test + public void testOrdering() throws Exception + { + List postAggregatorList = + Lists.newArrayList( + (PostAggregator) + new ConstantPostAggregator( + "size", 6, null + ), + new ConstantPostAggregator( + "zero", 0, null + ) + ); + + ArithmeticPostAggregator divideNumericFirst = new ArithmeticPostAggregator( + "divide", + "/", + postAggregatorList, + ArithmeticPostAggregator.Ordering.numericFirst.name(), + null + ); + + Assert.assertTrue( + divideNumericFirst.getComparator().compare(Double.POSITIVE_INFINITY, 0.0) < 0 + ); + Assert.assertTrue( + divideNumericFirst.getComparator().compare(Double.NEGATIVE_INFINITY, 0.0) < 0 + ); + Assert.assertTrue( + divideNumericFirst.getComparator().compare(Double.NaN, 0.0) < 0 + ); + } + + @Test + public void testOpStrategy() throws Exception + { + List postAggregatorList = + Lists.newArrayList( + (PostAggregator) + new ConstantPostAggregator( + "size", 6, null + ), + new ConstantPostAggregator( + "zero", 0, null + ) + ); + + ArithmeticPostAggregator divByZeroNaN = new ArithmeticPostAggregator( + "divide", + "/", + postAggregatorList, + null, + ArithmeticPostAggregator.OpStrategy.nanDivisionByZero.name() + ); + + Assert.assertEquals( + divByZeroNaN.compute(ImmutableMap.of("dummy", (Object) 0.0)), + Double.NaN + ); + + ArithmeticPostAggregator divByZeroZero = new ArithmeticPostAggregator( + "divide", + "/", + postAggregatorList, + null, + ArithmeticPostAggregator.OpStrategy.zeroDivisionByZero.name() + ); + + Assert.assertEquals( + divByZeroZero.compute(ImmutableMap.of("dummy", (Object) 0.0)), + 0.0 + ); + + // default behavior is zeroDivByZero + ArithmeticPostAggregator _default = new ArithmeticPostAggregator( + "divide", + "/", + postAggregatorList, + null, + null + ); + + Assert.assertEquals( + _default.compute(ImmutableMap.of("dummy", (Object) 0.0)), + 0.0 + ); + } +}