From a4eb1d55051b4b6671938ae1213a205c106438d0 Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Thu, 22 Jan 2015 14:01:15 -0500 Subject: [PATCH] Aggregations: Add standard deviation bounds to extended_stats Extended_stats now displays the upper and lower bounds on standard deviations (e.g. avg +/- std). Default is to show 2 std above/below, but can be changed using the `sigma` parameter. Accepts non-negative doubles Closes #9356 --- .../extendedstats-aggregation.asciidoc | 59 ++++++++-- .../metrics/stats/extended/ExtendedStats.java | 24 +++- .../extended/ExtendedStatsAggregator.java | 28 +++-- .../stats/extended/ExtendedStatsBuilder.java | 20 ++++ .../stats/extended/ExtendedStatsParser.java | 54 ++++++++- .../stats/extended/InternalExtendedStats.java | 58 ++++++++- .../metrics/ExtendedStatsTests.java | 110 ++++++++++++++---- 7 files changed, 295 insertions(+), 58 deletions(-) diff --git a/docs/reference/search/aggregations/metrics/extendedstats-aggregation.asciidoc b/docs/reference/search/aggregations/metrics/extendedstats-aggregation.asciidoc index 5e001899e94..e7134eb52a6 100644 --- a/docs/reference/search/aggregations/metrics/extendedstats-aggregation.asciidoc +++ b/docs/reference/search/aggregations/metrics/extendedstats-aggregation.asciidoc @@ -3,7 +3,7 @@ A `multi-value` metrics aggregation that computes stats over numeric values extracted from the aggregated documents. These values can be extracted either from specific numeric fields in the documents, or be generated by a provided script. -The `extended_stats` aggregations is an extended version of the <> aggregation, where additional metrics are added such as `sum_of_squares`, `variance` and `std_deviation`. +The `extended_stats` aggregations is an extended version of the <> aggregation, where additional metrics are added such as `sum_of_squares`, `variance`, `std_deviation` and `std_deviation_bounds`. Assuming the data consists of documents representing exams grades (between 0 and 100) of students @@ -25,21 +25,58 @@ The above aggregation computes the grades statistics over all documents. The agg ... "aggregations": { - "grades_stats": { - "count": 6, - "min": 72, - "max": 117.6, - "avg": 94.2, - "sum": 565.2, - "sum_of_squares": 54551.51999999999, - "variance": 218.2799999999976, - "std_deviation": 14.774302013969987 + "grade_stats": { + "count": 9, + "min": 72, + "max": 99, + "avg": 86, + "sum": 774, + "sum_of_squares": 67028, + "variance": 51.55555555555556, + "std_deviation": 7.180219742846005, + "std_deviation_bounds": { + "upper": 100.36043948569201, + "lower": 71.63956051430799 + } } } } -------------------------------------------------- -The name of the aggregation (`grades_stats` above) also serves as the key by which the aggreagtion result can be retrieved from the returned response. +The name of the aggregation (`grades_stats` above) also serves as the key by which the aggregation result can be retrieved from the returned response. + +==== Standard Deviation Bounds +coming[1.4.3] + +By default, the `extended_stats` metric will return an object called `std_deviation_bounds`, which provides an interval of plus/minus two standard +deviations from the mean. This can be a useful way to visualize variance of your data. If you want a different boundary, for example +three standard deviations, you can set `sigma` in the request: + +[source,js] +-------------------------------------------------- +{ + "aggs" : { + "grades_stats" : { + "extended_stats" : { + "field" : "grade", + "sigma" : 3 <1> + } + } + } +} +-------------------------------------------------- +<1> `sigma` controls how many standard deviations +/- from the mean should be displayed coming[1.4.3] + +`sigma` can be any non-negative double, meaning you can request non-integer values such as `1.5`. A value of `0` is valid, but will simply +return the average for both `upper` and `lower` bounds. + +.Standard Deviation and Bounds require normality +[NOTE] +===== +The standard deviation and its bounds are displayed by default, but they are not always applicable to all data-sets. Your data must +be normally distributed for the metrics to make sense. The statistics behind standard deviations assumes normally distributed data, so +if your data is skewed heavily left or right, the value returned will be misleading. +===== ==== Script diff --git a/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStats.java b/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStats.java index 9fe541a785d..1b235a6cfec 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStats.java +++ b/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStats.java @@ -40,6 +40,22 @@ public interface ExtendedStats extends Stats { */ double getStdDeviation(); + /** + * The upper or lower bounds of the stdDeviation + */ + double getStdDeviationBound(Bounds bound); + + /** + * The standard deviation of the collected values as a String. + */ + String getStdDeviationAsString(); + + /** + * The upper or lower bounds of stdDev of the collected values as a String. + */ + String getStdDeviationBoundAsString(Bounds bound); + + /** * The sum of the squares of the collected values as a String. */ @@ -50,9 +66,9 @@ public interface ExtendedStats extends Stats { */ String getVarianceAsString(); - /** - * The standard deviation of the collected values as a String. - */ - String getStdDeviationAsString(); + + public enum Bounds { + UPPER, LOWER + } } diff --git a/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsAggregator.java b/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsAggregator.java index e9bf1a8081b..d7dd227f378 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsAggregator.java +++ b/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsAggregator.java @@ -52,9 +52,11 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue private DoubleArray maxes; private DoubleArray sumOfSqrs; private ValueFormatter formatter; + private double sigma; public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource, - @Nullable ValueFormatter formatter, AggregationContext context, Aggregator parent, Map metaData) throws IOException { + @Nullable ValueFormatter formatter, AggregationContext context, + Aggregator parent, double sigma, Map metaData) throws IOException { super(name, context, parent, metaData); this.valuesSource = valuesSource; this.formatter = formatter; @@ -66,6 +68,7 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue maxes = bigArrays.newDoubleArray(1, false); maxes.fill(0, maxes.size(), Double.NEGATIVE_INFINITY); sumOfSqrs = bigArrays.newDoubleArray(1, true); + this.sigma = sigma; } } @@ -134,6 +137,12 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue case sum_of_squares: return valuesSource == null ? 0 : sumOfSqrs.get(owningBucketOrd); case variance: return valuesSource == null ? Double.NaN : variance(owningBucketOrd); case std_deviation: return valuesSource == null ? Double.NaN : Math.sqrt(variance(owningBucketOrd)); + case std_upper: + if (valuesSource == null) { return Double.NaN; } + return (sums.get(owningBucketOrd) / counts.get(owningBucketOrd)) + (Math.sqrt(variance(owningBucketOrd)) * this.sigma); + case std_lower: + if (valuesSource == null) { return Double.NaN; } + return (sums.get(owningBucketOrd) / counts.get(owningBucketOrd)) - (Math.sqrt(variance(owningBucketOrd)) * this.sigma); default: throw new ElasticsearchIllegalArgumentException("Unknown value [" + name + "] in common stats aggregation"); } @@ -148,16 +157,16 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue @Override public InternalAggregation buildAggregation(long owningBucketOrdinal) { if (valuesSource == null) { - return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, formatter, metaData()); + return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, 0d, formatter, metaData()); } assert owningBucketOrdinal < counts.size(); return new InternalExtendedStats(name, counts.get(owningBucketOrdinal), sums.get(owningBucketOrdinal), - mins.get(owningBucketOrdinal), maxes.get(owningBucketOrdinal), sumOfSqrs.get(owningBucketOrdinal), formatter, metaData()); + mins.get(owningBucketOrdinal), maxes.get(owningBucketOrdinal), sumOfSqrs.get(owningBucketOrdinal), sigma, formatter, metaData()); } @Override public InternalAggregation buildEmptyAggregation() { - return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, formatter, metaData()); + return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, 0d, formatter, metaData()); } @Override @@ -167,19 +176,22 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue public static class Factory extends ValuesSourceAggregatorFactory.LeafOnly { - public Factory(String name, ValuesSourceConfig valuesSourceConfig) { + private final double sigma; + + public Factory(String name, ValuesSourceConfig valuesSourceConfig, double sigma) { super(name, InternalExtendedStats.TYPE.name(), valuesSourceConfig); + + this.sigma = sigma; } @Override protected Aggregator createUnmapped(AggregationContext aggregationContext, Aggregator parent, Map metaData) throws IOException { - return new ExtendedStatsAggregator(name, null, config.formatter(), aggregationContext, parent, metaData); + return new ExtendedStatsAggregator(name, null, config.formatter(), aggregationContext, parent, sigma, metaData); } @Override protected Aggregator doCreateInternal(ValuesSource.Numeric valuesSource, AggregationContext aggregationContext, Aggregator parent, boolean collectsFromSingleBucket, Map metaData) throws IOException { - return new ExtendedStatsAggregator(name, valuesSource, config.formatter(), aggregationContext, parent, - metaData); + return new ExtendedStatsAggregator(name, valuesSource, config.formatter(), aggregationContext, parent, sigma, metaData); } } } diff --git a/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsBuilder.java b/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsBuilder.java index 6eb095f6f63..28f4d739ada 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsBuilder.java +++ b/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsBuilder.java @@ -19,17 +19,37 @@ package org.elasticsearch.search.aggregations.metrics.stats.extended; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.search.aggregations.metrics.ValuesSourceMetricsAggregationBuilder; +import java.io.IOException; + /** * Builder for the {@link ExtendedStats} aggregation. */ public class ExtendedStatsBuilder extends ValuesSourceMetricsAggregationBuilder { + private Double sigma; + /** * Sole constructor. */ public ExtendedStatsBuilder(String name) { super(name, InternalExtendedStats.TYPE.name()); } + + public ExtendedStatsBuilder sigma(double sigma) { + this.sigma = sigma; + return this; + } + + @Override + protected void internalXContent(XContentBuilder builder, Params params) throws IOException { + super.internalXContent(builder, params); + + if (sigma != null) { + builder.field("sigma", sigma); + } + + } } diff --git a/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsParser.java b/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsParser.java index ee70c3d5b86..18ca93495c3 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsParser.java +++ b/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/ExtendedStatsParser.java @@ -18,22 +18,64 @@ */ package org.elasticsearch.search.aggregations.metrics.stats.extended; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchParseException; +import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactory; -import org.elasticsearch.search.aggregations.metrics.NumericValuesSourceMetricsAggregatorParser; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; +import org.elasticsearch.search.aggregations.support.ValuesSourceParser; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; /** * */ -public class ExtendedStatsParser extends NumericValuesSourceMetricsAggregatorParser { +public class ExtendedStatsParser implements Aggregator.Parser { - public ExtendedStatsParser() { - super(InternalExtendedStats.TYPE); + static final ParseField SIGMA = new ParseField("sigma"); + + @Override + public String type() { + return InternalExtendedStats.TYPE.name(); + } + + protected AggregatorFactory createFactory(String aggregationName, ValuesSourceConfig config, double sigma) { + return new ExtendedStatsAggregator.Factory(aggregationName, config, sigma); } @Override - protected AggregatorFactory createFactory(String aggregationName, ValuesSourceConfig config) { - return new ExtendedStatsAggregator.Factory(aggregationName, config); + public AggregatorFactory parse(String aggregationName, XContentParser parser, SearchContext context) throws IOException { + + ValuesSourceParser vsParser = ValuesSourceParser.numeric(aggregationName, InternalExtendedStats.TYPE, context).formattable(true) + .build(); + + XContentParser.Token token; + String currentFieldName = null; + double sigma = 2.0; + + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (vsParser.token(currentFieldName, token, parser)) { + continue; + } else if (token == XContentParser.Token.VALUE_NUMBER) { + if (SIGMA.match(currentFieldName)) { + sigma = parser.doubleValue(); + } else { + throw new SearchParseException(context, "Unknown key for a " + token + " in [" + aggregationName + "]: [" + currentFieldName + "]."); + } + } else { + throw new SearchParseException(context, "Unexpected token " + token + " in [" + aggregationName + "]."); + } + } + + if (sigma < 0) { + throw new SearchParseException(context, "[sigma] must not be negative. Value provided was" + sigma ); + } + + return createFactory(aggregationName, vsParser.config(), sigma); } } diff --git a/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/InternalExtendedStats.java b/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/InternalExtendedStats.java index 2f02d3d21c7..9a700690530 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/InternalExtendedStats.java +++ b/src/main/java/org/elasticsearch/search/aggregations/metrics/stats/extended/InternalExtendedStats.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.search.aggregations.metrics.stats.extended; +import org.elasticsearch.Version; import org.elasticsearch.common.inject.internal.Nullable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -53,7 +54,7 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat enum Metrics { - count, sum, min, max, avg, sum_of_squares, variance, std_deviation; + count, sum, min, max, avg, sum_of_squares, variance, std_deviation, std_upper, std_lower; public static Metrics resolve(String name) { return Metrics.valueOf(name); @@ -61,13 +62,15 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat } private double sumOfSqrs; + private double sigma; InternalExtendedStats() {} // for serialization public InternalExtendedStats(String name, long count, double sum, double min, double max, double sumOfSqrs, - @Nullable ValueFormatter formatter, Map metaData) { + double sigma, @Nullable ValueFormatter formatter, Map metaData) { super(name, count, sum, min, max, formatter, metaData); this.sumOfSqrs = sumOfSqrs; + this.sigma = sigma; } @Override @@ -86,6 +89,12 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat if ("std_deviation".equals(name)) { return getStdDeviation(); } + if ("std_upper".equals(name)) { + return getStdDeviationBound(Bounds.UPPER); + } + if ("std_lower".equals(name)) { + return getStdDeviationBound(Bounds.LOWER); + } return super.value(name); } @@ -104,6 +113,15 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat return Math.sqrt(getVariance()); } + @Override + public double getStdDeviationBound(Bounds bound) { + if (bound.equals(Bounds.UPPER)) { + return getAvg() + (getStdDeviation() * sigma); + } else { + return getAvg() - (getStdDeviation() * sigma); + } + } + @Override public String getSumOfSquaresAsString() { return valueAsString(Metrics.sum_of_squares.name()); @@ -119,6 +137,11 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat return valueAsString(Metrics.std_deviation.name()); } + @Override + public String getStdDeviationBoundAsString(Bounds bound) { + return bound == Bounds.UPPER ? valueAsString(Metrics.std_upper.name()) : valueAsString(Metrics.std_lower.name()); + } + @Override public InternalExtendedStats reduce(ReduceContext reduceContext) { double sumOfSqrs = 0; @@ -127,20 +150,28 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat sumOfSqrs += stats.getSumOfSquares(); } final InternalStats stats = super.reduce(reduceContext); - return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, valueFormatter, - getMetaData()); + return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, sigma, valueFormatter, getMetaData()); } @Override public void readOtherStatsFrom(StreamInput in) throws IOException { sumOfSqrs = in.readDouble(); + if (in.getVersion().onOrAfter(Version.V_1_4_3)) { + sigma = in.readDouble(); + } else { + sigma = 2.0; + } } @Override protected void writeOtherStatsTo(StreamOutput out) throws IOException { out.writeDouble(sumOfSqrs); + if (out.getVersion().onOrAfter(Version.V_1_4_3)) { + out.writeDouble(sigma); + } } + static class Fields { public static final XContentBuilderString SUM_OF_SQRS = new XContentBuilderString("sum_of_squares"); public static final XContentBuilderString SUM_OF_SQRS_AS_STRING = new XContentBuilderString("sum_of_squares_as_string"); @@ -148,6 +179,11 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat public static final XContentBuilderString VARIANCE_AS_STRING = new XContentBuilderString("variance_as_string"); public static final XContentBuilderString STD_DEVIATION = new XContentBuilderString("std_deviation"); public static final XContentBuilderString STD_DEVIATION_AS_STRING = new XContentBuilderString("std_deviation_as_string"); + public static final XContentBuilderString STD_DEVIATION_BOUNDS = new XContentBuilderString("std_deviation_bounds"); + public static final XContentBuilderString STD_DEVIATION_BOUNDS_AS_STRING = new XContentBuilderString("std_deviation_bounds_as_string"); + public static final XContentBuilderString UPPER = new XContentBuilderString("upper"); + public static final XContentBuilderString LOWER = new XContentBuilderString("lower"); + } @Override @@ -155,12 +191,22 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat builder.field(Fields.SUM_OF_SQRS, count != 0 ? sumOfSqrs : null); builder.field(Fields.VARIANCE, count != 0 ? getVariance() : null); builder.field(Fields.STD_DEVIATION, count != 0 ? getStdDeviation() : null); + builder.startObject(Fields.STD_DEVIATION_BOUNDS) + .field(Fields.UPPER, count != 0 ? getStdDeviationBound(Bounds.UPPER) : null) + .field(Fields.LOWER, count != 0 ? getStdDeviationBound(Bounds.LOWER) : null) + .endObject(); + if (count != 0 && valueFormatter != null) { builder.field(Fields.SUM_OF_SQRS_AS_STRING, valueFormatter.format(sumOfSqrs)); builder.field(Fields.VARIANCE_AS_STRING, valueFormatter.format(getVariance())); - builder.field(Fields.STD_DEVIATION_AS_STRING, valueFormatter.format(getStdDeviation())); + builder.field(Fields.STD_DEVIATION_AS_STRING, getStdDeviationAsString()); + + builder.startObject(Fields.STD_DEVIATION_BOUNDS_AS_STRING) + .field(Fields.UPPER, getStdDeviationBoundAsString(Bounds.UPPER)) + .field(Fields.LOWER, getStdDeviationBoundAsString(Bounds.LOWER)) + .endObject(); + } return builder; } - } diff --git a/src/test/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsTests.java b/src/test/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsTests.java index ccf25f159af..a768a7b941f 100644 --- a/src/test/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsTests.java +++ b/src/test/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsTests.java @@ -77,6 +77,8 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getMax(), equalTo(Double.NEGATIVE_INFINITY)); assertThat(Double.isNaN(stats.getStdDeviation()), is(true)); assertThat(Double.isNaN(stats.getAvg()), is(true)); + assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.UPPER)), is(true)); + assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.LOWER)), is(true)); } @Test @@ -99,10 +101,39 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSumOfSquares(), equalTo(0.0)); assertThat(stats.getVariance(), equalTo(Double.NaN)); assertThat(stats.getStdDeviation(), equalTo(Double.NaN)); + assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.UPPER)), is(true)); + assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.LOWER)), is(true)); } @Test public void testSingleValuedField() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); + SearchResponse searchResponse = client().prepareSearch("idx") + .setQuery(matchAllQuery()) + .addAggregation(extendedStats("stats").field("value").sigma(sigma)) + .execute().actionGet(); + + assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); + + ExtendedStats stats = searchResponse.getAggregations().get("stats"); + assertThat(stats, notNullValue()); + assertThat(stats.getName(), equalTo("stats")); + assertThat(stats.getAvg(), equalTo((double) (1+2+3+4+5+6+7+8+9+10) / 10)); + assertThat(stats.getMin(), equalTo(1.0)); + assertThat(stats.getMax(), equalTo(10.0)); + assertThat(stats.getSum(), equalTo((double) 1+2+3+4+5+6+7+8+9+10)); + assertThat(stats.getCount(), equalTo(10l)); + assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100)); + assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))); + assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))); + checkUpperLowerBounds(stats, sigma); + } + + @Test + public void testSingleValuedFieldDefaultSigma() throws Exception { + + // Same as previous test, but uses a default value for sigma + SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) .addAggregation(extendedStats("stats").field("value")) @@ -119,13 +150,15 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSum(), equalTo((double) 1+2+3+4+5+6+7+8+9+10)); assertThat(stats.getCount(), equalTo(10l)); assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100)); - assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10))); - assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10))); + assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))); + assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))); + checkUpperLowerBounds(stats, 2); } public void testSingleValuedField_WithFormatter() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx").setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").format("0000.0").field("value")).execute().actionGet(); + .addAggregation(extendedStats("stats").format("0000.0").field("value").sigma(sigma)).execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -148,6 +181,7 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getVarianceAsString(), equalTo("0008.2")); assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))); assertThat(stats.getStdDeviationAsString(), equalTo("0002.9")); + checkUpperLowerBounds(stats, sigma); } @Test @@ -199,9 +233,10 @@ public class ExtendedStatsTests extends AbstractNumericTests { @Test public void testSingleValuedField_PartiallyUnmapped() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx", "idx_unmapped") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").field("value")) + .addAggregation(extendedStats("stats").field("value").sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -215,15 +250,17 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSum(), equalTo((double) 1+2+3+4+5+6+7+8+9+10)); assertThat(stats.getCount(), equalTo(10l)); assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100)); - assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10))); - assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10))); + assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))); + assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))); + checkUpperLowerBounds(stats, sigma); } @Test public void testSingleValuedField_WithValueScript() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").field("value").script("_value + 1")) + .addAggregation(extendedStats("stats").field("value").script("_value + 1").sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -237,15 +274,17 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSum(), equalTo((double) 2+3+4+5+6+7+8+9+10+11)); assertThat(stats.getCount(), equalTo(10l)); assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121)); - assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); - assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); + assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8, 9, 10, 11))); + assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8, 9, 10, 11))); + checkUpperLowerBounds(stats, sigma); } @Test public void testSingleValuedField_WithValueScript_WithParams() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").field("value").script("_value + inc").param("inc", 1)) + .addAggregation(extendedStats("stats").field("value").script("_value + inc").param("inc", 1).sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -259,15 +298,17 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSum(), equalTo((double) 2+3+4+5+6+7+8+9+10+11)); assertThat(stats.getCount(), equalTo(10l)); assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121)); - assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); - assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); + assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8, 9, 10, 11))); + assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8, 9, 10, 11))); + checkUpperLowerBounds(stats, sigma); } @Test public void testMultiValuedField() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").field("values")) + .addAggregation(extendedStats("stats").field("values").sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -281,15 +322,17 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSum(), equalTo((double) 2+3+4+5+6+7+8+9+10+11+3+4+5+6+7+8+9+10+11+12)); assertThat(stats.getCount(), equalTo(20l)); assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121+9+16+25+36+49+64+81+100+121+144)); - assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12))); - assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12))); + assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12))); + assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12))); + checkUpperLowerBounds(stats, sigma); } @Test public void testMultiValuedField_WithValueScript() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").field("values").script("_value - 1")) + .addAggregation(extendedStats("stats").field("values").script("_value - 1").sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -305,13 +348,15 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100+4+9+16+25+36+49+64+81+100+121)); assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); + checkUpperLowerBounds(stats, sigma); } @Test public void testMultiValuedField_WithValueScript_WithParams() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").field("values").script("_value - dec").param("dec", 1)) + .addAggregation(extendedStats("stats").field("values").script("_value - dec").param("dec", 1).sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -327,13 +372,15 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100+4+9+16+25+36+49+64+81+100+121)); assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); + checkUpperLowerBounds(stats, sigma); } @Test public void testScript_SingleValued() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").script("doc['value'].value")) + .addAggregation(extendedStats("stats").script("doc['value'].value").sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -349,13 +396,15 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100)); assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10))); assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10))); + checkUpperLowerBounds(stats, sigma); } @Test public void testScript_SingleValued_WithParams() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1)) + .addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1).sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -371,13 +420,15 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121)); assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); + checkUpperLowerBounds(stats, sigma); } @Test public void testScript_ExplicitSingleValued_WithParams() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1)) + .addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1).sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -393,13 +444,15 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121)); assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11))); + checkUpperLowerBounds(stats, sigma); } @Test public void testScript_MultiValued() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").script("doc['values'].values")) + .addAggregation(extendedStats("stats").script("doc['values'].values").sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -415,13 +468,15 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121+9+16+25+36+49+64+81+100+121+144)); assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12))); assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12))); + checkUpperLowerBounds(stats, sigma); } @Test public void testScript_ExplicitMultiValued() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").script("doc['values'].values")) + .addAggregation(extendedStats("stats").script("doc['values'].values").sigma(sigma)) .execute().actionGet(); assertShardExecutionState(searchResponse, 0); @@ -438,14 +493,16 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121+9+16+25+36+49+64+81+100+121+144)); assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12))); assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12))); + checkUpperLowerBounds(stats, sigma); } @Test public void testScript_MultiValued_WithParams() throws Exception { + double sigma = randomDouble() * randomIntBetween(1, 10); SearchResponse searchResponse = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(extendedStats("stats").script("[ doc['value'].value, doc['value'].value - dec ]").param("dec", 1)) + .addAggregation(extendedStats("stats").script("[ doc['value'].value, doc['value'].value - dec ]").param("dec", 1).sigma(sigma)) .execute().actionGet(); assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l)); @@ -461,6 +518,7 @@ public class ExtendedStatsTests extends AbstractNumericTests { assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100+0+1+4+9+16+25+36+49+64+81)); assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8 ,9))); assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8 ,9))); + checkUpperLowerBounds(stats, sigma); } @@ -474,4 +532,10 @@ public class ExtendedStatsTests extends AbstractNumericTests { } assertThat("Not all shards are initialized", response.getSuccessfulShards(), equalTo(response.getTotalShards())); } + + private void checkUpperLowerBounds(ExtendedStats stats, double sigma) { + assertThat(stats.getStdDeviationBound(ExtendedStats.Bounds.UPPER), equalTo(stats.getAvg() + (stats.getStdDeviation() * sigma))); + assertThat(stats.getStdDeviationBound(ExtendedStats.Bounds.LOWER), equalTo(stats.getAvg() - (stats.getStdDeviation() * sigma))); + } + } \ No newline at end of file