Aggregations: Add standard deviation bounds to extended_stats

Extended_stats now displays the upper and lower bounds on standard deviations (e.g. avg +/- std).
Default is to show 2 std above/below, but can be changed using the `sigma` parameter.
Accepts non-negative doubles

Closes #9356
This commit is contained in:
Zachary Tong 2015-01-22 14:01:15 -05:00
parent 3e4fc2659d
commit a4eb1d5505
7 changed files with 295 additions and 58 deletions

View File

@ -3,7 +3,7 @@
A `multi-value` metrics aggregation that computes stats over numeric values extracted from the aggregated documents. These values can be extracted either from specific numeric fields in the documents, or be generated by a provided script.
The `extended_stats` aggregations is an extended version of the <<search-aggregations-metrics-stats-aggregation,`stats`>> aggregation, where additional metrics are added such as `sum_of_squares`, `variance` and `std_deviation`.
The `extended_stats` aggregations is an extended version of the <<search-aggregations-metrics-stats-aggregation,`stats`>> aggregation, where additional metrics are added such as `sum_of_squares`, `variance`, `std_deviation` and `std_deviation_bounds`.
Assuming the data consists of documents representing exams grades (between 0 and 100) of students
@ -25,21 +25,58 @@ The above aggregation computes the grades statistics over all documents. The agg
...
"aggregations": {
"grades_stats": {
"count": 6,
"min": 72,
"max": 117.6,
"avg": 94.2,
"sum": 565.2,
"sum_of_squares": 54551.51999999999,
"variance": 218.2799999999976,
"std_deviation": 14.774302013969987
"grade_stats": {
"count": 9,
"min": 72,
"max": 99,
"avg": 86,
"sum": 774,
"sum_of_squares": 67028,
"variance": 51.55555555555556,
"std_deviation": 7.180219742846005,
"std_deviation_bounds": {
"upper": 100.36043948569201,
"lower": 71.63956051430799
}
}
}
}
--------------------------------------------------
The name of the aggregation (`grades_stats` above) also serves as the key by which the aggreagtion result can be retrieved from the returned response.
The name of the aggregation (`grades_stats` above) also serves as the key by which the aggregation result can be retrieved from the returned response.
==== Standard Deviation Bounds
coming[1.4.3]
By default, the `extended_stats` metric will return an object called `std_deviation_bounds`, which provides an interval of plus/minus two standard
deviations from the mean. This can be a useful way to visualize variance of your data. If you want a different boundary, for example
three standard deviations, you can set `sigma` in the request:
[source,js]
--------------------------------------------------
{
"aggs" : {
"grades_stats" : {
"extended_stats" : {
"field" : "grade",
"sigma" : 3 <1>
}
}
}
}
--------------------------------------------------
<1> `sigma` controls how many standard deviations +/- from the mean should be displayed coming[1.4.3]
`sigma` can be any non-negative double, meaning you can request non-integer values such as `1.5`. A value of `0` is valid, but will simply
return the average for both `upper` and `lower` bounds.
.Standard Deviation and Bounds require normality
[NOTE]
=====
The standard deviation and its bounds are displayed by default, but they are not always applicable to all data-sets. Your data must
be normally distributed for the metrics to make sense. The statistics behind standard deviations assumes normally distributed data, so
if your data is skewed heavily left or right, the value returned will be misleading.
=====
==== Script

View File

