Aggregations: Add standard deviation bounds to extended_stats
Extended_stats now displays the upper and lower bounds on standard deviations (e.g. avg +/- std). Default is to show 2 std above/below, but can be changed using the `sigma` parameter. Accepts non-negative doubles Closes #9356
This commit is contained in:
parent
3e4fc2659d
commit
a4eb1d5505
|
@ -3,7 +3,7 @@
|
|||
|
||||
A `multi-value` metrics aggregation that computes stats over numeric values extracted from the aggregated documents. These values can be extracted either from specific numeric fields in the documents, or be generated by a provided script.
|
||||
|
||||
The `extended_stats` aggregations is an extended version of the <<search-aggregations-metrics-stats-aggregation,`stats`>> aggregation, where additional metrics are added such as `sum_of_squares`, `variance` and `std_deviation`.
|
||||
The `extended_stats` aggregations is an extended version of the <<search-aggregations-metrics-stats-aggregation,`stats`>> aggregation, where additional metrics are added such as `sum_of_squares`, `variance`, `std_deviation` and `std_deviation_bounds`.
|
||||
|
||||
Assuming the data consists of documents representing exams grades (between 0 and 100) of students
|
||||
|
||||
|
@ -25,21 +25,58 @@ The above aggregation computes the grades statistics over all documents. The agg
|
|||
...
|
||||
|
||||
"aggregations": {
|
||||
"grades_stats": {
|
||||
"count": 6,
|
||||
"grade_stats": {
|
||||
"count": 9,
|
||||
"min": 72,
|
||||
"max": 117.6,
|
||||
"avg": 94.2,
|
||||
"sum": 565.2,
|
||||
"sum_of_squares": 54551.51999999999,
|
||||
"variance": 218.2799999999976,
|
||||
"std_deviation": 14.774302013969987
|
||||
"max": 99,
|
||||
"avg": 86,
|
||||
"sum": 774,
|
||||
"sum_of_squares": 67028,
|
||||
"variance": 51.55555555555556,
|
||||
"std_deviation": 7.180219742846005,
|
||||
"std_deviation_bounds": {
|
||||
"upper": 100.36043948569201,
|
||||
"lower": 71.63956051430799
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
--------------------------------------------------
|
||||
|
||||
The name of the aggregation (`grades_stats` above) also serves as the key by which the aggreagtion result can be retrieved from the returned response.
|
||||
The name of the aggregation (`grades_stats` above) also serves as the key by which the aggregation result can be retrieved from the returned response.
|
||||
|
||||
==== Standard Deviation Bounds
|
||||
coming[1.4.3]
|
||||
|
||||
By default, the `extended_stats` metric will return an object called `std_deviation_bounds`, which provides an interval of plus/minus two standard
|
||||
deviations from the mean. This can be a useful way to visualize variance of your data. If you want a different boundary, for example
|
||||
three standard deviations, you can set `sigma` in the request:
|
||||
|
||||
[source,js]
|
||||
--------------------------------------------------
|
||||
{
|
||||
"aggs" : {
|
||||
"grades_stats" : {
|
||||
"extended_stats" : {
|
||||
"field" : "grade",
|
||||
"sigma" : 3 <1>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
--------------------------------------------------
|
||||
<1> `sigma` controls how many standard deviations +/- from the mean should be displayed coming[1.4.3]
|
||||
|
||||
`sigma` can be any non-negative double, meaning you can request non-integer values such as `1.5`. A value of `0` is valid, but will simply
|
||||
return the average for both `upper` and `lower` bounds.
|
||||
|
||||
.Standard Deviation and Bounds require normality
|
||||
[NOTE]
|
||||
=====
|
||||
The standard deviation and its bounds are displayed by default, but they are not always applicable to all data-sets. Your data must
|
||||
be normally distributed for the metrics to make sense. The statistics behind standard deviations assumes normally distributed data, so
|
||||
if your data is skewed heavily left or right, the value returned will be misleading.
|
||||
=====
|
||||
|
||||
==== Script
|
||||
|
||||
|
|
|
@ -40,6 +40,22 @@ public interface ExtendedStats extends Stats {
|
|||
*/
|
||||
double getStdDeviation();
|
||||
|
||||
/**
|
||||
* The upper or lower bounds of the stdDeviation
|
||||
*/
|
||||
double getStdDeviationBound(Bounds bound);
|
||||
|
||||
/**
|
||||
* The standard deviation of the collected values as a String.
|
||||
*/
|
||||
String getStdDeviationAsString();
|
||||
|
||||
/**
|
||||
* The upper or lower bounds of stdDev of the collected values as a String.
|
||||
*/
|
||||
String getStdDeviationBoundAsString(Bounds bound);
|
||||
|
||||
|
||||
/**
|
||||
* The sum of the squares of the collected values as a String.
|
||||
*/
|
||||
|
@ -50,9 +66,9 @@ public interface ExtendedStats extends Stats {
|
|||
*/
|
||||
String getVarianceAsString();
|
||||
|
||||
/**
|
||||
* The standard deviation of the collected values as a String.
|
||||
*/
|
||||
String getStdDeviationAsString();
|
||||
|
||||
public enum Bounds {
|
||||
UPPER, LOWER
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -52,9 +52,11 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
|
|||
private DoubleArray maxes;
|
||||
private DoubleArray sumOfSqrs;
|
||||
private ValueFormatter formatter;
|
||||
private double sigma;
|
||||
|
||||
public ExtendedStatsAggregator(String name, ValuesSource.Numeric valuesSource,
|
||||
@Nullable ValueFormatter formatter, AggregationContext context, Aggregator parent, Map<String, Object> metaData) throws IOException {
|
||||
@Nullable ValueFormatter formatter, AggregationContext context,
|
||||
Aggregator parent, double sigma, Map<String, Object> metaData) throws IOException {
|
||||
super(name, context, parent, metaData);
|
||||
this.valuesSource = valuesSource;
|
||||
this.formatter = formatter;
|
||||
|
@ -66,6 +68,7 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
|
|||
maxes = bigArrays.newDoubleArray(1, false);
|
||||
maxes.fill(0, maxes.size(), Double.NEGATIVE_INFINITY);
|
||||
sumOfSqrs = bigArrays.newDoubleArray(1, true);
|
||||
this.sigma = sigma;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -134,6 +137,12 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
|
|||
case sum_of_squares: return valuesSource == null ? 0 : sumOfSqrs.get(owningBucketOrd);
|
||||
case variance: return valuesSource == null ? Double.NaN : variance(owningBucketOrd);
|
||||
case std_deviation: return valuesSource == null ? Double.NaN : Math.sqrt(variance(owningBucketOrd));
|
||||
case std_upper:
|
||||
if (valuesSource == null) { return Double.NaN; }
|
||||
return (sums.get(owningBucketOrd) / counts.get(owningBucketOrd)) + (Math.sqrt(variance(owningBucketOrd)) * this.sigma);
|
||||
case std_lower:
|
||||
if (valuesSource == null) { return Double.NaN; }
|
||||
return (sums.get(owningBucketOrd) / counts.get(owningBucketOrd)) - (Math.sqrt(variance(owningBucketOrd)) * this.sigma);
|
||||
default:
|
||||
throw new ElasticsearchIllegalArgumentException("Unknown value [" + name + "] in common stats aggregation");
|
||||
}
|
||||
|
@ -148,16 +157,16 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
|
|||
@Override
|
||||
public InternalAggregation buildAggregation(long owningBucketOrdinal) {
|
||||
if (valuesSource == null) {
|
||||
return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, formatter, metaData());
|
||||
return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, 0d, formatter, metaData());
|
||||
}
|
||||
assert owningBucketOrdinal < counts.size();
|
||||
return new InternalExtendedStats(name, counts.get(owningBucketOrdinal), sums.get(owningBucketOrdinal),
|
||||
mins.get(owningBucketOrdinal), maxes.get(owningBucketOrdinal), sumOfSqrs.get(owningBucketOrdinal), formatter, metaData());
|
||||
mins.get(owningBucketOrdinal), maxes.get(owningBucketOrdinal), sumOfSqrs.get(owningBucketOrdinal), sigma, formatter, metaData());
|
||||
}
|
||||
|
||||
@Override
|
||||
public InternalAggregation buildEmptyAggregation() {
|
||||
return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, formatter, metaData());
|
||||
return new InternalExtendedStats(name, 0, 0d, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0d, 0d, formatter, metaData());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -167,19 +176,22 @@ public class ExtendedStatsAggregator extends NumericMetricsAggregator.MultiValue
|
|||
|
||||
public static class Factory extends ValuesSourceAggregatorFactory.LeafOnly<ValuesSource.Numeric> {
|
||||
|
||||
public Factory(String name, ValuesSourceConfig<ValuesSource.Numeric> valuesSourceConfig) {
|
||||
private final double sigma;
|
||||
|
||||
public Factory(String name, ValuesSourceConfig<ValuesSource.Numeric> valuesSourceConfig, double sigma) {
|
||||
super(name, InternalExtendedStats.TYPE.name(), valuesSourceConfig);
|
||||
|
||||
this.sigma = sigma;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Aggregator createUnmapped(AggregationContext aggregationContext, Aggregator parent, Map<String, Object> metaData) throws IOException {
|
||||
return new ExtendedStatsAggregator(name, null, config.formatter(), aggregationContext, parent, metaData);
|
||||
return new ExtendedStatsAggregator(name, null, config.formatter(), aggregationContext, parent, sigma, metaData);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Aggregator doCreateInternal(ValuesSource.Numeric valuesSource, AggregationContext aggregationContext, Aggregator parent, boolean collectsFromSingleBucket, Map<String, Object> metaData) throws IOException {
|
||||
return new ExtendedStatsAggregator(name, valuesSource, config.formatter(), aggregationContext, parent,
|
||||
metaData);
|
||||
return new ExtendedStatsAggregator(name, valuesSource, config.formatter(), aggregationContext, parent, sigma, metaData);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,17 +19,37 @@
|
|||
|
||||
package org.elasticsearch.search.aggregations.metrics.stats.extended;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.ValuesSourceMetricsAggregationBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Builder for the {@link ExtendedStats} aggregation.
|
||||
*/
|
||||
public class ExtendedStatsBuilder extends ValuesSourceMetricsAggregationBuilder<ExtendedStatsBuilder> {
|
||||
|
||||
private Double sigma;
|
||||
|
||||
/**
|
||||
* Sole constructor.
|
||||
*/
|
||||
public ExtendedStatsBuilder(String name) {
|
||||
super(name, InternalExtendedStats.TYPE.name());
|
||||
}
|
||||
|
||||
public ExtendedStatsBuilder sigma(double sigma) {
|
||||
this.sigma = sigma;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void internalXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
super.internalXContent(builder, params);
|
||||
|
||||
if (sigma != null) {
|
||||
builder.field("sigma", sigma);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,22 +18,64 @@
|
|||
*/
|
||||
package org.elasticsearch.search.aggregations.metrics.stats.extended;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchParseException;
|
||||
import org.elasticsearch.search.aggregations.Aggregator;
|
||||
import org.elasticsearch.search.aggregations.AggregatorFactory;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericValuesSourceMetricsAggregatorParser;
|
||||
import org.elasticsearch.search.aggregations.support.ValuesSource;
|
||||
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
|
||||
import org.elasticsearch.search.aggregations.support.ValuesSourceParser;
|
||||
import org.elasticsearch.search.internal.SearchContext;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public class ExtendedStatsParser extends NumericValuesSourceMetricsAggregatorParser<InternalExtendedStats> {
|
||||
public class ExtendedStatsParser implements Aggregator.Parser {
|
||||
|
||||
public ExtendedStatsParser() {
|
||||
super(InternalExtendedStats.TYPE);
|
||||
static final ParseField SIGMA = new ParseField("sigma");
|
||||
|
||||
@Override
|
||||
public String type() {
|
||||
return InternalExtendedStats.TYPE.name();
|
||||
}
|
||||
|
||||
protected AggregatorFactory createFactory(String aggregationName, ValuesSourceConfig<ValuesSource.Numeric> config, double sigma) {
|
||||
return new ExtendedStatsAggregator.Factory(aggregationName, config, sigma);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AggregatorFactory createFactory(String aggregationName, ValuesSourceConfig<ValuesSource.Numeric> config) {
|
||||
return new ExtendedStatsAggregator.Factory(aggregationName, config);
|
||||
public AggregatorFactory parse(String aggregationName, XContentParser parser, SearchContext context) throws IOException {
|
||||
|
||||
ValuesSourceParser<ValuesSource.Numeric> vsParser = ValuesSourceParser.numeric(aggregationName, InternalExtendedStats.TYPE, context).formattable(true)
|
||||
.build();
|
||||
|
||||
XContentParser.Token token;
|
||||
String currentFieldName = null;
|
||||
double sigma = 2.0;
|
||||
|
||||
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
|
||||
if (token == XContentParser.Token.FIELD_NAME) {
|
||||
currentFieldName = parser.currentName();
|
||||
} else if (vsParser.token(currentFieldName, token, parser)) {
|
||||
continue;
|
||||
} else if (token == XContentParser.Token.VALUE_NUMBER) {
|
||||
if (SIGMA.match(currentFieldName)) {
|
||||
sigma = parser.doubleValue();
|
||||
} else {
|
||||
throw new SearchParseException(context, "Unknown key for a " + token + " in [" + aggregationName + "]: [" + currentFieldName + "].");
|
||||
}
|
||||
} else {
|
||||
throw new SearchParseException(context, "Unexpected token " + token + " in [" + aggregationName + "].");
|
||||
}
|
||||
}
|
||||
|
||||
if (sigma < 0) {
|
||||
throw new SearchParseException(context, "[sigma] must not be negative. Value provided was" + sigma );
|
||||
}
|
||||
|
||||
return createFactory(aggregationName, vsParser.config(), sigma);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
*/
|
||||
package org.elasticsearch.search.aggregations.metrics.stats.extended;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.inject.internal.Nullable;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
|
@ -53,7 +54,7 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
|
|||
|
||||
enum Metrics {
|
||||
|
||||
count, sum, min, max, avg, sum_of_squares, variance, std_deviation;
|
||||
count, sum, min, max, avg, sum_of_squares, variance, std_deviation, std_upper, std_lower;
|
||||
|
||||
public static Metrics resolve(String name) {
|
||||
return Metrics.valueOf(name);
|
||||
|
@ -61,13 +62,15 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
|
|||
}
|
||||
|
||||
private double sumOfSqrs;
|
||||
private double sigma;
|
||||
|
||||
InternalExtendedStats() {} // for serialization
|
||||
|
||||
public InternalExtendedStats(String name, long count, double sum, double min, double max, double sumOfSqrs,
|
||||
@Nullable ValueFormatter formatter, Map<String, Object> metaData) {
|
||||
double sigma, @Nullable ValueFormatter formatter, Map<String, Object> metaData) {
|
||||
super(name, count, sum, min, max, formatter, metaData);
|
||||
this.sumOfSqrs = sumOfSqrs;
|
||||
this.sigma = sigma;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -86,6 +89,12 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
|
|||
if ("std_deviation".equals(name)) {
|
||||
return getStdDeviation();
|
||||
}
|
||||
if ("std_upper".equals(name)) {
|
||||
return getStdDeviationBound(Bounds.UPPER);
|
||||
}
|
||||
if ("std_lower".equals(name)) {
|
||||
return getStdDeviationBound(Bounds.LOWER);
|
||||
}
|
||||
return super.value(name);
|
||||
}
|
||||
|
||||
|
@ -104,6 +113,15 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
|
|||
return Math.sqrt(getVariance());
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getStdDeviationBound(Bounds bound) {
|
||||
if (bound.equals(Bounds.UPPER)) {
|
||||
return getAvg() + (getStdDeviation() * sigma);
|
||||
} else {
|
||||
return getAvg() - (getStdDeviation() * sigma);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getSumOfSquaresAsString() {
|
||||
return valueAsString(Metrics.sum_of_squares.name());
|
||||
|
@ -119,6 +137,11 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
|
|||
return valueAsString(Metrics.std_deviation.name());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getStdDeviationBoundAsString(Bounds bound) {
|
||||
return bound == Bounds.UPPER ? valueAsString(Metrics.std_upper.name()) : valueAsString(Metrics.std_lower.name());
|
||||
}
|
||||
|
||||
@Override
|
||||
public InternalExtendedStats reduce(ReduceContext reduceContext) {
|
||||
double sumOfSqrs = 0;
|
||||
|
@ -127,19 +150,27 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
|
|||
sumOfSqrs += stats.getSumOfSquares();
|
||||
}
|
||||
final InternalStats stats = super.reduce(reduceContext);
|
||||
return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, valueFormatter,
|
||||
getMetaData());
|
||||
return new InternalExtendedStats(name, stats.getCount(), stats.getSum(), stats.getMin(), stats.getMax(), sumOfSqrs, sigma, valueFormatter, getMetaData());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readOtherStatsFrom(StreamInput in) throws IOException {
|
||||
sumOfSqrs = in.readDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_1_4_3)) {
|
||||
sigma = in.readDouble();
|
||||
} else {
|
||||
sigma = 2.0;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void writeOtherStatsTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(sumOfSqrs);
|
||||
if (out.getVersion().onOrAfter(Version.V_1_4_3)) {
|
||||
out.writeDouble(sigma);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static class Fields {
|
||||
public static final XContentBuilderString SUM_OF_SQRS = new XContentBuilderString("sum_of_squares");
|
||||
|
@ -148,6 +179,11 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
|
|||
public static final XContentBuilderString VARIANCE_AS_STRING = new XContentBuilderString("variance_as_string");
|
||||
public static final XContentBuilderString STD_DEVIATION = new XContentBuilderString("std_deviation");
|
||||
public static final XContentBuilderString STD_DEVIATION_AS_STRING = new XContentBuilderString("std_deviation_as_string");
|
||||
public static final XContentBuilderString STD_DEVIATION_BOUNDS = new XContentBuilderString("std_deviation_bounds");
|
||||
public static final XContentBuilderString STD_DEVIATION_BOUNDS_AS_STRING = new XContentBuilderString("std_deviation_bounds_as_string");
|
||||
public static final XContentBuilderString UPPER = new XContentBuilderString("upper");
|
||||
public static final XContentBuilderString LOWER = new XContentBuilderString("lower");
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -155,12 +191,22 @@ public class InternalExtendedStats extends InternalStats implements ExtendedStat
|
|||
builder.field(Fields.SUM_OF_SQRS, count != 0 ? sumOfSqrs : null);
|
||||
builder.field(Fields.VARIANCE, count != 0 ? getVariance() : null);
|
||||
builder.field(Fields.STD_DEVIATION, count != 0 ? getStdDeviation() : null);
|
||||
builder.startObject(Fields.STD_DEVIATION_BOUNDS)
|
||||
.field(Fields.UPPER, count != 0 ? getStdDeviationBound(Bounds.UPPER) : null)
|
||||
.field(Fields.LOWER, count != 0 ? getStdDeviationBound(Bounds.LOWER) : null)
|
||||
.endObject();
|
||||
|
||||
if (count != 0 && valueFormatter != null) {
|
||||
builder.field(Fields.SUM_OF_SQRS_AS_STRING, valueFormatter.format(sumOfSqrs));
|
||||
builder.field(Fields.VARIANCE_AS_STRING, valueFormatter.format(getVariance()));
|
||||
builder.field(Fields.STD_DEVIATION_AS_STRING, valueFormatter.format(getStdDeviation()));
|
||||
builder.field(Fields.STD_DEVIATION_AS_STRING, getStdDeviationAsString());
|
||||
|
||||
builder.startObject(Fields.STD_DEVIATION_BOUNDS_AS_STRING)
|
||||
.field(Fields.UPPER, getStdDeviationBoundAsString(Bounds.UPPER))
|
||||
.field(Fields.LOWER, getStdDeviationBoundAsString(Bounds.LOWER))
|
||||
.endObject();
|
||||
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -77,6 +77,8 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getMax(), equalTo(Double.NEGATIVE_INFINITY));
|
||||
assertThat(Double.isNaN(stats.getStdDeviation()), is(true));
|
||||
assertThat(Double.isNaN(stats.getAvg()), is(true));
|
||||
assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.UPPER)), is(true));
|
||||
assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.LOWER)), is(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -99,10 +101,39 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo(0.0));
|
||||
assertThat(stats.getVariance(), equalTo(Double.NaN));
|
||||
assertThat(stats.getStdDeviation(), equalTo(Double.NaN));
|
||||
assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.UPPER)), is(true));
|
||||
assertThat(Double.isNaN(stats.getStdDeviationBound(ExtendedStats.Bounds.LOWER)), is(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSingleValuedField() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").field("value").sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
||||
ExtendedStats stats = searchResponse.getAggregations().get("stats");
|
||||
assertThat(stats, notNullValue());
|
||||
assertThat(stats.getName(), equalTo("stats"));
|
||||
assertThat(stats.getAvg(), equalTo((double) (1+2+3+4+5+6+7+8+9+10) / 10));
|
||||
assertThat(stats.getMin(), equalTo(1.0));
|
||||
assertThat(stats.getMax(), equalTo(10.0));
|
||||
assertThat(stats.getSum(), equalTo((double) 1+2+3+4+5+6+7+8+9+10));
|
||||
assertThat(stats.getCount(), equalTo(10l));
|
||||
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100));
|
||||
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSingleValuedFieldDefaultSigma() throws Exception {
|
||||
|
||||
// Same as previous test, but uses a default value for sigma
|
||||
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").field("value"))
|
||||
|
@ -121,11 +152,13 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100));
|
||||
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
|
||||
checkUpperLowerBounds(stats, 2);
|
||||
}
|
||||
|
||||
public void testSingleValuedField_WithFormatter() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx").setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").format("0000.0").field("value")).execute().actionGet();
|
||||
.addAggregation(extendedStats("stats").format("0000.0").field("value").sigma(sigma)).execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
||||
|
@ -148,6 +181,7 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getVarianceAsString(), equalTo("0008.2"));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
|
||||
assertThat(stats.getStdDeviationAsString(), equalTo("0002.9"));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -199,9 +233,10 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
|
||||
@Test
|
||||
public void testSingleValuedField_PartiallyUnmapped() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx", "idx_unmapped")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").field("value"))
|
||||
.addAggregation(extendedStats("stats").field("value").sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -217,13 +252,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100));
|
||||
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSingleValuedField_WithValueScript() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").field("value").script("_value + 1"))
|
||||
.addAggregation(extendedStats("stats").field("value").script("_value + 1").sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -239,13 +276,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121));
|
||||
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8, 9, 10, 11)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8, 9, 10, 11)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSingleValuedField_WithValueScript_WithParams() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").field("value").script("_value + inc").param("inc", 1))
|
||||
.addAggregation(extendedStats("stats").field("value").script("_value + inc").param("inc", 1).sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -261,13 +300,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121));
|
||||
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8, 9, 10, 11)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8, 9, 10, 11)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiValuedField() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").field("values"))
|
||||
.addAggregation(extendedStats("stats").field("values").sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -283,13 +324,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121+9+16+25+36+49+64+81+100+121+144));
|
||||
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiValuedField_WithValueScript() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").field("values").script("_value - 1"))
|
||||
.addAggregation(extendedStats("stats").field("values").script("_value - 1").sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -305,13 +348,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100+4+9+16+25+36+49+64+81+100+121));
|
||||
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiValuedField_WithValueScript_WithParams() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").field("values").script("_value - dec").param("dec", 1))
|
||||
.addAggregation(extendedStats("stats").field("values").script("_value - dec").param("dec", 1).sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -327,13 +372,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100+4+9+16+25+36+49+64+81+100+121));
|
||||
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScript_SingleValued() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").script("doc['value'].value"))
|
||||
.addAggregation(extendedStats("stats").script("doc['value'].value").sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -349,13 +396,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100));
|
||||
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScript_SingleValued_WithParams() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1))
|
||||
.addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1).sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -371,13 +420,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121));
|
||||
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScript_ExplicitSingleValued_WithParams() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1))
|
||||
.addAggregation(extendedStats("stats").script("doc['value'].value + inc").param("inc", 1).sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -393,13 +444,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121));
|
||||
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScript_MultiValued() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").script("doc['values'].values"))
|
||||
.addAggregation(extendedStats("stats").script("doc['values'].values").sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -415,13 +468,15 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121+9+16+25+36+49+64+81+100+121+144));
|
||||
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScript_ExplicitMultiValued() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").script("doc['values'].values"))
|
||||
.addAggregation(extendedStats("stats").script("doc['values'].values").sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertShardExecutionState(searchResponse, 0);
|
||||
|
@ -438,14 +493,16 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 4+9+16+25+36+49+64+81+100+121+9+16+25+36+49+64+81+100+121+144));
|
||||
assertThat(stats.getVariance(), equalTo(variance(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(2, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 3, 4, 5, 6, 7, 8 ,9, 10, 11, 12)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScript_MultiValued_WithParams() throws Exception {
|
||||
double sigma = randomDouble() * randomIntBetween(1, 10);
|
||||
SearchResponse searchResponse = client().prepareSearch("idx")
|
||||
.setQuery(matchAllQuery())
|
||||
.addAggregation(extendedStats("stats").script("[ doc['value'].value, doc['value'].value - dec ]").param("dec", 1))
|
||||
.addAggregation(extendedStats("stats").script("[ doc['value'].value, doc['value'].value - dec ]").param("dec", 1).sigma(sigma))
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(searchResponse.getHits().getTotalHits(), equalTo(10l));
|
||||
|
@ -461,6 +518,7 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
assertThat(stats.getSumOfSquares(), equalTo((double) 1+4+9+16+25+36+49+64+81+100+0+1+4+9+16+25+36+49+64+81));
|
||||
assertThat(stats.getVariance(), equalTo(variance(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8 ,9)));
|
||||
assertThat(stats.getStdDeviation(), equalTo(stdDev(1, 2, 3, 4, 5, 6, 7, 8 ,9, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8 ,9)));
|
||||
checkUpperLowerBounds(stats, sigma);
|
||||
}
|
||||
|
||||
|
||||
|
@ -474,4 +532,10 @@ public class ExtendedStatsTests extends AbstractNumericTests {
|
|||
}
|
||||
assertThat("Not all shards are initialized", response.getSuccessfulShards(), equalTo(response.getTotalShards()));
|
||||
}
|
||||
|
||||
private void checkUpperLowerBounds(ExtendedStats stats, double sigma) {
|
||||
assertThat(stats.getStdDeviationBound(ExtendedStats.Bounds.UPPER), equalTo(stats.getAvg() + (stats.getStdDeviation() * sigma)));
|
||||
assertThat(stats.getStdDeviationBound(ExtendedStats.Bounds.LOWER), equalTo(stats.getAvg() - (stats.getStdDeviation() * sigma)));
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue