opStrategy and ordering for ArithmeticPostAgg

Fixes #510
This commit is contained in:
Xavier Léauté 2014-04-25 16:24:37 -07:00
parent f88cb13ccb
commit ded08b22ee
2 changed files with 198 additions and 15 deletions

View File

@ -36,7 +36,7 @@ import java.util.Set;
*/ */
public class ArithmeticPostAggregator implements PostAggregator public class ArithmeticPostAggregator implements PostAggregator
{ {
private static final Comparator COMPARATOR = new Comparator() private static final Comparator DEFAULT_COMPARATOR = new Comparator()
{ {
@Override @Override
public int compare(Object o, Object o1) public int compare(Object o, Object o1)
@ -48,13 +48,27 @@ public class ArithmeticPostAggregator implements PostAggregator
private final String name; private final String name;
private final String fnName; private final String fnName;
private final List<PostAggregator> fields; private final List<PostAggregator> fields;
private final Ops op; private final Op op;
private final Comparator comparator;
private final String ordering;
private final String opStrategy;
public ArithmeticPostAggregator(
String name,
String fnName,
List<PostAggregator> fields
)
{
this(name, fnName, fields, null, null);
}
@JsonCreator @JsonCreator
public ArithmeticPostAggregator( public ArithmeticPostAggregator(
@JsonProperty("name") String name, @JsonProperty("name") String name,
@JsonProperty("fn") String fnName, @JsonProperty("fn") String fnName,
@JsonProperty("fields") List<PostAggregator> fields @JsonProperty("fields") List<PostAggregator> fields,
@JsonProperty("ordering") String ordering,
@JsonProperty("opStrategy") String opStrategy
) )
{ {
this.name = name; this.name = name;
@ -64,10 +78,21 @@ public class ArithmeticPostAggregator implements PostAggregator
throw new IAE("Illegal number of fields[%s], must be > 1", fields.size()); throw new IAE("Illegal number of fields[%s], must be > 1", fields.size());
} }
this.op = Ops.lookup(fnName); Ops baseOp = Ops.lookup(fnName);
if (op == null) { if (baseOp == null) {
throw new IAE("Unknown operation[%s], known operations[%s]", fnName, Ops.getFns()); throw new IAE("Unknown operation[%s], known operations[%s]", fnName, Ops.getFns());
} }
this.ordering = ordering;
this.comparator = ordering == null ? DEFAULT_COMPARATOR : Ordering.valueOf(ordering);
this.opStrategy = opStrategy == null && baseOp.equals(Ops.DIV) ? OpStrategy.zeroDivisionByZero.name() : opStrategy;
if(this.opStrategy != null) {
this.op = Ops.withStrategy(baseOp, OpStrategy.valueOf(this.opStrategy));
} else {
this.op = baseOp;
}
} }
@Override @Override
@ -83,7 +108,7 @@ public class ArithmeticPostAggregator implements PostAggregator
@Override @Override
public Comparator getComparator() public Comparator getComparator()
{ {
return COMPARATOR; return comparator;
} }
@Override @Override
@ -113,6 +138,18 @@ public class ArithmeticPostAggregator implements PostAggregator
return fnName; return fnName;
} }
@JsonProperty("ordering")
public String getOrdering()
{
return ordering;
}
@JsonProperty("opStrategy")
public String getOpStrategy()
{
return opStrategy;
}
@JsonProperty @JsonProperty
public List<PostAggregator> getFields() public List<PostAggregator> getFields()
{ {
@ -130,34 +167,38 @@ public class ArithmeticPostAggregator implements PostAggregator
'}'; '}';
} }
private static enum Ops static interface Op {
double compute(double lhs, double rhs);
}
private static enum Ops implements Op
{ {
PLUS("+") PLUS("+")
{ {
double compute(double lhs, double rhs) public double compute(double lhs, double rhs)
{ {
return lhs + rhs; return lhs + rhs;
} }
}, },
MINUS("-") MINUS("-")
{ {
double compute(double lhs, double rhs) public double compute(double lhs, double rhs)
{ {
return lhs - rhs; return lhs - rhs;
} }
}, },
MULT("*") MULT("*")
{ {
double compute(double lhs, double rhs) public double compute(double lhs, double rhs)
{ {
return lhs * rhs; return lhs * rhs;
} }
}, },
DIV("/") DIV("/")
{ {
double compute(double lhs, double rhs) public double compute(double lhs, double rhs)
{ {
return (rhs == 0.0) ? 0 : (lhs / rhs); return lhs / rhs;
} }
}; };
@ -181,8 +222,6 @@ public class ArithmeticPostAggregator implements PostAggregator
return fn; return fn;
} }
abstract double compute(double lhs, double rhs);
static Ops lookup(String fn) static Ops lookup(String fn)
{ {
return lookupMap.get(fn); return lookupMap.get(fn);
@ -192,6 +231,60 @@ public class ArithmeticPostAggregator implements PostAggregator
{ {
return lookupMap.keySet(); return lookupMap.keySet();
} }
public static Op withStrategy(final Op baseOp, final OpStrategy strategy) {
if(strategy.equals(OpStrategy.none)) {
return baseOp;
}
return new Op()
{
@Override
public double compute(double lhs, double rhs)
{
return strategy.compute(baseOp, lhs, rhs);
}
};
}
}
public static enum Ordering implements Comparator<Double> {
numericFirst {
public int compare(Double lhs, Double rhs) {
if(Double.isInfinite(lhs) || Double.isNaN(lhs)) {
return -1;
}
if(Double.isInfinite(rhs) || Double.isNaN(rhs)) {
return 1;
}
return Double.compare(lhs, rhs);
}
}
}
public static enum OpStrategy
{
none {
public double compute(Op op, double lhs, double rhs) {
return op.compute(lhs, rhs);
}
},
zeroDivisionByZero {
public double compute(Op op, double lhs, double rhs) {
if(rhs == 0) { return 0; }
else return op.compute(lhs, rhs);
}
},
nanDivisionByZero {
public double compute(Op op, double lhs, double rhs)
{
if(rhs == 0) { return Double.NaN; }
else return op.compute(lhs, rhs);
}
};
public abstract double compute(Op op, double lhs, double rhs);
} }
@Override @Override