@ -40,6 +40,22 @@ public interface ExtendedStats extends Stats {
*/
double getStdDeviation();
/**
* The upper or lower bounds of the stdDeviation
*/
double getStdDeviationBound(Bounds bound);
/**
* The standard deviation of the collected values as a String.
*/
String getStdDeviationAsString();
/**
* The upper or lower bounds of stdDev of the collected values as a String.
*/
String getStdDeviationBoundAsString(Bounds bound);
/**
* The sum of the squares of the collected values as a String.
*/
@ -50,9 +66,9 @@ public interface ExtendedStats extends Stats {
*/
String getVarianceAsString();
/**
* The standard deviation of the collected values as a String.
*/
String getStdDeviationAsString();
public enum Bounds {
UPPER, LOWER
}
}

View File

@ -52,9 +52,11 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
private DoubleArray maxes;
private DoubleArray sumOfSqrs;
private ValueFormatter formatter;
private double sigma;
public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource,
@Nullable ValueFormatter formatter, AggregationContext context, Aggregator parent, Map<String, Object> metaData) throws IOException {
@Nullable ValueFormatter formatter, AggregationContext context,
Aggregator parent, double sigma, Map<String, Object> metaData) throws IOException {
super(name, context, parent, metaData);
this.valuesSource = valuesSource;
this.formatter = formatter;
@ -66,6 +68,7 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
maxes = bigArrays.newDoubleArray(1, false);
maxes.fill(0, maxes.size(), Double.NEGATIVE_INFINITY);
sumOfSqrs = bigArrays.newDoubleArray(1, true);
this.sigma = sigma;
}
}
@ -134,6 +137,12 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
case sum_of_squares: return valuesSource == null ? 0 : sumOfSqrs.get(owningBucketOrd);
case variance: return valuesSource == null ? Double.NaN : variance(owningBucketOrd);
case std_deviation: return valuesSource == null ? Double.NaN : Math.sqrt(variance(owningBucketOrd));
case std_upper:
if (valuesSource == null) { return Double.NaN; }
return (sums.get(owningBucketOrd) / counts.get(owningBucketOrd)) + (Math.sqrt(variance(owningBucketOrd)) * this.sigma);
case std_lower:
if (valuesSource == null) { return Double.NaN; }
return (sums.get(owningBucketOrd) / counts.get(owningBucketOrd)) - (Math.sqrt(variance(owningBucketOrd)) * this.sigma);
default:
throw new ElasticsearchIllegalArgumentException("Unknown value [" + name + "] in common stats aggregation");
}
@ -148,16 +157,16 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
@Override
public InternalAggregation buildAggregation(long owningBucketOrdinal) {
if (valuesSource == null) {
return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, formatter, metaData());
return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, 0d, formatter, metaData());
}
assert owningBucketOrdinal < counts.size();
return new InternalExtendedStats(name, counts.get(owningBucketOrdinal), sums.get(owningBucketOrdinal),
mins.get(owningBucketOrdinal), maxes.get(owningBucketOrdinal), sumOfSqrs.get(owningBucketOrdinal), formatter, metaData());
mins.get(owningBucketOrdinal), maxes.get(owningBucketOrdinal), sumOfSqrs.get(owningBucketOrdinal), sigma, formatter, metaData());
}
@Override
public InternalAggregation buildEmptyAggregation() {
return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, formatter, metaData());
return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, 0d, formatter, metaData());
}
@Override
@ -167,19 +176,22 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
public static class Factory extends ValuesSourceAggregatorFactory.LeafOnly<ValuesSource.Numeric> {
public Factory(String name, ValuesSourceConfig<ValuesSource.Numeric> valuesSourceConfig) {
private final double sigma;
public Factory(String name, ValuesSourceConfig<ValuesSource.Numeric> valuesSourceConfig, double sigma) {
super(name, InternalExtendedStats.TYPE.name(), valuesSourceConfig);
this.sigma = sigma;
}
@Override
protected Aggregator createUnmapped(AggregationContext aggregationContext, Aggregator parent, Map<String, Object> metaData) throws IOException {
return new ExtendedStatsAggregator(name, null, config.formatter(), aggregationContext, parent, metaData);
return new ExtendedStatsAggregator(name, null, config.formatter(), aggregationContext, parent, sigma, metaData);
}
@Override
protected Aggregator doCreateInternal(ValuesSource.Numeric valuesSource, AggregationContext aggregationContext, Aggregator parent, boolean collectsFromSingleBucket, Map<String, Object> metaData) throws IOException {
return new ExtendedStatsAggregator(name, valuesSource, config.formatter(), aggregationContext, parent,
metaData);
return new ExtendedStatsAggregator(name, valuesSource, config.formatter(), aggregationContext, parent, sigma, metaData);
}
}
}

