Add prediction capability to MovAvgReducer

This commit adds the ability for moving average models to output a "prediction" based on the current
moving average model.  For simple, linear and single, this prediction is simply converges on the
moving average's mean at the last point, leading to a straight line.  For double, this will
predict in the direction of the linear trend (either globally or locally, depending on beta).

Also adds some more tests.

Closes #10545
This commit is contained in:
Zachary Tong 2015-04-09 15:02:01 -04:00
parent dcf91ff02f
commit 30177887b1
9 changed files with 1477 additions and 519 deletions

View File

@ -36,7 +36,7 @@ public final class ReducerBuilders {
return new MaxBucketBuilder(name);
}
public static final MovAvgBuilder smooth(String name) {
public static final MovAvgBuilder movingAvg(String name) {
return new MovAvgBuilder(name);
}
}

View File

@ -36,6 +36,7 @@ public class MovAvgBuilder extends ReducerBuilder<MovAvgBuilder> {
private GapPolicy gapPolicy;
private MovAvgModelBuilder modelBuilder;
private Integer window;
private Integer predict;
public MovAvgBuilder(String name) {
super(name, MovAvgReducer.TYPE.name());
@ -81,6 +82,19 @@ public class MovAvgBuilder extends ReducerBuilder<MovAvgBuilder> {
return this;
}
/**
* Sets the number of predictions that should be returned. Each prediction will be spaced at
* the intervals specified in the histogram. E.g "predict: 2" will return two new buckets at the
* end of the histogram with the predicted values.
*
* @param numPredictions Number of predictions to make
* @return Returns the builder to continue chaining
*/
public MovAvgBuilder predict(int numPredictions) {
this.predict = numPredictions;
return this;
}
@Override
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
@ -96,6 +110,9 @@ public class MovAvgBuilder extends ReducerBuilder<MovAvgBuilder> {
if (window != null) {
builder.field(MovAvgParser.WINDOW.getPreferredName(), window);
}
if (predict != null) {
builder.field(MovAvgParser.PREDICT.getPreferredName(), predict);
}
return builder;
}

View File

@ -46,6 +46,7 @@ public class MovAvgParser implements Reducer.Parser {
public static final ParseField MODEL = new ParseField("model");
public static final ParseField WINDOW = new ParseField("window");
public static final ParseField SETTINGS = new ParseField("settings");
public static final ParseField PREDICT = new ParseField("predict");
private final MovAvgModelParserMapper movAvgModelParserMapper;
@ -65,10 +66,12 @@ public class MovAvgParser implements Reducer.Parser {
String currentFieldName = null;
String[] bucketsPaths = null;
String format = null;
GapPolicy gapPolicy = GapPolicy.IGNORE;
int window = 5;
Map<String, Object> settings = null;
String model = "simple";
int predict = 0;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
@ -76,6 +79,16 @@ public class MovAvgParser implements Reducer.Parser {
} else if (token == XContentParser.Token.VALUE_NUMBER) {
if (WINDOW.match(currentFieldName)) {
window = parser.intValue();
if (window <= 0) {
throw new SearchParseException(context, "[" + currentFieldName + "] value must be a positive, " +
"non-zero integer. Value supplied was [" + predict + "] in [" + reducerName + "].");
}
} else if (PREDICT.match(currentFieldName)) {
predict = parser.intValue();
if (predict <= 0) {
throw new SearchParseException(context, "[" + currentFieldName + "] value must be a positive, " +
"non-zero integer. Value supplied was [" + predict + "] in [" + reducerName + "].");
}
} else {
throw new SearchParseException(context, "Unknown key for a " + token + " in [" + reducerName + "]: ["
+ currentFieldName + "].");
@ -119,7 +132,7 @@ public class MovAvgParser implements Reducer.Parser {
if (bucketsPaths == null) {
throw new SearchParseException(context, "Missing required field [" + BUCKETS_PATH.getPreferredName()
+ "] for smooth aggregation [" + reducerName + "]");
+ "] for movingAvg aggregation [" + reducerName + "]");
}
ValueFormatter formatter = null;
@ -135,7 +148,7 @@ public class MovAvgParser implements Reducer.Parser {
MovAvgModel movAvgModel = modelParser.parse(settings);
return new MovAvgReducer.Factory(reducerName, bucketsPaths, formatter, gapPolicy, window, movAvgModel);
return new MovAvgReducer.Factory(reducerName, bucketsPaths, formatter, gapPolicy, window, predict, movAvgModel);
}

View File

@ -27,12 +27,9 @@ import org.elasticsearch.ElasticsearchIllegalStateException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.*;
import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregation.Type;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregator;
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram;
import org.elasticsearch.search.aggregations.reducers.BucketHelpers.GapPolicy;
@ -44,6 +41,7 @@ import org.elasticsearch.search.aggregations.reducers.movavg.models.MovAvgModel;
import org.elasticsearch.search.aggregations.reducers.movavg.models.MovAvgModelStreams;
import org.elasticsearch.search.aggregations.support.format.ValueFormatter;
import org.elasticsearch.search.aggregations.support.format.ValueFormatterStreams;
import org.joda.time.DateTime;
import java.io.IOException;
import java.util.ArrayList;
@ -80,17 +78,19 @@ public class MovAvgReducer extends Reducer {
private GapPolicy gapPolicy;
private int window;
private MovAvgModel model;
private int predict;
public MovAvgReducer() {
}
public MovAvgReducer(String name, String[] bucketsPaths, @Nullable ValueFormatter formatter, GapPolicy gapPolicy,
int window, MovAvgModel model, Map<String, Object> metadata) {
int window, int predict, MovAvgModel model, Map<String, Object> metadata) {
super(name, bucketsPaths, metadata);
this.formatter = formatter;
this.gapPolicy = gapPolicy;
this.window = window;
this.model = model;
this.predict = predict;
}
@Override
@ -107,8 +107,14 @@ public class MovAvgReducer extends Reducer {
List newBuckets = new ArrayList<>();
EvictingQueue<Double> values = EvictingQueue.create(this.window);
long lastKey = 0;
long interval = Long.MAX_VALUE;
Object currentKey;
for (InternalHistogram.Bucket bucket : buckets) {
Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], gapPolicy);
currentKey = bucket.getKey();
if (thisBucketValue != null) {
values.offer(thisBucketValue);
@ -117,14 +123,46 @@ public class MovAvgReducer extends Reducer {
List<InternalAggregation> aggs = new ArrayList<>(Lists.transform(bucket.getAggregations().asList(), FUNCTION));
aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList<Reducer>(), metaData()));
InternalHistogram.Bucket newBucket = factory.createBucket(bucket.getKey(), bucket.getDocCount(), new InternalAggregations(
InternalHistogram.Bucket newBucket = factory.createBucket(currentKey, bucket.getDocCount(), new InternalAggregations(
aggs), bucket.getKeyed(), bucket.getFormatter());
newBuckets.add(newBucket);
} else {
newBuckets.add(bucket);
}
if (predict > 0) {
if (currentKey instanceof Number) {
interval = Math.min(interval, ((Number) bucket.getKey()).longValue() - lastKey);
lastKey = ((Number) bucket.getKey()).longValue();
} else if (currentKey instanceof DateTime) {
interval = Math.min(interval, ((DateTime) bucket.getKey()).getMillis() - lastKey);
lastKey = ((DateTime) bucket.getKey()).getMillis();
} else {
throw new AggregationExecutionException("Expected key of type Number or DateTime but got [" + currentKey + "]");
}
}
}
//return factory.create(histo.getName(), newBuckets, histo);
if (buckets.size() > 0 && predict > 0) {
boolean keyed;
ValueFormatter formatter;
keyed = buckets.get(0).getKeyed();
formatter = buckets.get(0).getFormatter();
double[] predictions = model.predict(values, predict);
for (int i = 0; i < predictions.length; i++) {
List<InternalAggregation> aggs = new ArrayList<>();
aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList<Reducer>(), metaData()));
InternalHistogram.Bucket newBucket = factory.createBucket(lastKey + (interval * (i + 1)), 0, new InternalAggregations(
aggs), keyed, formatter);
newBuckets.add(newBucket);
}
}
return factory.create(newBuckets, histo);
}
@ -133,7 +171,9 @@ public class MovAvgReducer extends Reducer {
formatter = ValueFormatterStreams.readOptional(in);
gapPolicy = GapPolicy.readFrom(in);
window = in.readVInt();
predict = in.readVInt();
model = MovAvgModelStreams.read(in);
}
@Override
@ -141,7 +181,9 @@ public class MovAvgReducer extends Reducer {
ValueFormatterStreams.writeOptional(formatter, out);
gapPolicy.writeTo(out);
out.writeVInt(window);
out.writeVInt(predict);
model.writeTo(out);
}
public static class Factory extends ReducerFactory {
@ -150,19 +192,21 @@ public class MovAvgReducer extends Reducer {
private GapPolicy gapPolicy;
private int window;
private MovAvgModel model;
private int predict;
public Factory(String name, String[] bucketsPaths, @Nullable ValueFormatter formatter, GapPolicy gapPolicy,
int window, MovAvgModel model) {
int window, int predict, MovAvgModel model) {
super(name, TYPE.name(), bucketsPaths);
this.formatter = formatter;
this.gapPolicy = gapPolicy;
this.window = window;
this.model = model;
this.predict = predict;
}
@Override
protected Reducer createInternal(Map<String, Object> metaData) throws IOException {
return new MovAvgReducer(name, bucketsPaths, formatter, gapPolicy, window, model, metaData);
return new MovAvgReducer(name, bucketsPaths, formatter, gapPolicy, window, predict, model, metaData);
}
@Override

View File

@ -53,10 +53,25 @@ public class DoubleExpModel extends MovAvgModel {
this.beta = beta;
}
/**
* Predicts the next `n` values in the series, using the smoothing model to generate new values.
* Unlike the other moving averages, double-exp has forecasting/prediction built into the algorithm.
* Prediction is more than simply adding the next prediction to the window and repeating. Double-exp
* will extrapolate into the future by applying the trend information to the smoothed data.
*
* @param values Collection of numerics to movingAvg, usually windowed
* @param numPredictions Number of newly generated predictions to return
* @param <T> Type of numeric
* @return Returns an array of doubles, since most smoothing methods operate on floating points
*/
@Override
public <T extends Number> double[] predict(Collection<T> values, int numPredictions) {
return next(values, numPredictions);
}
@Override
public <T extends Number> double next(Collection<T> values) {
return next(values, 1).get(0);
return next(values, 1)[0];
}
/**
@ -68,7 +83,12 @@ public class DoubleExpModel extends MovAvgModel {
* @param <T> Type T extending Number
* @return Returns a Double containing the moving avg for the window
*/
public <T extends Number> List<Double> next(Collection<T> values, int numForecasts) {
public <T extends Number> double[] next(Collection<T> values, int numForecasts) {
if (values.size() == 0) {
return emptyPredictions(numForecasts);
}
// Smoothed value
double s = 0;
double last_s = 0;
@ -97,9 +117,9 @@ public class DoubleExpModel extends MovAvgModel {
last_b = b;
}
List<Double> forecastValues = new ArrayList<>(numForecasts);
double[] forecastValues = new double[numForecasts];
for (int i = 0; i < numForecasts; i++) {
forecastValues.add(s + (i * b));
forecastValues[i] = s + (i * b);
}
return forecastValues;

View File

@ -19,6 +19,8 @@
package org.elasticsearch.search.aggregations.reducers.movavg.models;
import com.google.common.collect.EvictingQueue;
import org.elasticsearch.ElasticsearchIllegalArgumentException;
import org.elasticsearch.common.io.stream.StreamOutput;
import java.io.IOException;
@ -29,12 +31,61 @@ public abstract class MovAvgModel {
/**
* Returns the next value in the series, according to the underlying smoothing model
*
* @param values Collection of numerics to smooth, usually windowed
* @param values Collection of numerics to movingAvg, usually windowed
* @param <T> Type of numeric
* @return Returns a double, since most smoothing methods operate on floating points
*/
public abstract <T extends Number> double next(Collection<T> values);
/**
* Predicts the next `n` values in the series, using the smoothing model to generate new values.
* Default prediction mode is to simply continuing calling <code>next()</code> and adding the
* predicted value back into the windowed buffer.
*
* @param values Collection of numerics to movingAvg, usually windowed
* @param numPredictions Number of newly generated predictions to return
* @param <T> Type of numeric
* @return Returns an array of doubles, since most smoothing methods operate on floating points
*/
public <T extends Number> double[] predict(Collection<T> values, int numPredictions) {
double[] predictions = new double[numPredictions];
// If there are no values, we can't do anything. Return an array of NaNs.
if (values.size() == 0) {
return emptyPredictions(numPredictions);
}
// special case for one prediction, avoids allocation
if (numPredictions < 1) {
throw new ElasticsearchIllegalArgumentException("numPredictions may not be less than 1.");
} else if (numPredictions == 1){
predictions[0] = next(values);
return predictions;
}
// nocommit
// I don't like that it creates a new queue here
// The alternative to this is to just use `values` directly, but that would "consume" values
// and potentially change state elsewhere. Maybe ok?
Collection<Number> predictionBuffer = EvictingQueue.create(values.size());
predictionBuffer.addAll(values);
for (int i = 0; i < numPredictions; i++) {
predictions[i] = next(predictionBuffer);
// Add the last value to the buffer, so we can keep predicting
predictionBuffer.add(predictions[i]);
}
return predictions;
}
protected double[] emptyPredictions(int numPredictions) {
double[] predictions = new double[numPredictions];
Arrays.fill(predictions, Double.NaN);
return predictions;
}
/**
* Write the model to the output stream
*

View File

@ -1,502 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.reducers;
import com.google.common.collect.EvictingQueue;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.aggregations.bucket.histogram.Histogram;
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram;
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram.Bucket;
import org.elasticsearch.search.aggregations.reducers.movavg.models.DoubleExpModel;
import org.elasticsearch.search.aggregations.reducers.movavg.models.LinearModel;
import org.elasticsearch.search.aggregations.reducers.movavg.models.SimpleModel;
import org.elasticsearch.search.aggregations.reducers.movavg.models.SingleExpModel;
import org.elasticsearch.test.ElasticsearchIntegrationTest;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.search.aggregations.AggregationBuilders.histogram;
import static org.elasticsearch.search.aggregations.AggregationBuilders.sum;
import static org.elasticsearch.search.aggregations.reducers.ReducerBuilders.smooth;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.core.IsNull.notNullValue;
@ElasticsearchIntegrationTest.SuiteScopeTest
public class MovAvgTests extends ElasticsearchIntegrationTest {
private static final String SINGLE_VALUED_FIELD_NAME = "l_value";
private static final String SINGLE_VALUED_VALUE_FIELD_NAME = "v_value";
static int interval;
static int numValueBuckets;
static int numFilledValueBuckets;
static int windowSize;
static BucketHelpers.GapPolicy gapPolicy;
static long[] docCounts;
static long[] valueCounts;
static Double[] simpleMovAvgCounts;
static Double[] linearMovAvgCounts;
static Double[] singleExpMovAvgCounts;
static Double[] doubleExpMovAvgCounts;
static Double[] simpleMovAvgValueCounts;
static Double[] linearMovAvgValueCounts;
static Double[] singleExpMovAvgValueCounts;
static Double[] doubleExpMovAvgValueCounts;
@Override
public void setupSuiteScopeCluster() throws Exception {
createIndex("idx");
createIndex("idx_unmapped");
interval = 5;
numValueBuckets = randomIntBetween(6, 80);
numFilledValueBuckets = numValueBuckets;
windowSize = randomIntBetween(3,10);
gapPolicy = BucketHelpers.GapPolicy.INSERT_ZEROS; // TODO randomBoolean() ? BucketHelpers.GapPolicy.IGNORE : BucketHelpers.GapPolicy.INSERT_ZEROS;
docCounts = new long[numValueBuckets];
valueCounts = new long[numValueBuckets];
for (int i = 0; i < numValueBuckets; i++) {
docCounts[i] = randomIntBetween(0, 20);
valueCounts[i] = randomIntBetween(1,20); //this will be used as a constant for all values within a bucket
}
this.setupSimple();
this.setupLinear();
this.setupSingle();
this.setupDouble();
List<IndexRequestBuilder> builders = new ArrayList<>();
for (int i = 0; i < numValueBuckets; i++) {
for (int docs = 0; docs < docCounts[i]; docs++) {
builders.add(client().prepareIndex("idx", "type").setSource(jsonBuilder().startObject()
.field(SINGLE_VALUED_FIELD_NAME, i * interval)
.field(SINGLE_VALUED_VALUE_FIELD_NAME, 1).endObject()));
}
}
indexRandom(true, builders);
ensureSearchable();
}
private void setupSimple() {
simpleMovAvgCounts = new Double[numValueBuckets];
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < numValueBuckets; i++) {
double thisValue = docCounts[i];
window.offer(thisValue);
double movAvg = 0;
for (double value : window) {
movAvg += value;
}
movAvg /= window.size();
simpleMovAvgCounts[i] = movAvg;
}
window.clear();
simpleMovAvgValueCounts = new Double[numValueBuckets];
for (int i = 0; i < numValueBuckets; i++) {
window.offer((double)docCounts[i]);
double movAvg = 0;
for (double value : window) {
movAvg += value;
}
movAvg /= window.size();
simpleMovAvgValueCounts[i] = movAvg;
}
}
private void setupLinear() {
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
linearMovAvgCounts = new Double[numValueBuckets];
window.clear();
for (int i = 0; i < numValueBuckets; i++) {
double thisValue = docCounts[i];
if (thisValue == -1) {
thisValue = 0;
}
window.offer(thisValue);
double avg = 0;
long totalWeight = 1;
long current = 1;
for (double value : window) {
avg += value * current;
totalWeight += current;
current += 1;
}
linearMovAvgCounts[i] = avg / totalWeight;
}
window.clear();
linearMovAvgValueCounts = new Double[numValueBuckets];
for (int i = 0; i < numValueBuckets; i++) {
double thisValue = docCounts[i];
window.offer(thisValue);
double avg = 0;
long totalWeight = 1;
long current = 1;
for (double value : window) {
avg += value * current;
totalWeight += current;
current += 1;
}
linearMovAvgValueCounts[i] = avg / totalWeight;
}
}
private void setupSingle() {
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
singleExpMovAvgCounts = new Double[numValueBuckets];
for (int i = 0; i < numValueBuckets; i++) {
double thisValue = docCounts[i];
if (thisValue == -1) {
thisValue = 0;
}
window.offer(thisValue);
double avg = 0;
double alpha = 0.5;
boolean first = true;
for (double value : window) {
if (first) {
avg = value;
first = false;
} else {
avg = (value * alpha) + (avg * (1 - alpha));
}
}
singleExpMovAvgCounts[i] = avg ;
}
singleExpMovAvgValueCounts = new Double[numValueBuckets];
window.clear();
for (int i = 0; i < numValueBuckets; i++) {
window.offer((double)docCounts[i]);
double avg = 0;
double alpha = 0.5;
boolean first = true;
for (double value : window) {
if (first) {
avg = value;
first = false;
} else {
avg = (value * alpha) + (avg * (1 - alpha));
}
}
singleExpMovAvgCounts[i] = avg ;
}
}
private void setupDouble() {
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
doubleExpMovAvgCounts = new Double[numValueBuckets];
for (int i = 0; i < numValueBuckets; i++) {
double thisValue = docCounts[i];
if (thisValue == -1) {
thisValue = 0;
}
window.offer(thisValue);
double s = 0;
double last_s = 0;
// Trend value
double b = 0;
double last_b = 0;
double alpha = 0.5;
double beta = 0.5;
int counter = 0;
double last;
for (double value : window) {
last = value;
if (counter == 1) {
s = value;
b = value - last;
} else {
s = alpha * value + (1.0d - alpha) * (last_s + last_b);
b = beta * (s - last_s) + (1 - beta) * last_b;
}
counter += 1;
last_s = s;
last_b = b;
}
doubleExpMovAvgCounts[i] = s + (0 * b) ;
}
doubleExpMovAvgValueCounts = new Double[numValueBuckets];
window.clear();
for (int i = 0; i < numValueBuckets; i++) {
window.offer((double)docCounts[i]);
double s = 0;
double last_s = 0;
// Trend value
double b = 0;
double last_b = 0;
double alpha = 0.5;
double beta = 0.5;
int counter = 0;
double last;
for (double value : window) {
last = value;
if (counter == 1) {
s = value;
b = value - last;
} else {
s = alpha * value + (1.0d - alpha) * (last_s + last_b);
b = beta * (s - last_s) + (1 - beta) * last_b;
}
counter += 1;
last_s = s;
last_b = b;
}
doubleExpMovAvgValueCounts[i] = s + (0 * b) ;
}
}
/**
* test simple moving average on single value field
*/
@Test
public void simpleSingleValuedField() {
SearchResponse response = client()
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0)
.extendedBounds(0L, (long) (interval * (numValueBuckets - 1)))
.subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME))
.subAggregation(smooth("smooth")
.window(windowSize)
.modelBuilder(new SimpleModel.SimpleModelBuilder())
.gapPolicy(gapPolicy)
.setBucketsPaths("_count"))
.subAggregation(smooth("movavg_values")
.window(windowSize)
.modelBuilder(new SimpleModel.SimpleModelBuilder())
.gapPolicy(gapPolicy)
.setBucketsPaths("the_sum"))
).execute().actionGet();
assertSearchResponse(response);
InternalHistogram<Bucket> histo = response.getAggregations().get("histo");
assertThat(histo, notNullValue());
assertThat(histo.getName(), equalTo("histo"));
List<? extends Bucket> buckets = histo.getBuckets();
assertThat(buckets.size(), equalTo(numValueBuckets));
for (int i = 0; i < numValueBuckets; ++i) {
Histogram.Bucket bucket = buckets.get(i);
checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]);
SimpleValue docCountMovAvg = bucket.getAggregations().get("smooth");
assertThat(docCountMovAvg, notNullValue());
assertThat(docCountMovAvg.value(), equalTo(simpleMovAvgCounts[i]));
SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values");
assertThat(valuesMovAvg, notNullValue());
assertThat(valuesMovAvg.value(), equalTo(simpleMovAvgCounts[i]));
}
}
/**
* test linear moving average on single value field
*/
@Test
public void linearSingleValuedField() {
SearchResponse response = client()
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0)
.extendedBounds(0L, (long) (interval * (numValueBuckets - 1)))
.subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME))
.subAggregation(smooth("smooth")
.window(windowSize)
.modelBuilder(new LinearModel.LinearModelBuilder())
.gapPolicy(gapPolicy)
.setBucketsPaths("_count"))
.subAggregation(smooth("movavg_values")
.window(windowSize)
.modelBuilder(new LinearModel.LinearModelBuilder())
.gapPolicy(gapPolicy)
.setBucketsPaths("the_sum"))
).execute().actionGet();
assertSearchResponse(response);
InternalHistogram<Bucket> histo = response.getAggregations().get("histo");
assertThat(histo, notNullValue());
assertThat(histo.getName(), equalTo("histo"));
List<? extends Bucket> buckets = histo.getBuckets();
assertThat(buckets.size(), equalTo(numValueBuckets));
for (int i = 0; i < numValueBuckets; ++i) {
Histogram.Bucket bucket = buckets.get(i);
checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]);
SimpleValue docCountMovAvg = bucket.getAggregations().get("smooth");
assertThat(docCountMovAvg, notNullValue());
assertThat(docCountMovAvg.value(), equalTo(linearMovAvgCounts[i]));
SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values");
assertThat(valuesMovAvg, notNullValue());
assertThat(valuesMovAvg.value(), equalTo(linearMovAvgCounts[i]));
}
}
/**
* test single exponential moving average on single value field
*/
@Test
public void singleExpSingleValuedField() {
SearchResponse response = client()
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0)
.extendedBounds(0L, (long) (interval * (numValueBuckets - 1)))
.subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME))
.subAggregation(smooth("smooth")
.window(windowSize)
.modelBuilder(new SingleExpModel.SingleExpModelBuilder().alpha(0.5))
.gapPolicy(gapPolicy)
.setBucketsPaths("_count"))
.subAggregation(smooth("movavg_values")
.window(windowSize)
.modelBuilder(new SingleExpModel.SingleExpModelBuilder().alpha(0.5))
.gapPolicy(gapPolicy)
.setBucketsPaths("the_sum"))
).execute().actionGet();
assertSearchResponse(response);
InternalHistogram<Bucket> histo = response.getAggregations().get("histo");
assertThat(histo, notNullValue());
assertThat(histo.getName(), equalTo("histo"));
List<? extends Bucket> buckets = histo.getBuckets();
assertThat(buckets.size(), equalTo(numValueBuckets));
for (int i = 0; i < numValueBuckets; ++i) {
Histogram.Bucket bucket = buckets.get(i);
checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]);
SimpleValue docCountMovAvg = bucket.getAggregations().get("smooth");
assertThat(docCountMovAvg, notNullValue());
assertThat(docCountMovAvg.value(), equalTo(singleExpMovAvgCounts[i]));
SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values");
assertThat(valuesMovAvg, notNullValue());
assertThat(valuesMovAvg.value(), equalTo(singleExpMovAvgCounts[i]));
}
}
/**
* test double exponential moving average on single value field
*/
@Test
public void doubleExpSingleValuedField() {
SearchResponse response = client()
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0)
.extendedBounds(0L, (long) (interval * (numValueBuckets - 1)))
.subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME))
.subAggregation(smooth("smooth")
.window(windowSize)
.modelBuilder(new DoubleExpModel.DoubleExpModelBuilder().alpha(0.5).beta(0.5))
.gapPolicy(gapPolicy)
.setBucketsPaths("_count"))
.subAggregation(smooth("movavg_values")
.window(windowSize)
.modelBuilder(new DoubleExpModel.DoubleExpModelBuilder().alpha(0.5).beta(0.5))
.gapPolicy(gapPolicy)
.setBucketsPaths("the_sum"))
).execute().actionGet();
assertSearchResponse(response);
InternalHistogram<Bucket> histo = response.getAggregations().get("histo");
assertThat(histo, notNullValue());
assertThat(histo.getName(), equalTo("histo"));
List<? extends Bucket> buckets = histo.getBuckets();
assertThat(buckets.size(), equalTo(numValueBuckets));
for (int i = 0; i < numValueBuckets; ++i) {
Histogram.Bucket bucket = buckets.get(i);
checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]);
SimpleValue docCountMovAvg = bucket.getAggregations().get("smooth");
assertThat(docCountMovAvg, notNullValue());
assertThat(docCountMovAvg.value(), equalTo(doubleExpMovAvgCounts[i]));
SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values");
assertThat(valuesMovAvg, notNullValue());
assertThat(valuesMovAvg.value(), equalTo(doubleExpMovAvgCounts[i]));
}
}
private void checkBucketKeyAndDocCount(final String msg, final Histogram.Bucket bucket, final long expectedKey,
long expectedDocCount) {
if (expectedDocCount == -1) {
expectedDocCount = 0;
}
assertThat(msg, bucket, notNullValue());
assertThat(msg + " key", ((Number) bucket.getKey()).longValue(), equalTo(expectedKey));
assertThat(msg + " docCount", bucket.getDocCount(), equalTo(expectedDocCount));
}
}

