diff --git a/docs/content/Post-aggregations.md b/docs/content/Post-aggregations.md index f2eeea76429..a89103c799e 100644 --- a/docs/content/Post-aggregations.md +++ b/docs/content/Post-aggregations.md @@ -8,9 +8,21 @@ There are several post-aggregators available. ### Arithmetic post-aggregator -The arithmetic post-aggregator applies the provided function to the given fields from left to right. The fields can be aggregators or other post aggregators. +The arithmetic post-aggregator applies the provided function to the given +fields from left to right. The fields can be aggregators or other post aggregators. -Supported functions are `+`, `-`, `*`, and `/` +Supported functions are `+`, `-`, `*`, `/`, and `quotient`. + +**Note**: + +* `/` division always returns `0` if dividing by`0`, regardless of the numerator. +* `quotient` division behaves like regular floating point division + +Arithmetic post-aggregators may also specify an `ordering`, which defines the order +of resulting values when sorting results (this can be useful for topN queries for instance): + +- If no ordering (or `null`) is specified, the default floating point ordering is used. +- `numericFirst` ordering always returns finite values first, followed by `NaN`, and infinite values last. The grammar for an arithmetic post aggregation is: @@ -19,13 +31,11 @@ postAggregation : { "type" : "arithmetic", "name" : , "fn" : , - "fields": [, , ...] + "fields": [, , ...], + "ordering" : } ``` -In the case of a division (`/`), if the denominator is `0` then `0` is returned regardless of the numerator. - - ### Field accessor post-aggregator This returns the value produced by the specified [aggregator](Aggregations.html). diff --git a/processing/src/main/java/io/druid/query/aggregation/post/ArithmeticPostAggregator.java b/processing/src/main/java/io/druid/query/aggregation/post/ArithmeticPostAggregator.java index 9d06241ddf8..cd89f1fdbd6 100644 --- a/processing/src/main/java/io/druid/query/aggregation/post/ArithmeticPostAggregator.java +++ b/processing/src/main/java/io/druid/query/aggregation/post/ArithmeticPostAggregator.java @@ -19,6 +19,7 @@ package io.druid.query.aggregation.post; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.metamx.common.IAE; @@ -34,7 +35,7 @@ import java.util.Set; */ public class ArithmeticPostAggregator implements PostAggregator { - private static final Comparator COMPARATOR = new Comparator() + private static final Comparator DEFAULT_COMPARATOR = new Comparator() { @Override public int compare(Object o, Object o1) @@ -47,25 +48,40 @@ public class ArithmeticPostAggregator implements PostAggregator private final String fnName; private final List fields; private final Ops op; + private final Comparator comparator; + private final String ordering; + + public ArithmeticPostAggregator( + String name, + String fnName, + List fields + ) + { + this(name, fnName, fields, null); + } @JsonCreator public ArithmeticPostAggregator( @JsonProperty("name") String name, @JsonProperty("fn") String fnName, - @JsonProperty("fields") List fields + @JsonProperty("fields") List fields, + @JsonProperty("ordering") String ordering ) { + Preconditions.checkArgument(fnName != null, "fn cannot not be null"); + Preconditions.checkArgument(fields != null && fields.size() > 1, "Illegal number of fields[%s], must be > 1"); + this.name = name; this.fnName = fnName; this.fields = fields; - if (fields.size() <= 1) { - throw new IAE("Illegal number of fields[%s], must be > 1", fields.size()); - } this.op = Ops.lookup(fnName); if (op == null) { throw new IAE("Unknown operation[%s], known operations[%s]", fnName, Ops.getFns()); } + + this.ordering = ordering; + this.comparator = ordering == null ? DEFAULT_COMPARATOR : Ordering.valueOf(ordering); } @Override @@ -81,7 +97,7 @@ public class ArithmeticPostAggregator implements PostAggregator @Override public Comparator getComparator() { - return COMPARATOR; + return comparator; } @Override @@ -111,6 +127,12 @@ public class ArithmeticPostAggregator implements PostAggregator return fnName; } + @JsonProperty("ordering") + public String getOrdering() + { + return ordering; + } + @JsonProperty public List getFields() { @@ -132,31 +154,38 @@ public class ArithmeticPostAggregator implements PostAggregator { PLUS("+") { - double compute(double lhs, double rhs) + public double compute(double lhs, double rhs) { return lhs + rhs; } }, MINUS("-") { - double compute(double lhs, double rhs) + public double compute(double lhs, double rhs) { return lhs - rhs; } }, MULT("*") { - double compute(double lhs, double rhs) + public double compute(double lhs, double rhs) { return lhs * rhs; } }, DIV("/") { - double compute(double lhs, double rhs) + public double compute(double lhs, double rhs) { return (rhs == 0.0) ? 0 : (lhs / rhs); } + }, + QUOTIENT("quotient") + { + public double compute(double lhs, double rhs) + { + return lhs / rhs; + } }; private static final Map lookupMap = Maps.newHashMap(); @@ -179,7 +208,7 @@ public class ArithmeticPostAggregator implements PostAggregator return fn; } - abstract double compute(double lhs, double rhs); + public abstract double compute(double lhs, double rhs); static Ops lookup(String fn) { @@ -192,18 +221,56 @@ public class ArithmeticPostAggregator implements PostAggregator } } + public static enum Ordering implements Comparator { + // ensures the following order: numeric > NaN > Infinite + numericFirst { + public int compare(Double lhs, Double rhs) { + if(isFinite(lhs) && !isFinite(rhs)) { + return 1; + } + if(!isFinite(lhs) && isFinite(rhs)) { + return -1; + } + return Double.compare(lhs, rhs); + } + + // Double.isFinite only exist in JDK8 + private boolean isFinite(double value) { + return !Double.isInfinite(value) && !Double.isNaN(value); + } + } + } + @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } ArithmeticPostAggregator that = (ArithmeticPostAggregator) o; - if (fields != null ? !fields.equals(that.fields) : that.fields != null) return false; - if (fnName != null ? !fnName.equals(that.fnName) : that.fnName != null) return false; - if (name != null ? !name.equals(that.name) : that.name != null) return false; - if (op != that.op) return false; + if (!comparator.equals(that.comparator)) { + return false; + } + if (!fields.equals(that.fields)) { + return false; + } + if (!fnName.equals(that.fnName)) { + return false; + } + if (name != null ? !name.equals(that.name) : that.name != null) { + return false; + } + if (op != that.op) { + return false; + } + if (ordering != null ? !ordering.equals(that.ordering) : that.ordering != null) { + return false; + } return true; } @@ -212,9 +279,11 @@ public class ArithmeticPostAggregator implements PostAggregator public int hashCode() { int result = name != null ? name.hashCode() : 0; - result = 31 * result + (fnName != null ? fnName.hashCode() : 0); - result = 31 * result + (fields != null ? fields.hashCode() : 0); - result = 31 * result + (op != null ? op.hashCode() : 0); + result = 31 * result + fnName.hashCode(); + result = 31 * result + fields.hashCode(); + result = 31 * result + op.hashCode(); + result = 31 * result + comparator.hashCode(); + result = 31 * result + (ordering != null ? ordering.hashCode() : 0); return result; } } diff --git a/processing/src/test/java/io/druid/query/aggregation/post/ArithmeticPostAggregatorTest.java b/processing/src/test/java/io/druid/query/aggregation/post/ArithmeticPostAggregatorTest.java index 26b917af41f..c87aa944a96 100644 --- a/processing/src/test/java/io/druid/query/aggregation/post/ArithmeticPostAggregatorTest.java +++ b/processing/src/test/java/io/druid/query/aggregation/post/ArithmeticPostAggregatorTest.java @@ -17,12 +17,16 @@ package io.druid.query.aggregation.post; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import io.druid.query.aggregation.CountAggregator; +import io.druid.query.aggregation.DoubleSumAggregator; import io.druid.query.aggregation.PostAggregator; import org.junit.Assert; import org.junit.Test; +import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -98,4 +102,70 @@ public class ArithmeticPostAggregatorTest Assert.assertEquals(0, comp.compare(after, after)); Assert.assertEquals(1, comp.compare(after, before)); } + + @Test + public void testQuotient() throws Exception + { + ArithmeticPostAggregator agg = new ArithmeticPostAggregator( + null, + "quotient", + ImmutableList.of( + new FieldAccessPostAggregator("numerator", "value"), + new ConstantPostAggregator("zero", 0) + ), + "numericFirst" + ); + + + Assert.assertEquals(Double.NaN, agg.compute(ImmutableMap.of("value", 0))); + Assert.assertEquals(Double.NaN, agg.compute(ImmutableMap.of("value", Double.NaN))); + Assert.assertEquals(Double.POSITIVE_INFINITY, agg.compute(ImmutableMap.of("value", 1))); + Assert.assertEquals(Double.NEGATIVE_INFINITY, agg.compute(ImmutableMap.of("value", -1))); + } + + @Test + public void testDiv() throws Exception + { + ArithmeticPostAggregator agg = new ArithmeticPostAggregator( + null, + "/", + ImmutableList.of( + new FieldAccessPostAggregator("numerator", "value"), + new ConstantPostAggregator("denomiator", 0) + ) + ); + + Assert.assertEquals(0.0, agg.compute(ImmutableMap.of("value", 0))); + Assert.assertEquals(0.0, agg.compute(ImmutableMap.of("value", Double.NaN))); + Assert.assertEquals(0.0, agg.compute(ImmutableMap.of("value", 1))); + Assert.assertEquals(0.0, agg.compute(ImmutableMap.of("value", -1))); + } + + @Test + public void testNumericFirstOrdering() throws Exception + { + ArithmeticPostAggregator agg = new ArithmeticPostAggregator( + null, + "quotient", + ImmutableList.of( + new ConstantPostAggregator("zero", 0), + new ConstantPostAggregator("zero", 0) + ), + "numericFirst" + ); + final Comparator numericFirst = agg.getComparator(); + Assert.assertTrue(numericFirst.compare(Double.NaN, 0.0) < 0); + Assert.assertTrue(numericFirst.compare(Double.POSITIVE_INFINITY, 0.0) < 0); + Assert.assertTrue(numericFirst.compare(Double.NEGATIVE_INFINITY, 0.0) < 0); + Assert.assertTrue(numericFirst.compare(0.0, Double.NaN) > 0); + Assert.assertTrue(numericFirst.compare(0.0, Double.POSITIVE_INFINITY) > 0); + Assert.assertTrue(numericFirst.compare(0.0, Double.NEGATIVE_INFINITY) > 0); + + Assert.assertTrue(numericFirst.compare(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY) < 0); + Assert.assertTrue(numericFirst.compare(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY) > 0); + Assert.assertTrue(numericFirst.compare(Double.NaN, Double.POSITIVE_INFINITY) > 0); + Assert.assertTrue(numericFirst.compare(Double.NaN, Double.NEGATIVE_INFINITY) > 0); + Assert.assertTrue(numericFirst.compare(Double.POSITIVE_INFINITY, Double.NaN) < 0); + Assert.assertTrue(numericFirst.compare(Double.NEGATIVE_INFINITY, Double.NaN) < 0); + } }