View File

@ -19,12 +19,14 @@
package io.druid.query.aggregation.post; package io.druid.query.aggregation.post;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import io.druid.query.aggregation.CountAggregator; import io.druid.query.aggregation.CountAggregator;
import io.druid.query.aggregation.PostAggregator; import io.druid.query.aggregation.PostAggregator;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import javax.annotation.concurrent.Immutable;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@ -100,4 +102,92 @@ public class ArithmeticPostAggregatorTest
Assert.assertEquals(0, comp.compare(after, after)); Assert.assertEquals(0, comp.compare(after, after));
Assert.assertEquals(1, comp.compare(after, before)); Assert.assertEquals(1, comp.compare(after, before));
} }
@Test
public void testOrdering() throws Exception
{
List<PostAggregator> postAggregatorList =
Lists.newArrayList(
(PostAggregator)
new ConstantPostAggregator(
"size", 6, null
),
new ConstantPostAggregator(
"zero", 0, null
)
);
ArithmeticPostAggregator divideNumericFirst = new ArithmeticPostAggregator(
"divide",
"/",
postAggregatorList,
ArithmeticPostAggregator.Ordering.numericFirst.name(),
null
);
Assert.assertTrue(
divideNumericFirst.getComparator().compare(Double.POSITIVE_INFINITY, 0.0) < 0
);
Assert.assertTrue(
divideNumericFirst.getComparator().compare(Double.NEGATIVE_INFINITY, 0.0) < 0
);
Assert.assertTrue(
divideNumericFirst.getComparator().compare(Double.NaN, 0.0) < 0
);
}
@Test
public void testOpStrategy() throws Exception
{
List<PostAggregator> postAggregatorList =
Lists.newArrayList(
(PostAggregator)
new ConstantPostAggregator(
"size", 6, null
),
new ConstantPostAggregator(
"zero", 0, null
)
);
ArithmeticPostAggregator divByZeroNaN = new ArithmeticPostAggregator(
"divide",
"/",
postAggregatorList,
null,
ArithmeticPostAggregator.OpStrategy.nanDivisionByZero.name()
);
Assert.assertEquals(
divByZeroNaN.compute(ImmutableMap.of("dummy", (Object) 0.0)),
Double.NaN
);
ArithmeticPostAggregator divByZeroZero = new ArithmeticPostAggregator(
"divide",
"/",
postAggregatorList,
null,
ArithmeticPostAggregator.OpStrategy.zeroDivisionByZero.name()
);
Assert.assertEquals(
divByZeroZero.compute(ImmutableMap.of("dummy", (Object) 0.0)),
0.0
);
// default behavior is zeroDivByZero
ArithmeticPostAggregator _default = new ArithmeticPostAggregator(
"divide",
"/",
postAggregatorList,
null,
null
);
Assert.assertEquals(
_default.compute(ImmutableMap.of("dummy", (Object) 0.0)),
0.0
);
}
} }