diff --git a/src/main/java/org/elasticsearch/search/aggregations/reducers/ReducerBuilders.java b/src/main/java/org/elasticsearch/search/aggregations/reducers/ReducerBuilders.java index 3f45964153b..ba6d3ebe7c2 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/reducers/ReducerBuilders.java +++ b/src/main/java/org/elasticsearch/search/aggregations/reducers/ReducerBuilders.java @@ -36,7 +36,7 @@ public final class ReducerBuilders { return new MaxBucketBuilder(name); } - public static final MovAvgBuilder smooth(String name) { + public static final MovAvgBuilder movingAvg(String name) { return new MovAvgBuilder(name); } } diff --git a/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgBuilder.java b/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgBuilder.java index 9790604197d..5fba23957e9 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgBuilder.java +++ b/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgBuilder.java @@ -36,6 +36,7 @@ public class MovAvgBuilder extends ReducerBuilder { private GapPolicy gapPolicy; private MovAvgModelBuilder modelBuilder; private Integer window; + private Integer predict; public MovAvgBuilder(String name) { super(name, MovAvgReducer.TYPE.name()); @@ -81,6 +82,19 @@ public class MovAvgBuilder extends ReducerBuilder { return this; } + /** + * Sets the number of predictions that should be returned. Each prediction will be spaced at + * the intervals specified in the histogram. E.g "predict: 2" will return two new buckets at the + * end of the histogram with the predicted values. + * + * @param numPredictions Number of predictions to make + * @return Returns the builder to continue chaining + */ + public MovAvgBuilder predict(int numPredictions) { + this.predict = numPredictions; + return this; + } + @Override protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { @@ -96,6 +110,9 @@ public class MovAvgBuilder extends ReducerBuilder { if (window != null) { builder.field(MovAvgParser.WINDOW.getPreferredName(), window); } + if (predict != null) { + builder.field(MovAvgParser.PREDICT.getPreferredName(), predict); + } return builder; } diff --git a/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgParser.java b/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgParser.java index 3f241a67b3a..c1cdadf91ea 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgParser.java +++ b/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgParser.java @@ -46,6 +46,7 @@ public class MovAvgParser implements Reducer.Parser { public static final ParseField MODEL = new ParseField("model"); public static final ParseField WINDOW = new ParseField("window"); public static final ParseField SETTINGS = new ParseField("settings"); + public static final ParseField PREDICT = new ParseField("predict"); private final MovAvgModelParserMapper movAvgModelParserMapper; @@ -65,10 +66,12 @@ public class MovAvgParser implements Reducer.Parser { String currentFieldName = null; String[] bucketsPaths = null; String format = null; + GapPolicy gapPolicy = GapPolicy.IGNORE; int window = 5; Map settings = null; String model = "simple"; + int predict = 0; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -76,6 +79,16 @@ public class MovAvgParser implements Reducer.Parser { } else if (token == XContentParser.Token.VALUE_NUMBER) { if (WINDOW.match(currentFieldName)) { window = parser.intValue(); + if (window <= 0) { + throw new SearchParseException(context, "[" + currentFieldName + "] value must be a positive, " + + "non-zero integer. Value supplied was [" + predict + "] in [" + reducerName + "]."); + } + } else if (PREDICT.match(currentFieldName)) { + predict = parser.intValue(); + if (predict <= 0) { + throw new SearchParseException(context, "[" + currentFieldName + "] value must be a positive, " + + "non-zero integer. Value supplied was [" + predict + "] in [" + reducerName + "]."); + } } else { throw new SearchParseException(context, "Unknown key for a " + token + " in [" + reducerName + "]: [" + currentFieldName + "]."); @@ -119,7 +132,7 @@ public class MovAvgParser implements Reducer.Parser { if (bucketsPaths == null) { throw new SearchParseException(context, "Missing required field [" + BUCKETS_PATH.getPreferredName() - + "] for smooth aggregation [" + reducerName + "]"); + + "] for movingAvg aggregation [" + reducerName + "]"); } ValueFormatter formatter = null; @@ -135,7 +148,7 @@ public class MovAvgParser implements Reducer.Parser { MovAvgModel movAvgModel = modelParser.parse(settings); - return new MovAvgReducer.Factory(reducerName, bucketsPaths, formatter, gapPolicy, window, movAvgModel); + return new MovAvgReducer.Factory(reducerName, bucketsPaths, formatter, gapPolicy, window, predict, movAvgModel); } diff --git a/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgReducer.java b/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgReducer.java index 20baa1706f1..4bd2ff4c50a 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgReducer.java +++ b/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/MovAvgReducer.java @@ -27,12 +27,9 @@ import org.elasticsearch.ElasticsearchIllegalStateException; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.search.aggregations.Aggregation; -import org.elasticsearch.search.aggregations.AggregatorFactory; -import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.*; import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext; import org.elasticsearch.search.aggregations.InternalAggregation.Type; -import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregator; import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram; import org.elasticsearch.search.aggregations.reducers.BucketHelpers.GapPolicy; @@ -44,6 +41,7 @@ import org.elasticsearch.search.aggregations.reducers.movavg.models.MovAvgModel; import org.elasticsearch.search.aggregations.reducers.movavg.models.MovAvgModelStreams; import org.elasticsearch.search.aggregations.support.format.ValueFormatter; import org.elasticsearch.search.aggregations.support.format.ValueFormatterStreams; +import org.joda.time.DateTime; import java.io.IOException; import java.util.ArrayList; @@ -80,17 +78,19 @@ public class MovAvgReducer extends Reducer { private GapPolicy gapPolicy; private int window; private MovAvgModel model; + private int predict; public MovAvgReducer() { } public MovAvgReducer(String name, String[] bucketsPaths, @Nullable ValueFormatter formatter, GapPolicy gapPolicy, - int window, MovAvgModel model, Map metadata) { + int window, int predict, MovAvgModel model, Map metadata) { super(name, bucketsPaths, metadata); this.formatter = formatter; this.gapPolicy = gapPolicy; this.window = window; this.model = model; + this.predict = predict; } @Override @@ -107,8 +107,14 @@ public class MovAvgReducer extends Reducer { List newBuckets = new ArrayList<>(); EvictingQueue values = EvictingQueue.create(this.window); + long lastKey = 0; + long interval = Long.MAX_VALUE; + Object currentKey; + for (InternalHistogram.Bucket bucket : buckets) { Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], gapPolicy); + currentKey = bucket.getKey(); + if (thisBucketValue != null) { values.offer(thisBucketValue); @@ -117,14 +123,46 @@ public class MovAvgReducer extends Reducer { List aggs = new ArrayList<>(Lists.transform(bucket.getAggregations().asList(), FUNCTION)); aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList(), metaData())); - InternalHistogram.Bucket newBucket = factory.createBucket(bucket.getKey(), bucket.getDocCount(), new InternalAggregations( + InternalHistogram.Bucket newBucket = factory.createBucket(currentKey, bucket.getDocCount(), new InternalAggregations( aggs), bucket.getKeyed(), bucket.getFormatter()); newBuckets.add(newBucket); + } else { newBuckets.add(bucket); } + + if (predict > 0) { + if (currentKey instanceof Number) { + interval = Math.min(interval, ((Number) bucket.getKey()).longValue() - lastKey); + lastKey = ((Number) bucket.getKey()).longValue(); + } else if (currentKey instanceof DateTime) { + interval = Math.min(interval, ((DateTime) bucket.getKey()).getMillis() - lastKey); + lastKey = ((DateTime) bucket.getKey()).getMillis(); + } else { + throw new AggregationExecutionException("Expected key of type Number or DateTime but got [" + currentKey + "]"); + } + } + } - //return factory.create(histo.getName(), newBuckets, histo); + + + if (buckets.size() > 0 && predict > 0) { + + boolean keyed; + ValueFormatter formatter; + keyed = buckets.get(0).getKeyed(); + formatter = buckets.get(0).getFormatter(); + + double[] predictions = model.predict(values, predict); + for (int i = 0; i < predictions.length; i++) { + List aggs = new ArrayList<>(); + aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList(), metaData())); + InternalHistogram.Bucket newBucket = factory.createBucket(lastKey + (interval * (i + 1)), 0, new InternalAggregations( + aggs), keyed, formatter); + newBuckets.add(newBucket); + } + } + return factory.create(newBuckets, histo); } @@ -133,7 +171,9 @@ public class MovAvgReducer extends Reducer { formatter = ValueFormatterStreams.readOptional(in); gapPolicy = GapPolicy.readFrom(in); window = in.readVInt(); + predict = in.readVInt(); model = MovAvgModelStreams.read(in); + } @Override @@ -141,7 +181,9 @@ public class MovAvgReducer extends Reducer { ValueFormatterStreams.writeOptional(formatter, out); gapPolicy.writeTo(out); out.writeVInt(window); + out.writeVInt(predict); model.writeTo(out); + } public static class Factory extends ReducerFactory { @@ -150,19 +192,21 @@ public class MovAvgReducer extends Reducer { private GapPolicy gapPolicy; private int window; private MovAvgModel model; + private int predict; public Factory(String name, String[] bucketsPaths, @Nullable ValueFormatter formatter, GapPolicy gapPolicy, - int window, MovAvgModel model) { + int window, int predict, MovAvgModel model) { super(name, TYPE.name(), bucketsPaths); this.formatter = formatter; this.gapPolicy = gapPolicy; this.window = window; this.model = model; + this.predict = predict; } @Override protected Reducer createInternal(Map metaData) throws IOException { - return new MovAvgReducer(name, bucketsPaths, formatter, gapPolicy, window, model, metaData); + return new MovAvgReducer(name, bucketsPaths, formatter, gapPolicy, window, predict, model, metaData); } @Override diff --git a/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/models/DoubleExpModel.java b/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/models/DoubleExpModel.java index 907c23fd213..7d32989cda1 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/models/DoubleExpModel.java +++ b/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/models/DoubleExpModel.java @@ -53,10 +53,25 @@ public class DoubleExpModel extends MovAvgModel { this.beta = beta; } + /** + * Predicts the next `n` values in the series, using the smoothing model to generate new values. + * Unlike the other moving averages, double-exp has forecasting/prediction built into the algorithm. + * Prediction is more than simply adding the next prediction to the window and repeating. Double-exp + * will extrapolate into the future by applying the trend information to the smoothed data. + * + * @param values Collection of numerics to movingAvg, usually windowed + * @param numPredictions Number of newly generated predictions to return + * @param Type of numeric + * @return Returns an array of doubles, since most smoothing methods operate on floating points + */ + @Override + public double[] predict(Collection values, int numPredictions) { + return next(values, numPredictions); + } @Override public double next(Collection values) { - return next(values, 1).get(0); + return next(values, 1)[0]; } /** @@ -68,7 +83,12 @@ public class DoubleExpModel extends MovAvgModel { * @param Type T extending Number * @return Returns a Double containing the moving avg for the window */ - public List next(Collection values, int numForecasts) { + public double[] next(Collection values, int numForecasts) { + + if (values.size() == 0) { + return emptyPredictions(numForecasts); + } + // Smoothed value double s = 0; double last_s = 0; @@ -97,9 +117,9 @@ public class DoubleExpModel extends MovAvgModel { last_b = b; } - List forecastValues = new ArrayList<>(numForecasts); + double[] forecastValues = new double[numForecasts]; for (int i = 0; i < numForecasts; i++) { - forecastValues.add(s + (i * b)); + forecastValues[i] = s + (i * b); } return forecastValues; diff --git a/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/models/MovAvgModel.java b/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/models/MovAvgModel.java index 84f7832f893..d798887c836 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/models/MovAvgModel.java +++ b/src/main/java/org/elasticsearch/search/aggregations/reducers/movavg/models/MovAvgModel.java @@ -19,6 +19,8 @@ package org.elasticsearch.search.aggregations.reducers.movavg.models; +import com.google.common.collect.EvictingQueue; +import org.elasticsearch.ElasticsearchIllegalArgumentException; import org.elasticsearch.common.io.stream.StreamOutput; import java.io.IOException; @@ -29,12 +31,61 @@ public abstract class MovAvgModel { /** * Returns the next value in the series, according to the underlying smoothing model * - * @param values Collection of numerics to smooth, usually windowed + * @param values Collection of numerics to movingAvg, usually windowed * @param Type of numeric * @return Returns a double, since most smoothing methods operate on floating points */ public abstract double next(Collection values); + /** + * Predicts the next `n` values in the series, using the smoothing model to generate new values. + * Default prediction mode is to simply continuing calling next() and adding the + * predicted value back into the windowed buffer. + * + * @param values Collection of numerics to movingAvg, usually windowed + * @param numPredictions Number of newly generated predictions to return + * @param Type of numeric + * @return Returns an array of doubles, since most smoothing methods operate on floating points + */ + public double[] predict(Collection values, int numPredictions) { + double[] predictions = new double[numPredictions]; + + // If there are no values, we can't do anything. Return an array of NaNs. + if (values.size() == 0) { + return emptyPredictions(numPredictions); + } + + // special case for one prediction, avoids allocation + if (numPredictions < 1) { + throw new ElasticsearchIllegalArgumentException("numPredictions may not be less than 1."); + } else if (numPredictions == 1){ + predictions[0] = next(values); + return predictions; + } + + // nocommit + // I don't like that it creates a new queue here + // The alternative to this is to just use `values` directly, but that would "consume" values + // and potentially change state elsewhere. Maybe ok? + Collection predictionBuffer = EvictingQueue.create(values.size()); + predictionBuffer.addAll(values); + + for (int i = 0; i < numPredictions; i++) { + predictions[i] = next(predictionBuffer); + + // Add the last value to the buffer, so we can keep predicting + predictionBuffer.add(predictions[i]); + } + + return predictions; + } + + protected double[] emptyPredictions(int numPredictions) { + double[] predictions = new double[numPredictions]; + Arrays.fill(predictions, Double.NaN); + return predictions; + } + /** * Write the model to the output stream * diff --git a/src/test/java/org/elasticsearch/search/aggregations/reducers/MovAvgTests.java b/src/test/java/org/elasticsearch/search/aggregations/reducers/MovAvgTests.java deleted file mode 100644 index 4f0e3c0d1cf..00000000000 --- a/src/test/java/org/elasticsearch/search/aggregations/reducers/MovAvgTests.java +++ /dev/null @@ -1,502 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.search.aggregations.reducers; - - -import com.google.common.collect.EvictingQueue; - -import org.elasticsearch.action.index.IndexRequestBuilder; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; -import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram; -import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram.Bucket; -import org.elasticsearch.search.aggregations.reducers.movavg.models.DoubleExpModel; -import org.elasticsearch.search.aggregations.reducers.movavg.models.LinearModel; -import org.elasticsearch.search.aggregations.reducers.movavg.models.SimpleModel; -import org.elasticsearch.search.aggregations.reducers.movavg.models.SingleExpModel; -import org.elasticsearch.test.ElasticsearchIntegrationTest; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.List; - -import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.elasticsearch.search.aggregations.AggregationBuilders.histogram; -import static org.elasticsearch.search.aggregations.AggregationBuilders.sum; -import static org.elasticsearch.search.aggregations.reducers.ReducerBuilders.smooth; -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.core.IsNull.notNullValue; - -@ElasticsearchIntegrationTest.SuiteScopeTest -public class MovAvgTests extends ElasticsearchIntegrationTest { - - private static final String SINGLE_VALUED_FIELD_NAME = "l_value"; - private static final String SINGLE_VALUED_VALUE_FIELD_NAME = "v_value"; - - static int interval; - static int numValueBuckets; - static int numFilledValueBuckets; - static int windowSize; - static BucketHelpers.GapPolicy gapPolicy; - - static long[] docCounts; - static long[] valueCounts; - static Double[] simpleMovAvgCounts; - static Double[] linearMovAvgCounts; - static Double[] singleExpMovAvgCounts; - static Double[] doubleExpMovAvgCounts; - - static Double[] simpleMovAvgValueCounts; - static Double[] linearMovAvgValueCounts; - static Double[] singleExpMovAvgValueCounts; - static Double[] doubleExpMovAvgValueCounts; - - @Override - public void setupSuiteScopeCluster() throws Exception { - createIndex("idx"); - createIndex("idx_unmapped"); - - interval = 5; - numValueBuckets = randomIntBetween(6, 80); - numFilledValueBuckets = numValueBuckets; - windowSize = randomIntBetween(3,10); - gapPolicy = BucketHelpers.GapPolicy.INSERT_ZEROS; // TODO randomBoolean() ? BucketHelpers.GapPolicy.IGNORE : BucketHelpers.GapPolicy.INSERT_ZEROS; - - docCounts = new long[numValueBuckets]; - valueCounts = new long[numValueBuckets]; - for (int i = 0; i < numValueBuckets; i++) { - docCounts[i] = randomIntBetween(0, 20); - valueCounts[i] = randomIntBetween(1,20); //this will be used as a constant for all values within a bucket - } - - this.setupSimple(); - this.setupLinear(); - this.setupSingle(); - this.setupDouble(); - - - List builders = new ArrayList<>(); - for (int i = 0; i < numValueBuckets; i++) { - for (int docs = 0; docs < docCounts[i]; docs++) { - builders.add(client().prepareIndex("idx", "type").setSource(jsonBuilder().startObject() - .field(SINGLE_VALUED_FIELD_NAME, i * interval) - .field(SINGLE_VALUED_VALUE_FIELD_NAME, 1).endObject())); - } - } - - indexRandom(true, builders); - ensureSearchable(); - } - - private void setupSimple() { - simpleMovAvgCounts = new Double[numValueBuckets]; - EvictingQueue window = EvictingQueue.create(windowSize); - for (int i = 0; i < numValueBuckets; i++) { - double thisValue = docCounts[i]; - window.offer(thisValue); - - double movAvg = 0; - for (double value : window) { - movAvg += value; - } - movAvg /= window.size(); - - simpleMovAvgCounts[i] = movAvg; - } - - window.clear(); - simpleMovAvgValueCounts = new Double[numValueBuckets]; - for (int i = 0; i < numValueBuckets; i++) { - window.offer((double)docCounts[i]); - - double movAvg = 0; - for (double value : window) { - movAvg += value; - } - movAvg /= window.size(); - - simpleMovAvgValueCounts[i] = movAvg; - - } - - } - - private void setupLinear() { - EvictingQueue window = EvictingQueue.create(windowSize); - linearMovAvgCounts = new Double[numValueBuckets]; - window.clear(); - for (int i = 0; i < numValueBuckets; i++) { - double thisValue = docCounts[i]; - if (thisValue == -1) { - thisValue = 0; - } - window.offer(thisValue); - - double avg = 0; - long totalWeight = 1; - long current = 1; - - for (double value : window) { - avg += value * current; - totalWeight += current; - current += 1; - } - linearMovAvgCounts[i] = avg / totalWeight; - } - - window.clear(); - linearMovAvgValueCounts = new Double[numValueBuckets]; - - for (int i = 0; i < numValueBuckets; i++) { - double thisValue = docCounts[i]; - window.offer(thisValue); - - double avg = 0; - long totalWeight = 1; - long current = 1; - - for (double value : window) { - avg += value * current; - totalWeight += current; - current += 1; - } - linearMovAvgValueCounts[i] = avg / totalWeight; - } - } - - private void setupSingle() { - EvictingQueue window = EvictingQueue.create(windowSize); - singleExpMovAvgCounts = new Double[numValueBuckets]; - for (int i = 0; i < numValueBuckets; i++) { - double thisValue = docCounts[i]; - if (thisValue == -1) { - thisValue = 0; - } - window.offer(thisValue); - - double avg = 0; - double alpha = 0.5; - boolean first = true; - - for (double value : window) { - if (first) { - avg = value; - first = false; - } else { - avg = (value * alpha) + (avg * (1 - alpha)); - } - } - singleExpMovAvgCounts[i] = avg ; - } - - singleExpMovAvgValueCounts = new Double[numValueBuckets]; - window.clear(); - - for (int i = 0; i < numValueBuckets; i++) { - window.offer((double)docCounts[i]); - - double avg = 0; - double alpha = 0.5; - boolean first = true; - - for (double value : window) { - if (first) { - avg = value; - first = false; - } else { - avg = (value * alpha) + (avg * (1 - alpha)); - } - } - singleExpMovAvgCounts[i] = avg ; - } - - } - - private void setupDouble() { - EvictingQueue window = EvictingQueue.create(windowSize); - doubleExpMovAvgCounts = new Double[numValueBuckets]; - - for (int i = 0; i < numValueBuckets; i++) { - double thisValue = docCounts[i]; - if (thisValue == -1) { - thisValue = 0; - } - window.offer(thisValue); - - double s = 0; - double last_s = 0; - - // Trend value - double b = 0; - double last_b = 0; - - double alpha = 0.5; - double beta = 0.5; - int counter = 0; - - double last; - for (double value : window) { - last = value; - if (counter == 1) { - s = value; - b = value - last; - } else { - s = alpha * value + (1.0d - alpha) * (last_s + last_b); - b = beta * (s - last_s) + (1 - beta) * last_b; - } - - counter += 1; - last_s = s; - last_b = b; - } - - doubleExpMovAvgCounts[i] = s + (0 * b) ; - } - - doubleExpMovAvgValueCounts = new Double[numValueBuckets]; - window.clear(); - - for (int i = 0; i < numValueBuckets; i++) { - window.offer((double)docCounts[i]); - - double s = 0; - double last_s = 0; - - // Trend value - double b = 0; - double last_b = 0; - - double alpha = 0.5; - double beta = 0.5; - int counter = 0; - - double last; - for (double value : window) { - last = value; - if (counter == 1) { - s = value; - b = value - last; - } else { - s = alpha * value + (1.0d - alpha) * (last_s + last_b); - b = beta * (s - last_s) + (1 - beta) * last_b; - } - - counter += 1; - last_s = s; - last_b = b; - } - - doubleExpMovAvgValueCounts[i] = s + (0 * b) ; - } - } - - /** - * test simple moving average on single value field - */ - @Test - public void simpleSingleValuedField() { - - SearchResponse response = client() - .prepareSearch("idx") - .addAggregation( - histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) - .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) - .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) - .subAggregation(smooth("smooth") - .window(windowSize) - .modelBuilder(new SimpleModel.SimpleModelBuilder()) - .gapPolicy(gapPolicy) - .setBucketsPaths("_count")) - .subAggregation(smooth("movavg_values") - .window(windowSize) - .modelBuilder(new SimpleModel.SimpleModelBuilder()) - .gapPolicy(gapPolicy) - .setBucketsPaths("the_sum")) - ).execute().actionGet(); - - assertSearchResponse(response); - - InternalHistogram histo = response.getAggregations().get("histo"); - assertThat(histo, notNullValue()); - assertThat(histo.getName(), equalTo("histo")); - List buckets = histo.getBuckets(); - assertThat(buckets.size(), equalTo(numValueBuckets)); - - for (int i = 0; i < numValueBuckets; ++i) { - Histogram.Bucket bucket = buckets.get(i); - checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]); - SimpleValue docCountMovAvg = bucket.getAggregations().get("smooth"); - assertThat(docCountMovAvg, notNullValue()); - assertThat(docCountMovAvg.value(), equalTo(simpleMovAvgCounts[i])); - - SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values"); - assertThat(valuesMovAvg, notNullValue()); - assertThat(valuesMovAvg.value(), equalTo(simpleMovAvgCounts[i])); - } - } - - /** - * test linear moving average on single value field - */ - @Test - public void linearSingleValuedField() { - - SearchResponse response = client() - .prepareSearch("idx") - .addAggregation( - histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) - .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) - .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) - .subAggregation(smooth("smooth") - .window(windowSize) - .modelBuilder(new LinearModel.LinearModelBuilder()) - .gapPolicy(gapPolicy) - .setBucketsPaths("_count")) - .subAggregation(smooth("movavg_values") - .window(windowSize) - .modelBuilder(new LinearModel.LinearModelBuilder()) - .gapPolicy(gapPolicy) - .setBucketsPaths("the_sum")) - ).execute().actionGet(); - - assertSearchResponse(response); - - InternalHistogram histo = response.getAggregations().get("histo"); - assertThat(histo, notNullValue()); - assertThat(histo.getName(), equalTo("histo")); - List buckets = histo.getBuckets(); - assertThat(buckets.size(), equalTo(numValueBuckets)); - - for (int i = 0; i < numValueBuckets; ++i) { - Histogram.Bucket bucket = buckets.get(i); - checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]); - SimpleValue docCountMovAvg = bucket.getAggregations().get("smooth"); - assertThat(docCountMovAvg, notNullValue()); - assertThat(docCountMovAvg.value(), equalTo(linearMovAvgCounts[i])); - - SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values"); - assertThat(valuesMovAvg, notNullValue()); - assertThat(valuesMovAvg.value(), equalTo(linearMovAvgCounts[i])); - } - } - - /** - * test single exponential moving average on single value field - */ - @Test - public void singleExpSingleValuedField() { - - SearchResponse response = client() - .prepareSearch("idx") - .addAggregation( - histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) - .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) - .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) - .subAggregation(smooth("smooth") - .window(windowSize) - .modelBuilder(new SingleExpModel.SingleExpModelBuilder().alpha(0.5)) - .gapPolicy(gapPolicy) - .setBucketsPaths("_count")) - .subAggregation(smooth("movavg_values") - .window(windowSize) - .modelBuilder(new SingleExpModel.SingleExpModelBuilder().alpha(0.5)) - .gapPolicy(gapPolicy) - .setBucketsPaths("the_sum")) - ).execute().actionGet(); - - assertSearchResponse(response); - - InternalHistogram histo = response.getAggregations().get("histo"); - assertThat(histo, notNullValue()); - assertThat(histo.getName(), equalTo("histo")); - List buckets = histo.getBuckets(); - assertThat(buckets.size(), equalTo(numValueBuckets)); - - for (int i = 0; i < numValueBuckets; ++i) { - Histogram.Bucket bucket = buckets.get(i); - checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]); - SimpleValue docCountMovAvg = bucket.getAggregations().get("smooth"); - assertThat(docCountMovAvg, notNullValue()); - assertThat(docCountMovAvg.value(), equalTo(singleExpMovAvgCounts[i])); - - SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values"); - assertThat(valuesMovAvg, notNullValue()); - assertThat(valuesMovAvg.value(), equalTo(singleExpMovAvgCounts[i])); - } - } - - /** - * test double exponential moving average on single value field - */ - @Test - public void doubleExpSingleValuedField() { - - SearchResponse response = client() - .prepareSearch("idx") - .addAggregation( - histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) - .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) - .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) - .subAggregation(smooth("smooth") - .window(windowSize) - .modelBuilder(new DoubleExpModel.DoubleExpModelBuilder().alpha(0.5).beta(0.5)) - .gapPolicy(gapPolicy) - .setBucketsPaths("_count")) - .subAggregation(smooth("movavg_values") - .window(windowSize) - .modelBuilder(new DoubleExpModel.DoubleExpModelBuilder().alpha(0.5).beta(0.5)) - .gapPolicy(gapPolicy) - .setBucketsPaths("the_sum")) - ).execute().actionGet(); - - assertSearchResponse(response); - - InternalHistogram histo = response.getAggregations().get("histo"); - assertThat(histo, notNullValue()); - assertThat(histo.getName(), equalTo("histo")); - List buckets = histo.getBuckets(); - assertThat(buckets.size(), equalTo(numValueBuckets)); - - for (int i = 0; i < numValueBuckets; ++i) { - Histogram.Bucket bucket = buckets.get(i); - checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]); - SimpleValue docCountMovAvg = bucket.getAggregations().get("smooth"); - assertThat(docCountMovAvg, notNullValue()); - assertThat(docCountMovAvg.value(), equalTo(doubleExpMovAvgCounts[i])); - - SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values"); - assertThat(valuesMovAvg, notNullValue()); - assertThat(valuesMovAvg.value(), equalTo(doubleExpMovAvgCounts[i])); - } - } - - - private void checkBucketKeyAndDocCount(final String msg, final Histogram.Bucket bucket, final long expectedKey, - long expectedDocCount) { - if (expectedDocCount == -1) { - expectedDocCount = 0; - } - assertThat(msg, bucket, notNullValue()); - assertThat(msg + " key", ((Number) bucket.getKey()).longValue(), equalTo(expectedKey)); - assertThat(msg + " docCount", bucket.getDocCount(), equalTo(expectedDocCount)); - } - -} diff --git a/src/test/java/org/elasticsearch/search/aggregations/reducers/moving/avg/MovAvgTests.java b/src/test/java/org/elasticsearch/search/aggregations/reducers/moving/avg/MovAvgTests.java new file mode 100644 index 00000000000..9c3a6f23419 --- /dev/null +++ b/src/test/java/org/elasticsearch/search/aggregations/reducers/moving/avg/MovAvgTests.java @@ -0,0 +1,1018 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.reducers.moving.avg; + + +import com.google.common.collect.EvictingQueue; + +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.index.query.RangeFilterBuilder; +import org.elasticsearch.search.aggregations.bucket.filter.InternalFilter; +import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; +import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram; +import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram.Bucket; +import org.elasticsearch.search.aggregations.reducers.BucketHelpers; +import org.elasticsearch.search.aggregations.reducers.SimpleValue; +import org.elasticsearch.search.aggregations.reducers.movavg.models.*; +import org.elasticsearch.test.ElasticsearchIntegrationTest; +import org.hamcrest.Matchers; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.elasticsearch.search.aggregations.AggregationBuilders.histogram; +import static org.elasticsearch.search.aggregations.AggregationBuilders.sum; +import static org.elasticsearch.search.aggregations.reducers.ReducerBuilders.smooth; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.IsNull.notNullValue; + +@ElasticsearchIntegrationTest.SuiteScopeTest +public class MovAvgTests extends ElasticsearchIntegrationTest { + + private static final String SINGLE_VALUED_FIELD_NAME = "l_value"; + private static final String SINGLE_VALUED_VALUE_FIELD_NAME = "v_value"; + private static final String GAP_FIELD = "g_value"; + + static int interval; + static int numValueBuckets; + static int numFilledValueBuckets; + static int windowSize; + static BucketHelpers.GapPolicy gapPolicy; + + static long[] docCounts; + static long[] valueCounts; + static Double[] simpleMovAvgCounts; + static Double[] linearMovAvgCounts; + static Double[] singleExpMovAvgCounts; + static Double[] doubleExpMovAvgCounts; + + static Double[] simpleMovAvgValueCounts; + static Double[] linearMovAvgValueCounts; + static Double[] singleExpMovAvgValueCounts; + static Double[] doubleExpMovAvgValueCounts; + + @Override + public void setupSuiteScopeCluster() throws Exception { + createIndex("idx"); + createIndex("idx_unmapped"); + List builders = new ArrayList<>(); + + interval = 5; + numValueBuckets = randomIntBetween(6, 80); + numFilledValueBuckets = numValueBuckets; + windowSize = randomIntBetween(3,10); + gapPolicy = BucketHelpers.GapPolicy.INSERT_ZEROS; // TODO randomBoolean() ? BucketHelpers.GapPolicy.IGNORE : BucketHelpers.GapPolicy.INSERT_ZEROS; + + docCounts = new long[numValueBuckets]; + valueCounts = new long[numValueBuckets]; + for (int i = 0; i < numValueBuckets; i++) { + docCounts[i] = randomIntBetween(0, 20); + valueCounts[i] = randomIntBetween(1,20); //this will be used as a constant for all values within a bucket + } + + // Used for the gap tests + builders.add(client().prepareIndex("idx", "type").setSource(jsonBuilder().startObject() + .field("gap_test", 0) + .field(GAP_FIELD, 1).endObject())); + builders.add(client().prepareIndex("idx", "type").setSource(jsonBuilder().startObject() + .field("gap_test", (numValueBuckets - 1) * interval) + .field(GAP_FIELD, 1).endObject())); + + this.setupSimple(); + this.setupLinear(); + this.setupSingle(); + this.setupDouble(); + + + + for (int i = 0; i < numValueBuckets; i++) { + for (int docs = 0; docs < docCounts[i]; docs++) { + builders.add(client().prepareIndex("idx", "type").setSource(jsonBuilder().startObject() + .field(SINGLE_VALUED_FIELD_NAME, i * interval) + .field(SINGLE_VALUED_VALUE_FIELD_NAME, 1).endObject())); + } + } + + indexRandom(true, builders); + ensureSearchable(); + } + + private void setupSimple() { + simpleMovAvgCounts = new Double[numValueBuckets]; + EvictingQueue window = EvictingQueue.create(windowSize); + for (int i = 0; i < numValueBuckets; i++) { + double thisValue = docCounts[i]; + window.offer(thisValue); + + double movAvg = 0; + for (double value : window) { + movAvg += value; + } + movAvg /= window.size(); + + simpleMovAvgCounts[i] = movAvg; + } + + window.clear(); + simpleMovAvgValueCounts = new Double[numValueBuckets]; + for (int i = 0; i < numValueBuckets; i++) { + window.offer((double)docCounts[i]); + + double movAvg = 0; + for (double value : window) { + movAvg += value; + } + movAvg /= window.size(); + + simpleMovAvgValueCounts[i] = movAvg; + + } + + } + + private void setupLinear() { + EvictingQueue window = EvictingQueue.create(windowSize); + linearMovAvgCounts = new Double[numValueBuckets]; + window.clear(); + for (int i = 0; i < numValueBuckets; i++) { + double thisValue = docCounts[i]; + if (thisValue == -1) { + thisValue = 0; + } + window.offer(thisValue); + + double avg = 0; + long totalWeight = 1; + long current = 1; + + for (double value : window) { + avg += value * current; + totalWeight += current; + current += 1; + } + linearMovAvgCounts[i] = avg / totalWeight; + } + + window.clear(); + linearMovAvgValueCounts = new Double[numValueBuckets]; + + for (int i = 0; i < numValueBuckets; i++) { + double thisValue = docCounts[i]; + window.offer(thisValue); + + double avg = 0; + long totalWeight = 1; + long current = 1; + + for (double value : window) { + avg += value * current; + totalWeight += current; + current += 1; + } + linearMovAvgValueCounts[i] = avg / totalWeight; + } + } + + private void setupSingle() { + EvictingQueue window = EvictingQueue.create(windowSize); + singleExpMovAvgCounts = new Double[numValueBuckets]; + for (int i = 0; i < numValueBuckets; i++) { + double thisValue = docCounts[i]; + if (thisValue == -1) { + thisValue = 0; + } + window.offer(thisValue); + + double avg = 0; + double alpha = 0.5; + boolean first = true; + + for (double value : window) { + if (first) { + avg = value; + first = false; + } else { + avg = (value * alpha) + (avg * (1 - alpha)); + } + } + singleExpMovAvgCounts[i] = avg ; + } + + singleExpMovAvgValueCounts = new Double[numValueBuckets]; + window.clear(); + + for (int i = 0; i < numValueBuckets; i++) { + window.offer((double)docCounts[i]); + + double avg = 0; + double alpha = 0.5; + boolean first = true; + + for (double value : window) { + if (first) { + avg = value; + first = false; + } else { + avg = (value * alpha) + (avg * (1 - alpha)); + } + } + singleExpMovAvgCounts[i] = avg ; + } + + } + + private void setupDouble() { + EvictingQueue window = EvictingQueue.create(windowSize); + doubleExpMovAvgCounts = new Double[numValueBuckets]; + + for (int i = 0; i < numValueBuckets; i++) { + double thisValue = docCounts[i]; + if (thisValue == -1) { + thisValue = 0; + } + window.offer(thisValue); + + double s = 0; + double last_s = 0; + + // Trend value + double b = 0; + double last_b = 0; + + double alpha = 0.5; + double beta = 0.5; + int counter = 0; + + double last; + for (double value : window) { + last = value; + if (counter == 1) { + s = value; + b = value - last; + } else { + s = alpha * value + (1.0d - alpha) * (last_s + last_b); + b = beta * (s - last_s) + (1 - beta) * last_b; + } + + counter += 1; + last_s = s; + last_b = b; + } + + doubleExpMovAvgCounts[i] = s + (0 * b) ; + } + + doubleExpMovAvgValueCounts = new Double[numValueBuckets]; + window.clear(); + + for (int i = 0; i < numValueBuckets; i++) { + window.offer((double)docCounts[i]); + + double s = 0; + double last_s = 0; + + // Trend value + double b = 0; + double last_b = 0; + + double alpha = 0.5; + double beta = 0.5; + int counter = 0; + + double last; + for (double value : window) { + last = value; + if (counter == 1) { + s = value; + b = value - last; + } else { + s = alpha * value + (1.0d - alpha) * (last_s + last_b); + b = beta * (s - last_s) + (1 - beta) * last_b; + } + + counter += 1; + last_s = s; + last_b = b; + } + + doubleExpMovAvgValueCounts[i] = s + (0 * b) ; + } + } + + /** + * test simple moving average on single value field + */ + @Test + public void simpleSingleValuedField() { + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(new SimpleModel.SimpleModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("_count")) + .subAggregation(movingAvg("movavg_values") + .window(windowSize) + .modelBuilder(new SimpleModel.SimpleModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalHistogram histo = response.getAggregations().get("histo"); + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(numValueBuckets)); + + for (int i = 0; i < numValueBuckets; ++i) { + Histogram.Bucket bucket = buckets.get(i); + checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]); + SimpleValue docCountMovAvg = bucket.getAggregations().get("movingAvg"); + assertThat(docCountMovAvg, notNullValue()); + assertThat(docCountMovAvg.value(), equalTo(simpleMovAvgCounts[i])); + + SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values"); + assertThat(valuesMovAvg, notNullValue()); + assertThat(valuesMovAvg.value(), equalTo(simpleMovAvgCounts[i])); + } + } + + /** + * test linear moving average on single value field + */ + @Test + public void linearSingleValuedField() { + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(new LinearModel.LinearModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("_count")) + .subAggregation(movingAvg("movavg_values") + .window(windowSize) + .modelBuilder(new LinearModel.LinearModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalHistogram histo = response.getAggregations().get("histo"); + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(numValueBuckets)); + + for (int i = 0; i < numValueBuckets; ++i) { + Histogram.Bucket bucket = buckets.get(i); + checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]); + SimpleValue docCountMovAvg = bucket.getAggregations().get("movingAvg"); + assertThat(docCountMovAvg, notNullValue()); + assertThat(docCountMovAvg.value(), equalTo(linearMovAvgCounts[i])); + + SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values"); + assertThat(valuesMovAvg, notNullValue()); + assertThat(valuesMovAvg.value(), equalTo(linearMovAvgCounts[i])); + } + } + + /** + * test single exponential moving average on single value field + */ + @Test + public void singleExpSingleValuedField() { + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(new SingleExpModel.SingleExpModelBuilder().alpha(0.5)) + .gapPolicy(gapPolicy) + .setBucketsPaths("_count")) + .subAggregation(movingAvg("movavg_values") + .window(windowSize) + .modelBuilder(new SingleExpModel.SingleExpModelBuilder().alpha(0.5)) + .gapPolicy(gapPolicy) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalHistogram histo = response.getAggregations().get("histo"); + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(numValueBuckets)); + + for (int i = 0; i < numValueBuckets; ++i) { + Histogram.Bucket bucket = buckets.get(i); + checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]); + SimpleValue docCountMovAvg = bucket.getAggregations().get("movingAvg"); + assertThat(docCountMovAvg, notNullValue()); + assertThat(docCountMovAvg.value(), equalTo(singleExpMovAvgCounts[i])); + + SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values"); + assertThat(valuesMovAvg, notNullValue()); + assertThat(valuesMovAvg.value(), equalTo(singleExpMovAvgCounts[i])); + } + } + + /** + * test double exponential moving average on single value field + */ + @Test + public void doubleExpSingleValuedField() { + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(new DoubleExpModel.DoubleExpModelBuilder().alpha(0.5).beta(0.5)) + .gapPolicy(gapPolicy) + .setBucketsPaths("_count")) + .subAggregation(movingAvg("movavg_values") + .window(windowSize) + .modelBuilder(new DoubleExpModel.DoubleExpModelBuilder().alpha(0.5).beta(0.5)) + .gapPolicy(gapPolicy) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalHistogram histo = response.getAggregations().get("histo"); + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(numValueBuckets)); + + for (int i = 0; i < numValueBuckets; ++i) { + Histogram.Bucket bucket = buckets.get(i); + checkBucketKeyAndDocCount("Bucket " + i, bucket, i * interval, docCounts[i]); + SimpleValue docCountMovAvg = bucket.getAggregations().get("movingAvg"); + assertThat(docCountMovAvg, notNullValue()); + assertThat(docCountMovAvg.value(), equalTo(doubleExpMovAvgCounts[i])); + + SimpleValue valuesMovAvg = bucket.getAggregations().get("movavg_values"); + assertThat(valuesMovAvg, notNullValue()); + assertThat(valuesMovAvg.value(), equalTo(doubleExpMovAvgCounts[i])); + } + } + + @Test + public void testSizeZeroWindow() { + try { + client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) + .subAggregation(movingAvg("movingAvg") + .window(0) + .modelBuilder(new SimpleModel.SimpleModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + fail("MovingAvg should not accept a window that is zero"); + + } catch (SearchPhaseExecutionException exception) { + //Throwable rootCause = exception.unwrapCause(); + //assertThat(rootCause, instanceOf(SearchParseException.class)); + //assertThat("[window] value must be a positive, non-zero integer. Value supplied was [0] in [movingAvg].", equalTo(exception.getMessage())); + } + } + + @Test + public void testBadParent() { + try { + client() + .prepareSearch("idx") + .addAggregation( + range("histo").field(SINGLE_VALUED_FIELD_NAME).addRange(0,10) + .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) + .subAggregation(movingAvg("movingAvg") + .window(0) + .modelBuilder(new SimpleModel.SimpleModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + fail("MovingAvg should not accept non-histogram as parent"); + + } catch (SearchPhaseExecutionException exception) { + // All good + } + } + + @Test + public void testNegativeWindow() { + try { + client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) + .subAggregation(movingAvg("movingAvg") + .window(-10) + .modelBuilder(new SimpleModel.SimpleModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("_count")) + ).execute().actionGet(); + fail("MovingAvg should not accept a window that is negative"); + + } catch (SearchPhaseExecutionException exception) { + //Throwable rootCause = exception.unwrapCause(); + //assertThat(rootCause, instanceOf(SearchParseException.class)); + //assertThat("[window] value must be a positive, non-zero integer. Value supplied was [0] in [movingAvg].", equalTo(exception.getMessage())); + } + } + + @Test + public void testNoBucketsInHistogram() { + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field("test").interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(new SimpleModel.SimpleModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalHistogram histo = response.getAggregations().get("histo"); + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(0)); + } + + @Test + public void testZeroPrediction() { + try { + client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(randomModelBuilder()) + .gapPolicy(gapPolicy) + .predict(0) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + fail("MovingAvg should not accept a prediction size that is zero"); + + } catch (SearchPhaseExecutionException exception) { + // All Good + } + } + + @Test + public void testNegativePrediction() { + try { + client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(SINGLE_VALUED_VALUE_FIELD_NAME)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(randomModelBuilder()) + .gapPolicy(gapPolicy) + .predict(-10) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + fail("MovingAvg should not accept a prediction size that is negative"); + + } catch (SearchPhaseExecutionException exception) { + // All Good + } + } + + /** + * This test uses the "gap" dataset, which is simply a doc at the beginning and end of + * the SINGLE_VALUED_FIELD_NAME range. These docs have a value of 1 in the `g_field`. + * This test verifies that large gaps don't break things, and that the mov avg roughly works + * in the correct manner (checks direction of change, but not actual values) + */ + @Test + public void testGiantGap() { + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field("gap_test").interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(GAP_FIELD)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(randomModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalHistogram histo = response.getAggregations().get("histo"); + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(numValueBuckets)); + + double lastValue = ((SimpleValue)(buckets.get(0).getAggregations().get("movingAvg"))).value(); + assertThat(Double.compare(lastValue, 0.0d), greaterThanOrEqualTo(0)); + + double currentValue; + for (int i = 1; i < numValueBuckets - 2; i++) { + currentValue = ((SimpleValue)(buckets.get(i).getAggregations().get("movingAvg"))).value(); + + // Since there are only two values in this test, at the beginning and end, the moving average should + // decrease every step (until it reaches zero). Crude way to check that it's doing the right thing + // without actually verifying the computed values. Should work for all types of moving avgs and + // gap policies + assertThat(Double.compare(lastValue, currentValue), greaterThanOrEqualTo(0)); + lastValue = currentValue; + } + + // The last bucket has a real value, so this should always increase the moving avg + currentValue = ((SimpleValue)(buckets.get(numValueBuckets - 1).getAggregations().get("movingAvg"))).value(); + assertThat(Double.compare(lastValue, currentValue), equalTo(-1)); + } + + /** + * Big gap, but with prediction at the end. + */ + @Test + public void testGiantGapWithPredict() { + + MovAvgModelBuilder model = randomModelBuilder(); + int numPredictions = randomIntBetween(0, 10); + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + histogram("histo").field("gap_test").interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(GAP_FIELD)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(model) + .gapPolicy(gapPolicy) + .predict(numPredictions) + .setBucketsPaths("the_sum")) + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalHistogram histo = response.getAggregations().get("histo"); + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(numValueBuckets + numPredictions)); + + double lastValue = ((SimpleValue)(buckets.get(0).getAggregations().get("movingAvg"))).value(); + assertThat(Double.compare(lastValue, 0.0d), greaterThanOrEqualTo(0)); + + double currentValue; + for (int i = 1; i < numValueBuckets - 2; i++) { + currentValue = ((SimpleValue)(buckets.get(i).getAggregations().get("movingAvg"))).value(); + + // Since there are only two values in this test, at the beginning and end, the moving average should + // decrease every step (until it reaches zero). Crude way to check that it's doing the right thing + // without actually verifying the computed values. Should work for all types of moving avgs and + // gap policies + assertThat(Double.compare(lastValue, currentValue), greaterThanOrEqualTo(0)); + lastValue = currentValue; + } + + // The last bucket has a real value, so this should always increase the moving avg + currentValue = ((SimpleValue)(buckets.get(numValueBuckets - 1).getAggregations().get("movingAvg"))).value(); + assertThat(Double.compare(lastValue, currentValue), equalTo(-1)); + + // Now check predictions + for (int i = numValueBuckets; i < numValueBuckets + numPredictions; i++) { + // Unclear at this point which direction the predictions will go, just verify they are + // not null, and that we don't have the_sum anymore + assertThat((buckets.get(i).getAggregations().get("movingAvg")), notNullValue()); + assertThat((buckets.get(i).getAggregations().get("the_sum")), nullValue()); + } + } + + /** + * This test filters the "gap" data so that the first doc is excluded. This leaves a long stretch of empty + * buckets until the final bucket. The moving avg should be zero up until the last bucket, and should work + * regardless of mov avg type or gap policy. + */ + @Test + public void testLeftGap() { + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + filter("filtered").filter(new RangeFilterBuilder("gap_test").from(1)).subAggregation( + histogram("histo").field("gap_test").interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(GAP_FIELD)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(randomModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("the_sum")) + ) + + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalFilter filtered = response.getAggregations().get("filtered"); + assertThat(filtered, notNullValue()); + assertThat(filtered.getName(), equalTo("filtered")); + + InternalHistogram histo = filtered.getAggregations().get("histo"); + + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(numValueBuckets)); + + double currentValue; + double lastValue = 0.0; + for (int i = 0; i < numValueBuckets - 1; i++) { + currentValue = ((SimpleValue)(buckets.get(i).getAggregations().get("movingAvg"))).value(); + + assertThat(Double.compare(lastValue, currentValue), lessThanOrEqualTo(0)); + lastValue = currentValue; + } + + } + + @Test + public void testLeftGapWithPrediction() { + + int numPredictions = randomIntBetween(0, 10); + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + filter("filtered").filter(new RangeFilterBuilder("gap_test").from(1)).subAggregation( + histogram("histo").field("gap_test").interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(GAP_FIELD)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(randomModelBuilder()) + .gapPolicy(gapPolicy) + .predict(numPredictions) + .setBucketsPaths("the_sum")) + ) + + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalFilter filtered = response.getAggregations().get("filtered"); + assertThat(filtered, notNullValue()); + assertThat(filtered.getName(), equalTo("filtered")); + + InternalHistogram histo = filtered.getAggregations().get("histo"); + + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(numValueBuckets + numPredictions)); + + double currentValue; + double lastValue = 0.0; + for (int i = 0; i < numValueBuckets - 1; i++) { + currentValue = ((SimpleValue)(buckets.get(i).getAggregations().get("movingAvg"))).value(); + + assertThat(Double.compare(lastValue, currentValue), lessThanOrEqualTo(0)); + lastValue = currentValue; + } + + // Now check predictions + for (int i = numValueBuckets; i < numValueBuckets + numPredictions; i++) { + // Unclear at this point which direction the predictions will go, just verify they are + // not null, and that we don't have the_sum anymore + assertThat((buckets.get(i).getAggregations().get("movingAvg")), notNullValue()); + assertThat((buckets.get(i).getAggregations().get("the_sum")), nullValue()); + } + } + + /** + * This test filters the "gap" data so that the last doc is excluded. This leaves a long stretch of empty + * buckets after the first bucket. The moving avg should be one at the beginning, then zero for the rest + * regardless of mov avg type or gap policy. + */ + @Test + public void testRightGap() { + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + filter("filtered").filter(new RangeFilterBuilder("gap_test").to((interval * (numValueBuckets - 1) - interval))).subAggregation( + histogram("histo").field("gap_test").interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(GAP_FIELD)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(randomModelBuilder()) + .gapPolicy(gapPolicy) + .setBucketsPaths("the_sum")) + ) + + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalFilter filtered = response.getAggregations().get("filtered"); + assertThat(filtered, notNullValue()); + assertThat(filtered.getName(), equalTo("filtered")); + + InternalHistogram histo = filtered.getAggregations().get("histo"); + + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(numValueBuckets)); + + double currentValue; + double lastValue = ((SimpleValue)(buckets.get(0).getAggregations().get("movingAvg"))).value(); + for (int i = 1; i < numValueBuckets - 1; i++) { + currentValue = ((SimpleValue)(buckets.get(i).getAggregations().get("movingAvg"))).value(); + + assertThat(Double.compare(lastValue, currentValue), greaterThanOrEqualTo(0)); + lastValue = currentValue; + } + + } + + @Test + public void testRightGapWithPredictions() { + + int numPredictions = randomIntBetween(0, 10); + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + filter("filtered").filter(new RangeFilterBuilder("gap_test").to((interval * (numValueBuckets - 1) - interval))).subAggregation( + histogram("histo").field("gap_test").interval(interval).minDocCount(0) + .extendedBounds(0L, (long) (interval * (numValueBuckets - 1))) + .subAggregation(sum("the_sum").field(GAP_FIELD)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(randomModelBuilder()) + .gapPolicy(gapPolicy) + .predict(numPredictions) + .setBucketsPaths("the_sum")) + ) + + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalFilter filtered = response.getAggregations().get("filtered"); + assertThat(filtered, notNullValue()); + assertThat(filtered.getName(), equalTo("filtered")); + + InternalHistogram histo = filtered.getAggregations().get("histo"); + + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(numValueBuckets + numPredictions)); + + double currentValue; + double lastValue = ((SimpleValue)(buckets.get(0).getAggregations().get("movingAvg"))).value(); + for (int i = 1; i < numValueBuckets - 1; i++) { + currentValue = ((SimpleValue)(buckets.get(i).getAggregations().get("movingAvg"))).value(); + + assertThat(Double.compare(lastValue, currentValue), greaterThanOrEqualTo(0)); + lastValue = currentValue; + } + + // Now check predictions + for (int i = numValueBuckets; i < numValueBuckets + numPredictions; i++) { + // Unclear at this point which direction the predictions will go, just verify they are + // not null, and that we don't have the_sum anymore + assertThat((buckets.get(i).getAggregations().get("movingAvg")), notNullValue()); + assertThat((buckets.get(i).getAggregations().get("the_sum")), nullValue()); + } + } + + @Test + public void testPredictWithNoBuckets() { + + int numPredictions = randomIntBetween(0, 10); + + SearchResponse response = client() + .prepareSearch("idx") + .addAggregation( + // Filter so we are above all values + filter("filtered").filter(new RangeFilterBuilder("gap_test").from((interval * (numValueBuckets - 1) + interval))).subAggregation( + histogram("histo").field("gap_test").interval(interval).minDocCount(0) + .subAggregation(sum("the_sum").field(GAP_FIELD)) + .subAggregation(movingAvg("movingAvg") + .window(windowSize) + .modelBuilder(randomModelBuilder()) + .gapPolicy(gapPolicy) + .predict(numPredictions) + .setBucketsPaths("the_sum")) + ) + + ).execute().actionGet(); + + assertSearchResponse(response); + + InternalFilter filtered = response.getAggregations().get("filtered"); + assertThat(filtered, notNullValue()); + assertThat(filtered.getName(), equalTo("filtered")); + + InternalHistogram histo = filtered.getAggregations().get("histo"); + + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat(buckets.size(), equalTo(0)); + } + + + private void checkBucketKeyAndDocCount(final String msg, final Histogram.Bucket bucket, final long expectedKey, + long expectedDocCount) { + if (expectedDocCount == -1) { + expectedDocCount = 0; + } + assertThat(msg, bucket, notNullValue()); + assertThat(msg + " key", ((Number) bucket.getKey()).longValue(), equalTo(expectedKey)); + assertThat(msg + " docCount", bucket.getDocCount(), equalTo(expectedDocCount)); + } + + private MovAvgModelBuilder randomModelBuilder() { + int rand = randomIntBetween(0,3); + + switch (rand) { + case 0: + return new SimpleModel.SimpleModelBuilder(); + case 1: + return new LinearModel.LinearModelBuilder(); + case 2: + return new SingleExpModel.SingleExpModelBuilder().alpha(randomDouble()); + case 3: + return new DoubleExpModel.DoubleExpModelBuilder().alpha(randomDouble()).beta(randomDouble()); + default: + return new SimpleModel.SimpleModelBuilder(); + } + } + +} diff --git a/src/test/java/org/elasticsearch/search/aggregations/reducers/moving/avg/MovAvgUnitTests.java b/src/test/java/org/elasticsearch/search/aggregations/reducers/moving/avg/MovAvgUnitTests.java new file mode 100644 index 00000000000..156f4f873a7 --- /dev/null +++ b/src/test/java/org/elasticsearch/search/aggregations/reducers/moving/avg/MovAvgUnitTests.java @@ -0,0 +1,297 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.reducers.moving.avg; + +import com.google.common.collect.EvictingQueue; +import org.elasticsearch.search.aggregations.reducers.movavg.models.*; +import org.elasticsearch.test.ElasticsearchTestCase; +import static org.hamcrest.Matchers.equalTo; +import org.junit.Test; + +public class MovAvgUnitTests extends ElasticsearchTestCase { + + @Test + public void testSimpleMovAvgModel() { + MovAvgModel model = new SimpleModel(); + + int numValues = randomIntBetween(1, 100); + int windowSize = randomIntBetween(1, 50); + + EvictingQueue window = EvictingQueue.create(windowSize); + for (int i = 0; i < numValues; i++) { + + double randValue = randomDouble(); + double expected = 0; + + window.offer(randValue); + + for (double value : window) { + expected += value; + } + expected /= window.size(); + + double actual = model.next(window); + assertThat(Double.compare(expected, actual), equalTo(0)); + } + } + + @Test + public void testSimplePredictionModel() { + MovAvgModel model = new SimpleModel(); + + int windowSize = randomIntBetween(1, 50); + int numPredictions = randomIntBetween(1,50); + + EvictingQueue window = EvictingQueue.create(windowSize); + for (int i = 0; i < windowSize; i++) { + window.offer(randomDouble()); + } + double actual[] = model.predict(window, numPredictions); + + double expected[] = new double[numPredictions]; + for (int i = 0; i < numPredictions; i++) { + for (double value : window) { + expected[i] += value; + } + expected[i] /= window.size(); + window.offer(expected[i]); + } + + for (int i = 0; i < numPredictions; i++) { + assertThat(Double.compare(expected[i], actual[i]), equalTo(0)); + } + } + + @Test + public void testLinearMovAvgModel() { + MovAvgModel model = new LinearModel(); + + int numValues = randomIntBetween(1, 100); + int windowSize = randomIntBetween(1, 50); + + EvictingQueue window = EvictingQueue.create(windowSize); + for (int i = 0; i < numValues; i++) { + double randValue = randomDouble(); + window.offer(randValue); + + double avg = 0; + long totalWeight = 1; + long current = 1; + + for (double value : window) { + avg += value * current; + totalWeight += current; + current += 1; + } + double expected = avg / totalWeight; + double actual = model.next(window); + assertThat(Double.compare(expected, actual), equalTo(0)); + } + } + + @Test + public void testLinearPredictionModel() { + MovAvgModel model = new LinearModel(); + + int windowSize = randomIntBetween(1, 50); + int numPredictions = randomIntBetween(1,50); + + EvictingQueue window = EvictingQueue.create(windowSize); + for (int i = 0; i < windowSize; i++) { + window.offer(randomDouble()); + } + double actual[] = model.predict(window, numPredictions); + double expected[] = new double[numPredictions]; + + for (int i = 0; i < numPredictions; i++) { + double avg = 0; + long totalWeight = 1; + long current = 1; + + for (double value : window) { + avg += value * current; + totalWeight += current; + current += 1; + } + expected[i] = avg / totalWeight; + window.offer(expected[i]); + } + + for (int i = 0; i < numPredictions; i++) { + assertThat(Double.compare(expected[i], actual[i]), equalTo(0)); + } + } + + @Test + public void testSingleExpMovAvgModel() { + double alpha = randomDouble(); + MovAvgModel model = new SingleExpModel(alpha); + + int numValues = randomIntBetween(1, 100); + int windowSize = randomIntBetween(1, 50); + + EvictingQueue window = EvictingQueue.create(windowSize); + for (int i = 0; i < numValues; i++) { + double randValue = randomDouble(); + window.offer(randValue); + + double avg = 0; + boolean first = true; + + for (double value : window) { + if (first) { + avg = value; + first = false; + } else { + avg = (value * alpha) + (avg * (1 - alpha)); + } + } + double expected = avg; + double actual = model.next(window); + assertThat(Double.compare(expected, actual), equalTo(0)); + } + } + + @Test + public void testSinglePredictionModel() { + double alpha = randomDouble(); + MovAvgModel model = new SingleExpModel(alpha); + + int windowSize = randomIntBetween(1, 50); + int numPredictions = randomIntBetween(1,50); + + EvictingQueue window = EvictingQueue.create(windowSize); + for (int i = 0; i < windowSize; i++) { + window.offer(randomDouble()); + } + double actual[] = model.predict(window, numPredictions); + double expected[] = new double[numPredictions]; + + for (int i = 0; i < numPredictions; i++) { + double avg = 0; + boolean first = true; + + for (double value : window) { + if (first) { + avg = value; + first = false; + } else { + avg = (value * alpha) + (avg * (1 - alpha)); + } + } + expected[i] = avg; + window.offer(expected[i]); + } + + for (int i = 0; i < numPredictions; i++) { + assertThat(Double.compare(expected[i], actual[i]), equalTo(0)); + } + } + + @Test + public void testDoubleExpMovAvgModel() { + double alpha = randomDouble(); + double beta = randomDouble(); + MovAvgModel model = new DoubleExpModel(alpha, beta); + + int numValues = randomIntBetween(1, 100); + int windowSize = randomIntBetween(1, 50); + + EvictingQueue window = EvictingQueue.create(windowSize); + for (int i = 0; i < numValues; i++) { + double randValue = randomDouble(); + window.offer(randValue); + + double s = 0; + double last_s = 0; + + // Trend value + double b = 0; + double last_b = 0; + int counter = 0; + + double last; + for (double value : window) { + last = value; + if (counter == 1) { + s = value; + b = value - last; + } else { + s = alpha * value + (1.0d - alpha) * (last_s + last_b); + b = beta * (s - last_s) + (1 - beta) * last_b; + } + + counter += 1; + last_s = s; + last_b = b; + } + + double expected = s + (0 * b) ; + double actual = model.next(window); + assertThat(Double.compare(expected, actual), equalTo(0)); + } + } + + @Test + public void testDoublePredictionModel() { + double alpha = randomDouble(); + double beta = randomDouble(); + MovAvgModel model = new DoubleExpModel(alpha, beta); + + int windowSize = randomIntBetween(1, 50); + int numPredictions = randomIntBetween(1,50); + + EvictingQueue window = EvictingQueue.create(windowSize); + for (int i = 0; i < windowSize; i++) { + window.offer(randomDouble()); + } + double actual[] = model.predict(window, numPredictions); + double expected[] = new double[numPredictions]; + + double s = 0; + double last_s = 0; + + // Trend value + double b = 0; + double last_b = 0; + int counter = 0; + + double last; + for (double value : window) { + last = value; + if (counter == 1) { + s = value; + b = value - last; + } else { + s = alpha * value + (1.0d - alpha) * (last_s + last_b); + b = beta * (s - last_s) + (1 - beta) * last_b; + } + + counter += 1; + last_s = s; + last_b = b; + } + + for (int i = 0; i < numPredictions; i++) { + expected[i] = s + (i * b); + assertThat(Double.compare(expected[i], actual[i]), equalTo(0)); + } + } +}