View File

@ -19,17 +19,37 @@
package org.elasticsearch.search.aggregations.metrics.stats.extended;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.aggregations.metrics.ValuesSourceMetricsAggregationBuilder;
import java.io.IOException;
/**
* Builder for the {@link ExtendedStats} aggregation.
*/
public class ExtendedStatsBuilder extends ValuesSourceMetricsAggregationBuilder<ExtendedStatsBuilder> {
private Double sigma;
/**
* Sole constructor.
*/
public ExtendedStatsBuilder(String name) {
super(name, InternalExtendedStats.TYPE.name());
}
public ExtendedStatsBuilder sigma(double sigma) {
this.sigma = sigma;
return this;
}
@Override
protected void internalXContent(XContentBuilder builder, Params params) throws IOException {
super.internalXContent(builder, params);
if (sigma != null) {
builder.field("sigma", sigma);
}
}
}

View File

@ -18,22 +18,64 @@
*/
package org.elasticsearch.search.aggregations.metrics.stats.extended;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchParseException;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.metrics.NumericValuesSourceMetricsAggregatorParser;
import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.aggregations.support.ValuesSourceParser;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
/**
*
*/
public class ExtendedStatsParser extends NumericValuesSourceMetricsAggregatorParser<InternalExtendedStats> {
public class ExtendedStatsParser implements Aggregator.Parser {
public ExtendedStatsParser() {
super(InternalExtendedStats.TYPE);
static final ParseField SIGMA = new ParseField("sigma");
@Override
public String type() {
return InternalExtendedStats.TYPE.name();
}
protected AggregatorFactory createFactory(String aggregationName, ValuesSourceConfig<ValuesSource.Numeric> config, double sigma) {
return new ExtendedStatsAggregator.Factory(aggregationName, config, sigma);
}
@Override
protected AggregatorFactory createFactory(String aggregationName, ValuesSourceConfig<ValuesSource.Numeric> config) {
return new ExtendedStatsAggregator.Factory(aggregationName, config);
public AggregatorFactory parse(String aggregationName, XContentParser parser, SearchContext context) throws IOException {
ValuesSourceParser<ValuesSource.Numeric> vsParser = ValuesSourceParser.numeric(aggregationName, InternalExtendedStats.TYPE, context).formattable(true)
.build();
XContentParser.Token token;
String currentFieldName = null;
double sigma = 2.0;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (vsParser.token(currentFieldName, token, parser)) {
continue;
} else if (token == XContentParser.Token.VALUE_NUMBER) {
if (SIGMA.match(currentFieldName)) {
sigma = parser.doubleValue();
} else {
throw new SearchParseException(context, "Unknown key for a " + token + " in [" + aggregationName + "]: [" + currentFieldName + "].");
}
} else {
throw new SearchParseException(context, "Unexpected token " + token + " in [" + aggregationName + "].");
}
}
if (sigma < 0) {
throw new SearchParseException(context, "[sigma] must not be negative. Value provided was" + sigma );
}
return createFactory(aggregationName, vsParser.config(), sigma);
}
}

View File

@ -18,6 +18,7 @@
*/
package org.elasticsearch.search.aggregations.metrics.stats.extended;
import org.elasticsearch.Version;
import org.elasticsearch.common.inject.internal.Nullable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -53,7 +54,7 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
enum Metrics {
count, sum, min, max, avg, sum_of_squares, variance, std_deviation;
count, sum, min, max, avg, sum_of_squares, variance, std_deviation, std_upper, std_lower;
public static Metrics resolve(String name) {
return Metrics.valueOf(name);
@ -61,13 +62,15 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
}
private double sumOfSqrs;
private double sigma;
InternalExtendedStats() {} // for serialization
public InternalExtendedStats(String name, long count, double sum, double min, double max, double sumOfSqrs,
@Nullable ValueFormatter formatter, Map<String, Object> metaData) {
double sigma, @Nullable ValueFormatter formatter, Map<String, Object> metaData) {
super(name, count, sum, min, max, formatter, metaData);
this.sumOfSqrs = sumOfSqrs;
this.sigma = sigma;
}
@Override
@ -86,6 +89,12 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
if ("std_deviation".equals(name)) {
return getStdDeviation();
}
if ("std_upper".equals(name)) {
return getStdDeviationBound(Bounds.UPPER);
}
if ("std_lower".equals(name)) {
return getStdDeviationBound(Bounds.LOWER);
}
return super.value(name);
}
@ -104,6 +113,15 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
return Math.sqrt(getVariance());
}
@Override
public double getStdDeviationBound(Bounds bound) {
if (bound.equals(Bounds.UPPER)) {
return getAvg() + (getStdDeviation() * sigma);
} else {
return getAvg() - (getStdDeviation() * sigma);
}
}
@Override
public String getSumOfSquaresAsString() {
return valueAsString(Metrics.sum_of_squares.name());
@ -119,6 +137,11 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
return valueAsString(Metrics.std_deviation.name());
}
@Override
public String getStdDeviationBoundAsString(Bounds bound) {
return bound == Bounds.UPPER ? valueAsString(Metrics.std_upper.name()) : valueAsString(Metrics.std_lower.name());
}
@Override
public InternalExtendedStats reduce(ReduceContext reduceContext) {
double sumOfSqrs = 0;
@ -127,20 +150,28 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
sumOfSqrs += stats.getSumOfSquares();
}
final InternalStats stats = super.reduce(reduceContext);
return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, valueFormatter,
getMetaData());
return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, sigma, valueFormatter, getMetaData());
}
@Override
public void readOtherStatsFrom(StreamInput in) throws IOException {
sumOfSqrs = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_1_4_3)) {
sigma = in.readDouble();
} else {
sigma = 2.0;
}
}
@Override
protected void writeOtherStatsTo(StreamOutput out) throws IOException {
out.writeDouble(sumOfSqrs);
if (out.getVersion().onOrAfter(Version.V_1_4_3)) {
out.writeDouble(sigma);
}
}
static class Fields {
public static final XContentBuilderString SUM_OF_SQRS = new XContentBuilderString("sum_of_squares");
public static final XContentBuilderString SUM_OF_SQRS_AS_STRING = new XContentBuilderString("sum_of_squares_as_string");
@ -148,6 +179,11 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
public static final XContentBuilderString VARIANCE_AS_STRING = new XContentBuilderString("variance_as_string");
public static final XContentBuilderString STD_DEVIATION = new XContentBuilderString("std_deviation");
public static final XContentBuilderString STD_DEVIATION_AS_STRING = new XContentBuilderString("std_deviation_as_string");
public static final XContentBuilderString STD_DEVIATION_BOUNDS = new XContentBuilderString("std_deviation_bounds");
public static final XContentBuilderString STD_DEVIATION_BOUNDS_AS_STRING = new XContentBuilderString("std_deviation_bounds_as_string");
public static final XContentBuilderString UPPER = new XContentBuilderString("upper");
public static final XContentBuilderString LOWER = new XContentBuilderString("lower");
}
@Override
@ -155,12 +191,22 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
builder.field(Fields.SUM_OF_SQRS, count != 0 ? sumOfSqrs : null);
builder.field(Fields.VARIANCE, count != 0 ? getVariance() : null);
builder.field(Fields.STD_DEVIATION, count != 0 ? getStdDeviation() : null);
builder.startObject(Fields.STD_DEVIATION_BOUNDS)
.field(Fields.UPPER, count != 0 ? getStdDeviationBound(Bounds.UPPER) : null)
.field(Fields.LOWER, count != 0 ? getStdDeviationBound(Bounds.LOWER) : null)
.endObject();
if (count != 0 && valueFormatter != null) {
builder.field(Fields.SUM_OF_SQRS_AS_STRING, valueFormatter.format(sumOfSqrs));
builder.field(Fields.VARIANCE_AS_STRING, valueFormatter.format(getVariance()));
builder.field(Fields.STD_DEVIATION_AS_STRING, valueFormatter.format(getStdDeviation()));
builder.field(Fields.STD_DEVIATION_AS_STRING, getStdDeviationAsString());
builder.startObject(Fields.STD_DEVIATION_BOUNDS_AS_STRING)
.field(Fields.UPPER, getStdDeviationBoundAsString(Bounds.UPPER))
.field(Fields.LOWER, getStdDeviationBoundAsString(Bounds.LOWER))
.endObject();
}
return builder;
}
}