View File

@ -0,0 +1,297 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.reducers.moving.avg;
import com.google.common.collect.EvictingQueue;
import org.elasticsearch.search.aggregations.reducers.movavg.models.*;
import org.elasticsearch.test.ElasticsearchTestCase;
import static org.hamcrest.Matchers.equalTo;
import org.junit.Test;
public class MovAvgUnitTests extends ElasticsearchTestCase {
@Test
public void testSimpleMovAvgModel() {
MovAvgModel model = new SimpleModel();
int numValues = randomIntBetween(1, 100);
int windowSize = randomIntBetween(1, 50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < numValues; i++) {
double randValue = randomDouble();
double expected = 0;
window.offer(randValue);
for (double value : window) {
expected += value;
}
expected /= window.size();
double actual = model.next(window);
assertThat(Double.compare(expected, actual), equalTo(0));
}
}
@Test
public void testSimplePredictionModel() {
MovAvgModel model = new SimpleModel();
int windowSize = randomIntBetween(1, 50);
int numPredictions = randomIntBetween(1,50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < windowSize; i++) {
window.offer(randomDouble());
}
double actual[] = model.predict(window, numPredictions);
double expected[] = new double[numPredictions];
for (int i = 0; i < numPredictions; i++) {
for (double value : window) {
expected[i] += value;
}
expected[i] /= window.size();
window.offer(expected[i]);
}
for (int i = 0; i < numPredictions; i++) {
assertThat(Double.compare(expected[i], actual[i]), equalTo(0));
}
}
@Test
public void testLinearMovAvgModel() {
MovAvgModel model = new LinearModel();
int numValues = randomIntBetween(1, 100);
int windowSize = randomIntBetween(1, 50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < numValues; i++) {
double randValue = randomDouble();
window.offer(randValue);
double avg = 0;
long totalWeight = 1;
long current = 1;
for (double value : window) {
avg += value * current;
totalWeight += current;
current += 1;
}
double expected = avg / totalWeight;
double actual = model.next(window);
assertThat(Double.compare(expected, actual), equalTo(0));
}
}
@Test
public void testLinearPredictionModel() {
MovAvgModel model = new LinearModel();
int windowSize = randomIntBetween(1, 50);
int numPredictions = randomIntBetween(1,50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < windowSize; i++) {
window.offer(randomDouble());
}
double actual[] = model.predict(window, numPredictions);
double expected[] = new double[numPredictions];
for (int i = 0; i < numPredictions; i++) {
double avg = 0;
long totalWeight = 1;
long current = 1;
for (double value : window) {
avg += value * current;
totalWeight += current;
current += 1;
}
expected[i] = avg / totalWeight;
window.offer(expected[i]);
}
for (int i = 0; i < numPredictions; i++) {
assertThat(Double.compare(expected[i], actual[i]), equalTo(0));
}
}
@Test
public void testSingleExpMovAvgModel() {
double alpha = randomDouble();
MovAvgModel model = new SingleExpModel(alpha);
int numValues = randomIntBetween(1, 100);
int windowSize = randomIntBetween(1, 50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < numValues; i++) {
double randValue = randomDouble();
window.offer(randValue);
double avg = 0;
boolean first = true;
for (double value : window) {
if (first) {
avg = value;
first = false;
} else {
avg = (value * alpha) + (avg * (1 - alpha));
}
}
double expected = avg;
double actual = model.next(window);
assertThat(Double.compare(expected, actual), equalTo(0));
}
}
@Test
public void testSinglePredictionModel() {
double alpha = randomDouble();
MovAvgModel model = new SingleExpModel(alpha);
int windowSize = randomIntBetween(1, 50);
int numPredictions = randomIntBetween(1,50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < windowSize; i++) {
window.offer(randomDouble());
}
double actual[] = model.predict(window, numPredictions);
double expected[] = new double[numPredictions];
for (int i = 0; i < numPredictions; i++) {
double avg = 0;
boolean first = true;
for (double value : window) {
if (first) {
avg = value;
first = false;
} else {
avg = (value * alpha) + (avg * (1 - alpha));
}
}
expected[i] = avg;
window.offer(expected[i]);
}
for (int i = 0; i < numPredictions; i++) {
assertThat(Double.compare(expected[i], actual[i]), equalTo(0));
}
}
@Test
public void testDoubleExpMovAvgModel() {
double alpha = randomDouble();
double beta = randomDouble();
MovAvgModel model = new DoubleExpModel(alpha, beta);
int numValues = randomIntBetween(1, 100);
int windowSize = randomIntBetween(1, 50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < numValues; i++) {
double randValue = randomDouble();
window.offer(randValue);
double s = 0;
double last_s = 0;
// Trend value
double b = 0;
double last_b = 0;
int counter = 0;
double last;
for (double value : window) {
last = value;
if (counter == 1) {
s = value;
b = value - last;
} else {
s = alpha * value + (1.0d - alpha) * (last_s + last_b);
b = beta * (s - last_s) + (1 - beta) * last_b;
}
counter += 1;
last_s = s;
last_b = b;
}
double expected = s + (0 * b) ;
double actual = model.next(window);
assertThat(Double.compare(expected, actual), equalTo(0));
}
}
@Test
public void testDoublePredictionModel() {
double alpha = randomDouble();
double beta = randomDouble();
MovAvgModel model = new DoubleExpModel(alpha, beta);
int windowSize = randomIntBetween(1, 50);
int numPredictions = randomIntBetween(1,50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < windowSize; i++) {
window.offer(randomDouble());
}
double actual[] = model.predict(window, numPredictions);
double expected[] = new double[numPredictions];
double s = 0;
double last_s = 0;
// Trend value
double b = 0;
double last_b = 0;
int counter = 0;
double last;
for (double value : window) {
last = value;
if (counter == 1) {
s = value;
b = value - last;
} else {
s = alpha * value + (1.0d - alpha) * (last_s + last_b);
b = beta * (s - last_s) + (1 - beta) * last_b;
}
counter += 1;
last_s = s;
last_b = b;
}
for (int i = 0; i < numPredictions; i++) {
expected[i] = s + (i * b);
assertThat(Double.compare(expected[i], actual[i]), equalTo(0));
}
}
}