View File

@ -77,6 +77,8 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getMax(), equalTo(Double.NEGATIVE_INFINITY));
assertThat(Double.isNaN(stats.getStdDeviation()), is(true));
assertThat(Double.isNaN(stats.getAvg()), is(true));
assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.UPPER)), is(true));
assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.LOWER)), is(true));
}
@Test
@ -99,10 +101,39 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSumOfSquares(), equalTo(0.0));
assertThat(stats.getVariance(), equalTo(Double.NaN));
assertThat(stats.getStdDeviation(), equalTo(Double.NaN));
assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.UPPER)), is(true));
assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.LOWER)), is(true));
}
@Test
public void testSingleValuedField() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").field("value").sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
ExtendedStats stats = searchResponse.getAggregations().get("stats");
assertThat(stats, notNullValue());
assertThat(stats.getName(), equalTo("stats"));
assertThat(stats.getAvg(), equalTo((double) (1+2+3+4+5+6+7+8+9+10) / 10));
assertThat(stats.getMin(), equalTo(1.0));
assertThat(stats.getMax(), equalTo(10.0));
assertThat(stats.getSum(), equalTo((double) 1+2+3+4+5+6+7+8+9+10));
assertThat(stats.getCount(), equalTo(10l));
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100));
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testSingleValuedFieldDefaultSigma() throws Exception {
// Same as previous test, but uses a default value for sigma
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").field("value"))
@ -119,13 +150,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSum(), equalTo((double) 1+2+3+4+5+6+7+8+9+10));
assertThat(stats.getCount(), equalTo(10l));
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100));
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10)));
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
checkUpperLowerBounds(stats, 2);
}
public void testSingleValuedField_WithFormatter() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx").setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").format("0000.0").field("value")).execute().actionGet();
.addAggregation(extendedStats("stats").format("0000.0").field("value").sigma(sigma)).execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -148,6 +181,7 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getVarianceAsString(), equalTo("0008.2"));
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
assertThat(stats.getStdDeviationAsString(), equalTo("0002.9"));
checkUpperLowerBounds(stats, sigma);
}
@Test
@ -199,9 +233,10 @@ public class ExtendedStatsTests extends AbstractNumericTests {
@Test
public void testSingleValuedField_PartiallyUnmapped() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx", "idx_unmapped")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").field("value"))
.addAggregation(extendedStats("stats").field("value").sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -215,15 +250,17 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSum(), equalTo((double) 1+2+3+4+5+6+7+8+9+10));
assertThat(stats.getCount(), equalTo(10l));
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100));
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10)));
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testSingleValuedField_WithValueScript() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").field("value").script("_value + 1"))
.addAggregation(extendedStats("stats").field("value").script("_value + 1").sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -237,15 +274,17 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSum(), equalTo((double) 2+3+4+5+6+7+8+9+10+11));
assertThat(stats.getCount(), equalTo(10l));
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121));
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8, 9, 10, 11)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8, 9, 10, 11)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testSingleValuedField_WithValueScript_WithParams() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").field("value").script("_value + inc").param("inc", 1))
.addAggregation(extendedStats("stats").field("value").script("_value + inc").param("inc", 1).sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -259,15 +298,17 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSum(), equalTo((double) 2+3+4+5+6+7+8+9+10+11));
assertThat(stats.getCount(), equalTo(10l));
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121));
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8, 9, 10, 11)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8, 9, 10, 11)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testMultiValuedField() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").field("values"))
.addAggregation(extendedStats("stats").field("values").sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -281,15 +322,17 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSum(), equalTo((double) 2+3+4+5+6+7+8+9+10+11+3+4+5+6+7+8+9+10+11+12));
assertThat(stats.getCount(), equalTo(20l));
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121+9+16+25+36+49+64+81+100+121+144));
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12)));
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testMultiValuedField_WithValueScript() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").field("values").script("_value - 1"))
.addAggregation(extendedStats("stats").field("values").script("_value - 1").sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -305,13 +348,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100+4+9+16+25+36+49+64+81+100+121));
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testMultiValuedField_WithValueScript_WithParams() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").field("values").script("_value - dec").param("dec", 1))
.addAggregation(extendedStats("stats").field("values").script("_value - dec").param("dec", 1).sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -327,13 +372,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100+4+9+16+25+36+49+64+81+100+121));
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testScript_SingleValued() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").script("doc['value'].value"))
.addAggregation(extendedStats("stats").script("doc['value'].value").sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -349,13 +396,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100));
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testScript_SingleValued_WithParams() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1))
.addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1).sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -371,13 +420,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121));
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testScript_ExplicitSingleValued_WithParams() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1))
.addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1).sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -393,13 +444,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121));
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testScript_MultiValued() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").script("doc['values'].values"))
.addAggregation(extendedStats("stats").script("doc['values'].values").sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -415,13 +468,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121+9+16+25+36+49+64+81+100+121+144));
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testScript_ExplicitMultiValued() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").script("doc['values'].values"))
.addAggregation(extendedStats("stats").script("doc['values'].values").sigma(sigma))
.execute().actionGet();
assertShardExecutionState(searchResponse, 0);
@ -438,14 +493,16 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121+9+16+25+36+49+64+81+100+121+144));
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12)));
checkUpperLowerBounds(stats, sigma);
}
@Test
public void testScript_MultiValued_WithParams() throws Exception {
double sigma = randomDouble() * randomIntBetween(1, 10);
SearchResponse searchResponse = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(extendedStats("stats").script("[ doc['value'].value, doc['value'].value - dec ]").param("dec", 1))
.addAggregation(extendedStats("stats").script("[ doc['value'].value, doc['value'].value - dec ]").param("dec", 1).sigma(sigma))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
@ -461,6 +518,7 @@ public class ExtendedStatsTests extends AbstractNumericTests {
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100+0+1+4+9+16+25+36+49+64+81));
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8 ,9)));
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8 ,9)));
checkUpperLowerBounds(stats, sigma);
}
@ -474,4 +532,10 @@ public class ExtendedStatsTests extends AbstractNumericTests {
}
assertThat("Not all shards are initialized", response.getSuccessfulShards(), equalTo(response.getTotalShards()));
}
private void checkUpperLowerBounds(ExtendedStats stats, double sigma) {
assertThat(stats.getStdDeviationBound(ExtendedStats.Bounds.UPPER), equalTo(stats.getAvg() + (stats.getStdDeviation() * sigma)));
assertThat(stats.getStdDeviationBound(ExtendedStats.Bounds.LOWER), equalTo(stats.getAvg() - (stats.getStdDeviation() * sigma)));
}
}