Aggregations: Add Holt-Winters model to `moving_avg` pipeline aggregation

Closes #11043
This commit is contained in:
Zachary Tong 2015-05-06 16:13:11 -04:00
parent cbb7b633f6
commit 491afbe01c
18 changed files with 1204 additions and 99 deletions

View File

@ -180,11 +180,11 @@ The default value of `alpha` is `0.5`, and the setting accepts any float from 0-
[[single_0.2alpha]]
.Single Exponential moving average with window of size 10, alpha = 0.2
.EWMA with window of size 10, alpha = 0.2
image::images/pipeline_movavg/single_0.2alpha.png[]
[[single_0.7alpha]]
.Single Exponential moving average with window of size 10, alpha = 0.7
.EWMA with window of size 10, alpha = 0.7
image::images/pipeline_movavg/single_0.7alpha.png[]
==== Holt-Linear
@ -223,13 +223,111 @@ to see. Small values emphasize long-term trends (such as a constant linear tren
values emphasize short-term trends. This will become more apparently when you are predicting values.
[[double_0.2beta]]
.Double Exponential moving average with window of size 100, alpha = 0.5, beta = 0.2
.Holt-Linear moving average with window of size 100, alpha = 0.5, beta = 0.2
image::images/pipeline_movavg/double_0.2beta.png[]
[[double_0.7beta]]
.Double Exponential moving average with window of size 100, alpha = 0.5, beta = 0.7
.Holt-Linear moving average with window of size 100, alpha = 0.5, beta = 0.7
image::images/pipeline_movavg/double_0.7beta.png[]
==== Holt-Winters
The `holt_winters` model (aka "triple exponential") incorporates a third exponential term which
tracks the seasonal aspect of your data. This aggregation therefore smooths based on three components: "level", "trend"
and "seasonality".
The level and trend calculation is identical to `holt` The seasonal calculation looks at the difference between
the current point, and the point one period earlier.
Holt-Winters requires a little more handholding than the other moving averages. You need to specify the "periodicity"
of your data: e.g. if your data has cyclic trends every 7 days, you would set `period: 7`. Similarly if there was
a monthly trend, you would set it to `30`. There is currently no periodicity detection, although that is planned
for future enhancements.
There are two varieties of Holt-Winters: additive and multiplicative.
===== "Cold Start"
Unfortunately, due to the nature of Holt-Winters, it requires two periods of data to "bootstrap" the algorithm. This
means that your `window` must always be *at least* twice the size of your period. An exception will be thrown if it
isn't. It also means that Holt-Winters will not emit a value for the first `2 * period` buckets; the current algorithm
does not backcast.
[[holt_winters_cold_start]]
.Holt-Winters showing a "cold" start where no values are emitted
image::images/reducers_movavg/triple_untruncated.png[]
Because the "cold start" obscures what the moving average looks like, the rest of the Holt-Winters images are truncated
to not show the "cold start". Just be aware this will always be present at the beginning of your moving averages!
===== Additive Holt-Winters
Additive seasonality is the default; it can also be specified by setting `"type": "add"`. This variety is preferred
when the seasonal affect is additive to your data. E.g. you could simply subtract the seasonal effect to "de-seasonalize"
your data into a flat trend.
The default value of `alpha`, `beta` and `gamma` is `0.5`, and the settings accept any float from 0-1 inclusive.
The default value of `period` is `1`.
[source,js]
--------------------------------------------------
{
"the_movavg":{
"moving_avg":{
"buckets_path": "the_sum",
"model" : "holt_winters",
"settings" : {
"type" : "add",
"alpha" : 0.5,
"beta" : 0.5,
"gamma" : 0.5,
"period" : 7
}
}
}
--------------------------------------------------
[[holt_winters_add]]
.Holt-Winters moving average with window of size 120, alpha = 0.5, beta = 0.7, gamma = 0.3, period = 30
image::images/reducers_movavg/triple.png[]
===== Multiplicative Holt-Winters
Multiplicative is specified by setting `"type": "mult"`. This variety is preferred when the seasonal affect is
multiplied against your data. E.g. if the seasonal affect is x5 the data, rather than simply adding to it.
The default value of `alpha`, `beta` and `gamma` is `0.5`, and the settings accept any float from 0-1 inclusive.
The default value of `period` is `1`.
[WARNING]
======
Multiplicative Holt-Winters works by dividing each data point by the seasonal value. This is problematic if any of
your data is zero, or if there are gaps in the data (since this results in a divid-by-zero). To combat this, the
`mult` Holt-Winters pads all values by a very small amount (1*10^-10^) so that all values are non-zero. This affects
the result, but only minimally. If your data is non-zero, or you prefer to see `NaN` when zero's are encountered,
you can disable this behavior with `pad: false`
======
[source,js]
--------------------------------------------------
{
"the_movavg":{
"moving_avg":{
"buckets_path": "the_sum",
"model" : "holt_winters",
"settings" : {
"type" : "mult",
"alpha" : 0.5,
"beta" : 0.5,
"gamma" : 0.5,
"period" : 7,
"pad" : true
}
}
}
--------------------------------------------------
==== Prediction
All the moving average model support a "prediction" mode, which will attempt to extrapolate into the future given the
@ -263,7 +361,7 @@ value, we can extrapolate based on local constant trends (in this case the predi
of the series was heading in a downward direction):
[[double_prediction_local]]
.Double Exponential moving average with window of size 100, predict = 20, alpha = 0.5, beta = 0.8
.Holt-Linear moving average with window of size 100, predict = 20, alpha = 0.5, beta = 0.8
image::images/pipeline_movavg/double_prediction_local.png[]
In contrast, if we choose a small `beta`, the predictions are based on the global constant trend. In this series, the
@ -272,3 +370,10 @@ global trend is slightly positive, so the prediction makes a sharp u-turn and be
[[double_prediction_global]]
.Double Exponential moving average with window of size 100, predict = 20, alpha = 0.5, beta = 0.1
image::images/pipeline_movavg/double_prediction_global.png[]
The `holt_winters` model has the potential to deliver the best predictions, since it also incorporates seasonal
fluctuations into the model:
[[holt_winters_prediction_global]]
.Holt-Winters moving average with window of size 120, predict = 25, alpha = 0.8, beta = 0.2, gamma = 0.7, period = 30
image::images/pipeline_movavg/triple_prediction.png[]

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

View File

@ -27,7 +27,6 @@ import org.elasticsearch.search.aggregations.pipeline.BucketHelpers.GapPolicy;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorFactory;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModelParser;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModelParserMapper;
import org.elasticsearch.search.aggregations.support.format.ValueFormat;
import org.elasticsearch.search.aggregations.support.format.ValueFormatter;
@ -140,12 +139,12 @@ public class MovAvgParser implements PipelineAggregator.Parser {
formatter = ValueFormat.Patternable.Number.format(format).formatter();
}
MovAvgModelParser modelParser = movAvgModelParserMapper.get(model);
MovAvgModel.AbstractModelParser modelParser = movAvgModelParserMapper.get(model);
if (modelParser == null) {
throw new SearchParseException(context, "Unknown model [" + model + "] specified. Valid options are:"
+ movAvgModelParserMapper.getAllNames().toString(), parser.getTokenLocation());
}
MovAvgModel movAvgModel = modelParser.parse(settings);
MovAvgModel movAvgModel = modelParser.parse(settings, pipelineAggregatorName, context, window);
return new MovAvgPipelineAggregator.Factory(pipelineAggregatorName, bucketsPaths, formatter, gapPolicy, window, predict,
movAvgModel);

View File

@ -117,21 +117,26 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], gapPolicy);
currentKey = bucket.getKey();
// Default is to reuse existing bucket. Simplifies the rest of the logic,
// since we only change newBucket if we can add to it
InternalHistogram.Bucket newBucket = bucket;
if (!(thisBucketValue == null || thisBucketValue.equals(Double.NaN))) {
values.offer(thisBucketValue);
double movavg = model.next(values);
// Some models (e.g. HoltWinters) have certain preconditions that must be met
if (model.hasValue(values.size())) {
double movavg = model.next(values);
List<InternalAggregation> aggs = new ArrayList<>(Lists.transform(bucket.getAggregations().asList(), FUNCTION));
aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList<PipelineAggregator>(), metaData()));
InternalHistogram.Bucket newBucket = factory.createBucket(currentKey, bucket.getDocCount(), new InternalAggregations(
aggs), bucket.getKeyed(), bucket.getFormatter());
newBuckets.add(newBucket);
} else {
newBuckets.add(bucket);
List<InternalAggregation> aggs = new ArrayList<>(Lists.transform(bucket.getAggregations().asList(), AGGREGATION_TRANFORM_FUNCTION));
aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList<PipelineAggregator>(), metaData()));
newBucket = factory.createBucket(currentKey, bucket.getDocCount(), new InternalAggregations(
aggs), bucket.getKeyed(), bucket.getFormatter());
}
}
newBuckets.add(newBucket);
if (predict > 0) {
if (currentKey instanceof Number) {
lastKey = ((Number) bucket.getKey()).longValue();

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.aggregations.pipeline.movavg.MovAvgParser;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.Collection;
@ -83,7 +84,7 @@ public class EwmaModel extends MovAvgModel {
out.writeDouble(alpha);
}
public static class SingleExpModelParser implements MovAvgModelParser {
public static class SingleExpModelParser extends AbstractModelParser {
@Override
public String getName() {
@ -91,15 +92,13 @@ public class EwmaModel extends MovAvgModel {
}
@Override
public MovAvgModel parse(@Nullable Map<String, Object> settings) {
public MovAvgModel parse(@Nullable Map<String, Object> settings, String pipelineName, SearchContext context, int windowSize) {
Double alpha;
if (settings == null || (alpha = (Double)settings.get("alpha")) == null) {
alpha = 0.5;
}
double alpha = parseDoubleParam(context, settings, "alpha", 0.5);
return new EwmaModel(alpha);
}
}
public static class EWMAModelBuilder implements MovAvgModelBuilder {

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.aggregations.pipeline.movavg.MovAvgParser;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.*;
@ -142,7 +143,7 @@ public class HoltLinearModel extends MovAvgModel {
out.writeDouble(beta);
}
public static class DoubleExpModelParser implements MovAvgModelParser {
public static class DoubleExpModelParser extends AbstractModelParser {
@Override
public String getName() {
@ -150,19 +151,10 @@ public class HoltLinearModel extends MovAvgModel {
}
@Override
public MovAvgModel parse(@Nullable Map<String, Object> settings) {
Double alpha;
Double beta;
if (settings == null || (alpha = (Double)settings.get("alpha")) == null) {
alpha = 0.5;
}
if (settings == null || (beta = (Double)settings.get("beta")) == null) {
beta = 0.5;
}
public MovAvgModel parse(@Nullable Map<String, Object> settings, String pipelineName, SearchContext context, int windowSize) {
double alpha = parseDoubleParam(context, settings, "alpha", 0.5);
double beta = parseDoubleParam(context, settings, "beta", 0.5);
return new HoltLinearModel(alpha, beta);
}
}

View File

@ -0,0 +1,422 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.pipeline.movavg.models;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.SearchParseException;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.pipeline.movavg.MovAvgParser;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.*;
/**
* Calculate a triple exponential weighted moving average
*/
public class HoltWintersModel extends MovAvgModel {
protected static final ParseField NAME_FIELD = new ParseField("holt_winters");
public enum SeasonalityType {
ADDITIVE((byte) 0, "add"), MULTIPLICATIVE((byte) 1, "mult");
/**
* Parse a string SeasonalityType into the byte enum
*
* @param text SeasonalityType in string format (e.g. "add")
* @return SeasonalityType enum
*/
@Nullable
public static SeasonalityType parse(String text) {
if (text == null) {
return null;
}
SeasonalityType result = null;
for (SeasonalityType policy : values()) {
if (policy.parseField.match(text)) {
if (result == null) {
result = policy;
} else {
throw new IllegalStateException("Text can be parsed to 2 different seasonality types: text=[" + text
+ "], " + "policies=" + Arrays.asList(result, policy));
}
}
}
if (result == null) {
final List<String> validNames = new ArrayList<>();
for (SeasonalityType policy : values()) {
validNames.add(policy.getName());
}
throw new ElasticsearchParseException("Invalid seasonality type: [" + text + "], accepted values: " + validNames);
}
return result;
}
private final byte id;
private final ParseField parseField;
SeasonalityType(byte id, String name) {
this.id = id;
this.parseField = new ParseField(name);
}
/**
* Serialize the SeasonalityType to the output stream
*
* @param out
* @throws IOException
*/
public void writeTo(StreamOutput out) throws IOException {
out.writeByte(id);
}
/**
* Deserialize the SeasonalityType from the input stream
*
* @param in
* @return SeasonalityType Enum
* @throws IOException
*/
public static SeasonalityType readFrom(StreamInput in) throws IOException {
byte id = in.readByte();
for (SeasonalityType seasonalityType : values()) {
if (id == seasonalityType.id) {
return seasonalityType;
}
}
throw new IllegalStateException("Unknown Seasonality Type with id [" + id + "]");
}
/**
* Return the english-formatted name of the SeasonalityType
*
* @return English representation of SeasonalityType
*/
public String getName() {
return parseField.getPreferredName();
}
}
/**
* Controls smoothing of data. Alpha = 1 retains no memory of past values
* (e.g. random walk), while alpha = 0 retains infinite memory of past values (e.g.
* mean of the series). Useful values are somewhere in between
*/
private double alpha;
/**
* Equivalent to <code>alpha</code>, but controls the smoothing of the trend instead of the data
*/
private double beta;
private double gamma;
private int period;
private SeasonalityType seasonalityType;
private boolean pad;
private double padding;
public HoltWintersModel(double alpha, double beta, double gamma, int period, SeasonalityType seasonalityType, boolean pad) {
this.alpha = alpha;
this.beta = beta;
this.gamma = gamma;
this.period = period;
this.seasonalityType = seasonalityType;
this.pad = pad;
// Only pad if we are multiplicative and padding is enabled
// The padding amount is not currently user-configurable...i dont see a reason to expose it?
this.padding = seasonalityType.equals(SeasonalityType.MULTIPLICATIVE) && pad ? 0.0000000001 : 0;
}
@Override
public boolean hasValue(int windowLength) {
// We need at least (period * 2) data-points (e.g. two "seasons")
return windowLength >= period * 2;
}
/**
* Predicts the next `n` values in the series, using the smoothing model to generate new values.
* Unlike the other moving averages, HoltWinters has forecasting/prediction built into the algorithm.
* Prediction is more than simply adding the next prediction to the window and repeating. HoltWinters
* will extrapolate into the future by applying the trend and seasonal information to the smoothed data.
*
* @param values Collection of numerics to movingAvg, usually windowed
* @param numPredictions Number of newly generated predictions to return
* @param <T> Type of numeric
* @return Returns an array of doubles, since most smoothing methods operate on floating points
*/
@Override
public <T extends Number> double[] predict(Collection<T> values, int numPredictions) {
return next(values, numPredictions);
}
@Override
public <T extends Number> double next(Collection<T> values) {
return next(values, 1)[0];
}
/**
* Calculate a doubly exponential weighted moving average
*
* @param values Collection of values to calculate avg for
* @param numForecasts number of forecasts into the future to return
*
* @param <T> Type T extending Number
* @return Returns a Double containing the moving avg for the window
*/
public <T extends Number> double[] next(Collection<T> values, int numForecasts) {
if (values.size() < period * 2) {
// We need at least two full "seasons" to use HW
// This should have been caught earlier, we can't do anything now...bail
throw new AggregationExecutionException("Holt-Winters aggregation requires at least (2 * period == 2 * "
+ period + " == "+(2 * period)+") data-points to function. Only [" + values.size() + "] were provided.");
}
// Smoothed value
double s = 0;
double last_s = 0;
// Trend value
double b = 0;
double last_b = 0;
// Seasonal value
double[] seasonal = new double[values.size()];
int counter = 0;
double[] vs = new double[values.size()];
for (T v : values) {
vs[counter] = v.doubleValue() + padding;
counter += 1;
}
// Initial level value is average of first season
// Calculate the slopes between first and second season for each period
for (int i = 0; i < period; i++) {
s += vs[i];
b += (vs[i] - vs[i + period]) / 2;
}
s /= (double) period;
b /= (double) period;
last_s = s;
last_b = b;
// Calculate first seasonal
if (Double.compare(s, 0.0) == 0 || Double.compare(s, -0.0) == 0) {
Arrays.fill(seasonal, 0.0);
} else {
for (int i = 0; i < period; i++) {
seasonal[i] = vs[i] / s;
}
}
for (int i = period; i < vs.length; i++) {
// TODO if perf is a problem, we can specialize a subclass to avoid conditionals on each iteration
if (seasonalityType.equals(SeasonalityType.MULTIPLICATIVE)) {
s = alpha * (vs[i] / seasonal[i - period]) + (1.0d - alpha) * (last_s + last_b);
} else {
s = alpha * (vs[i] - seasonal[i - period]) + (1.0d - alpha) * (last_s + last_b);
}
b = beta * (s - last_s) + (1 - beta) * last_b;
if (seasonalityType.equals(SeasonalityType.MULTIPLICATIVE)) {
seasonal[i] = gamma * (vs[i] / (last_s + last_b )) + (1 - gamma) * seasonal[i - period];
} else {
seasonal[i] = gamma * (vs[i] - (last_s + last_b )) + (1 - gamma) * seasonal[i - period];
}
last_s = s;
last_b = b;
}
double[] forecastValues = new double[numForecasts];
int seasonCounter = (values.size() - 1) - period;
for (int i = 0; i < numForecasts; i++) {
// TODO perhaps pad out seasonal to a power of 2 and use a mask instead of modulo?
if (seasonalityType.equals(SeasonalityType.MULTIPLICATIVE)) {
forecastValues[i] = s + (i * b) * seasonal[seasonCounter % values.size()];
} else {
forecastValues[i] = s + (i * b) + seasonal[seasonCounter % values.size()];
}
seasonCounter += 1;
}
return forecastValues;
}
public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() {
@Override
public MovAvgModel readResult(StreamInput in) throws IOException {
double alpha = in.readDouble();
double beta = in.readDouble();
double gamma = in.readDouble();
int period = in.readVInt();
SeasonalityType type = SeasonalityType.readFrom(in);
boolean pad = in.readBoolean();
return new HoltWintersModel(alpha, beta, gamma, period, type, pad);
}
@Override
public String getName() {
return NAME_FIELD.getPreferredName();
}
};
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(STREAM.getName());
out.writeDouble(alpha);
out.writeDouble(beta);
out.writeDouble(gamma);
out.writeVInt(period);
seasonalityType.writeTo(out);
out.writeBoolean(pad);
}
public static class HoltWintersModelParser extends AbstractModelParser {
@Override
public String getName() {
return NAME_FIELD.getPreferredName();
}
@Override
public MovAvgModel parse(@Nullable Map<String, Object> settings, String pipelineName, SearchContext context, int windowSize) {
double alpha = parseDoubleParam(context, settings, "alpha", 0.5);
double beta = parseDoubleParam(context, settings, "beta", 0.5);
double gamma = parseDoubleParam(context, settings, "gamma", 0.5);
int period = parseIntegerParam(context, settings, "period", 1);
if (windowSize < 2 * period) {
throw new SearchParseException(context, "Field [window] must be at least twice as large as the period when " +
"using Holt-Winters. Value provided was [" + windowSize + "], which is less than (2*period) == "
+ (2 * period), null);
}
SeasonalityType seasonalityType = SeasonalityType.ADDITIVE;
if (settings != null) {
Object value = settings.get("type");
if (value != null) {
if (value instanceof String) {
seasonalityType = SeasonalityType.parse((String)value);
} else {
throw new SearchParseException(context, "Parameter [type] must be a String, type `"
+ value.getClass().getSimpleName() + "` provided instead", null);
}
}
}
boolean pad = parseBoolParam(context, settings, "pad", seasonalityType.equals(SeasonalityType.MULTIPLICATIVE));
return new HoltWintersModel(alpha, beta, gamma, period, seasonalityType, pad);
}
}
public static class HoltWintersModelBuilder implements MovAvgModelBuilder {
private double alpha = 0.5;
private double beta = 0.5;
private double gamma = 0.5;
private int period = 1;
private SeasonalityType seasonalityType = SeasonalityType.ADDITIVE;
private boolean pad = true;
/**
* Alpha controls the smoothing of the data. Alpha = 1 retains no memory of past values
* (e.g. a random walk), while alpha = 0 retains infinite memory of past values (e.g.
* the series mean). Useful values are somewhere in between. Defaults to 0.5.
*
* @param alpha A double between 0-1 inclusive, controls data smoothing
*
* @return The builder to continue chaining
*/
public HoltWintersModelBuilder alpha(double alpha) {
this.alpha = alpha;
return this;
}
/**
* Equivalent to <code>alpha</code>, but controls the smoothing of the trend instead of the data
*
* @param beta a double between 0-1 inclusive, controls trend smoothing
*
* @return The builder to continue chaining
*/
public HoltWintersModelBuilder beta(double beta) {
this.beta = beta;
return this;
}
public HoltWintersModelBuilder gamma(double gamma) {
this.gamma = gamma;
return this;
}
public HoltWintersModelBuilder period(int period) {
this.period = period;
return this;
}
public HoltWintersModelBuilder seasonalityType(SeasonalityType type) {
this.seasonalityType = type;
return this;
}
public HoltWintersModelBuilder pad(boolean pad) {
this.pad = pad;
return this;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName());
builder.startObject(MovAvgParser.SETTINGS.getPreferredName());
builder.field("alpha", alpha);
builder.field("beta", beta);
builder.field("gamma", gamma);
builder.field("period", period);
builder.field("type", seasonalityType.getName());
builder.field("pad", pad);
builder.endObject();
return builder;
}
}
}

View File

@ -26,6 +26,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.aggregations.pipeline.movavg.MovAvgParser;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.Collection;
@ -70,7 +71,7 @@ public class LinearModel extends MovAvgModel {
out.writeString(STREAM.getName());
}
public static class LinearModelParser implements MovAvgModelParser {
public static class LinearModelParser extends AbstractModelParser {
@Override
public String getName() {
@ -78,7 +79,7 @@ public class LinearModel extends MovAvgModel {
}
@Override
public MovAvgModel parse(@Nullable Map<String, Object> settings) {
public MovAvgModel parse(@Nullable Map<String, Object> settings, String pipelineName, SearchContext context, int windowSize) {
return new LinearModel();
}
}

View File

@ -21,14 +21,31 @@ package org.elasticsearch.search.aggregations.pipeline.movavg.models;
import com.google.common.collect.EvictingQueue;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.SearchParseException;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
public abstract class MovAvgModel {
/**
* Checks to see this model can produce a new value, without actually running the algo.
* This can be used for models that have certain preconditions that need to be met in order
* to short-circuit execution
*
* @param windowLength Length of current window
* @return Returns `true` if calling next() will produce a value, `false` otherwise
*/
public boolean hasValue(int windowLength) {
// Default implementation can always provide a next() value
return true;
}
/**
* Returns the next value in the series, according to the underlying smoothing model
*
@ -90,6 +107,122 @@ public abstract class MovAvgModel {
* @throws IOException
*/
public abstract void writeTo(StreamOutput out) throws IOException;
/**
* Abstract class which also provides some concrete parsing functionality.
*/
public abstract static class AbstractModelParser {
/**
* Returns the name of the model
*
* @return The model's name
*/
public abstract String getName();
/**
* Parse a settings hash that is specific to this model
*
* @param settings Map of settings, extracted from the request
* @param pipelineName Name of the parent pipeline agg
* @param context The parser context that we are in
* @param windowSize Size of the window for this moving avg
* @return A fully built moving average model
*/
public abstract MovAvgModel parse(@Nullable Map<String, Object> settings, String pipelineName, SearchContext context, int windowSize);
/**
* Extracts a 0-1 inclusive double from the settings map, otherwise throws an exception
*
* @param context Search query context
* @param settings Map of settings provided to this model
* @param name Name of parameter we are attempting to extract
* @param defaultValue Default value to be used if value does not exist in map
*
* @throws SearchParseException
*
* @return Double value extracted from settings map
*/
protected double parseDoubleParam(SearchContext context, @Nullable Map<String, Object> settings, String name, double defaultValue) {
if (settings == null) {
return defaultValue;
}
Object value = settings.get(name);
if (value == null) {
return defaultValue;
} else if (value instanceof Double) {
double v = (Double)value;
if (v >= 0 && v <= 1) {
return v;
}
throw new SearchParseException(context, "Parameter [" + name + "] must be between 0-1 inclusive. Provided"
+ "value was [" + v + "]", null);
}
throw new SearchParseException(context, "Parameter [" + name + "] must be a double, type `"
+ value.getClass().getSimpleName() + "` provided instead", null);
}
/**
* Extracts an integer from the settings map, otherwise throws an exception
*
* @param context Search query context
* @param settings Map of settings provided to this model
* @param name Name of parameter we are attempting to extract
* @param defaultValue Default value to be used if value does not exist in map
*
* @throws SearchParseException
*
* @return Integer value extracted from settings map
*/
protected int parseIntegerParam(SearchContext context, @Nullable Map<String, Object> settings, String name, int defaultValue) {
if (settings == null) {
return defaultValue;
}
Object value = settings.get(name);
if (value == null) {
return defaultValue;
} else if (value instanceof Integer) {
return (Integer)value;
}
throw new SearchParseException(context, "Parameter [" + name + "] must be an integer, type `"
+ value.getClass().getSimpleName() + "` provided instead", null);
}
/**
* Extracts a boolean from the settings map, otherwise throws an exception
*
* @param context Search query context
* @param settings Map of settings provided to this model
* @param name Name of parameter we are attempting to extract
* @param defaultValue Default value to be used if value does not exist in map
*
* @throws SearchParseException
*
* @return Boolean value extracted from settings map
*/
protected boolean parseBoolParam(SearchContext context, @Nullable Map<String, Object> settings, String name, boolean defaultValue) {
if (settings == null) {
return defaultValue;
}
Object value = settings.get(name);
if (value == null) {
return defaultValue;
} else if (value instanceof Boolean) {
return (Boolean)value;
}
throw new SearchParseException(context, "Parameter [" + name + "] must be a boolean, type `"
+ value.getClass().getSimpleName() + "` provided instead", null);
}
}
}

View File

@ -31,23 +31,24 @@ import java.util.List;
*/
public class MovAvgModelModule extends AbstractModule {
private List<Class<? extends MovAvgModelParser>> parsers = Lists.newArrayList();
private List<Class<? extends MovAvgModel.AbstractModelParser>> parsers = Lists.newArrayList();
public MovAvgModelModule() {
registerParser(SimpleModel.SimpleModelParser.class);
registerParser(LinearModel.LinearModelParser.class);
registerParser(EwmaModel.SingleExpModelParser.class);
registerParser(HoltLinearModel.DoubleExpModelParser.class);
registerParser(HoltWintersModel.HoltWintersModelParser.class);
}
public void registerParser(Class<? extends MovAvgModelParser> parser) {
public void registerParser(Class<? extends MovAvgModel.AbstractModelParser> parser) {
parsers.add(parser);
}
@Override
protected void configure() {
Multibinder<MovAvgModelParser> parserMapBinder = Multibinder.newSetBinder(binder(), MovAvgModelParser.class);
for (Class<? extends MovAvgModelParser> clazz : parsers) {
Multibinder<MovAvgModel.AbstractModelParser> parserMapBinder = Multibinder.newSetBinder(binder(), MovAvgModel.AbstractModelParser.class);
for (Class<? extends MovAvgModel.AbstractModelParser> clazz : parsers) {
parserMapBinder.addBinding().to(clazz);
}
bind(MovAvgModelParserMapper.class);

View File

@ -1,34 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.pipeline.movavg.models;
import org.elasticsearch.common.Nullable;
import java.util.Map;
/**
* Common interface for parsers used by the various Moving Average models
*/
public interface MovAvgModelParser {
public MovAvgModel parse(@Nullable Map<String, Object> settings);
public String getName();
}

View File

@ -32,19 +32,19 @@ import java.util.Set;
*/
public class MovAvgModelParserMapper {
protected ImmutableMap<String, MovAvgModelParser> movAvgParsers;
protected ImmutableMap<String, MovAvgModel.AbstractModelParser> movAvgParsers;
@Inject
public MovAvgModelParserMapper(Set<MovAvgModelParser> parsers) {
MapBuilder<String, MovAvgModelParser> builder = MapBuilder.newMapBuilder();
for (MovAvgModelParser parser : parsers) {
public MovAvgModelParserMapper(Set<MovAvgModel.AbstractModelParser> parsers) {
MapBuilder<String, MovAvgModel.AbstractModelParser> builder = MapBuilder.newMapBuilder();
for (MovAvgModel.AbstractModelParser parser : parsers) {
builder.put(parser.getName(), parser);
}
movAvgParsers = builder.immutableMap();
}
public @Nullable
MovAvgModelParser get(String parserName) {
MovAvgModel.AbstractModelParser get(String parserName) {
return movAvgParsers.get(parserName);
}

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.aggregations.pipeline.movavg.MovAvgParser;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.Collection;
@ -63,7 +64,7 @@ public class SimpleModel extends MovAvgModel {
out.writeString(STREAM.getName());
}
public static class SimpleModelParser implements MovAvgModelParser {
public static class SimpleModelParser extends AbstractModelParser {
@Override
public String getName() {
@ -71,7 +72,7 @@ public class SimpleModel extends MovAvgModel {
}
@Override
public MovAvgModel parse(@Nullable Map<String, Object> settings) {
public MovAvgModel parse(@Nullable Map<String, Object> settings, String pipelineName, SearchContext context, int windowSize) {
return new SimpleModel();
}
}

View File

@ -36,6 +36,7 @@ public class TransportMovAvgModelModule extends AbstractModule {
registerStream(LinearModel.STREAM);
registerStream(EwmaModel.STREAM);
registerStream(HoltLinearModel.STREAM);
registerStream(HoltWintersModel.STREAM);
}
public void registerStream(MovAvgModelStreams.Stream stream) {

View File

@ -35,21 +35,12 @@ import org.elasticsearch.search.aggregations.metrics.avg.Avg;
import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregationHelperTests;
import org.elasticsearch.search.aggregations.pipeline.SimpleValue;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.EwmaModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.HoltLinearModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.LinearModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModelBuilder;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.SimpleModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.*;
import org.elasticsearch.test.ElasticsearchIntegrationTest;
import org.hamcrest.Matchers;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.*;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.search.aggregations.AggregationBuilders.avg;
@ -79,6 +70,9 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
static int windowSize;
static double alpha;
static double beta;
static double gamma;
static int period;
static HoltWintersModel.SeasonalityType seasonalityType;
static BucketHelpers.GapPolicy gapPolicy;
static ValuesSourceMetricsAggregationBuilder metric;
static List<PipelineAggregationHelperTests.MockBucket> mockHisto;
@ -87,7 +81,7 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
enum MovAvgType {
SIMPLE ("simple"), LINEAR("linear"), EWMA("ewma"), HOLT("holt");
SIMPLE ("simple"), LINEAR("linear"), EWMA("ewma"), HOLT("holt"), HOLT_WINTERS("holt_winters");
private final String name;
@ -124,9 +118,13 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
interval = 5;
numBuckets = randomIntBetween(6, 80);
windowSize = randomIntBetween(3, 10);
period = randomIntBetween(1, 5);
windowSize = randomIntBetween(period * 2, 10); // start must be 2*period to play nice with HW
alpha = randomDouble();
beta = randomDouble();
gamma = randomDouble();
seasonalityType = randomBoolean() ? HoltWintersModel.SeasonalityType.ADDITIVE : HoltWintersModel.SeasonalityType.MULTIPLICATIVE;
gapPolicy = randomBoolean() ? BucketHelpers.GapPolicy.SKIP : BucketHelpers.GapPolicy.INSERT_ZEROS;
metric = randomMetric("the_metric", VALUE_FIELD);
@ -212,6 +210,15 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
break;
case HOLT:
values.add(holt(window));
break;
case HOLT_WINTERS:
// HW needs at least 2 periods of data to start
if (window.size() >= period * 2) {
values.add(holtWinters(window));
} else {
values.add(null);
}
break;
}
@ -308,7 +315,79 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
return s + (0 * b) ;
}
/**
* Holt winters (triple exponential) moving avg
* @param window Window of values to compute movavg for
* @return
*/
private double holtWinters(Collection<Double> window) {
// Smoothed value
double s = 0;
double last_s = 0;
// Trend value
double b = 0;
double last_b = 0;
// Seasonal value
double[] seasonal = new double[window.size()];
double padding = seasonalityType.equals(HoltWintersModel.SeasonalityType.MULTIPLICATIVE) ? 0.0000000001 : 0;
int counter = 0;
double[] vs = new double[window.size()];
for (double v : window) {
vs[counter] = v + padding;
counter += 1;
}
// Initial level value is average of first season
// Calculate the slopes between first and second season for each period
for (int i = 0; i < period; i++) {
s += vs[i];
b += (vs[i] - vs[i + period]) / 2;
}
s /= (double) period;
b /= (double) period;
last_s = s;
last_b = b;
// Calculate first seasonal
if (Double.compare(s, 0.0) == 0 || Double.compare(s, -0.0) == 0) {
Arrays.fill(seasonal, 0.0);
} else {
for (int i = 0; i < period; i++) {
seasonal[i] = vs[i] / s;
}
}
for (int i = period; i < vs.length; i++) {
if (seasonalityType.equals(HoltWintersModel.SeasonalityType.MULTIPLICATIVE)) {
s = alpha * (vs[i] / seasonal[i - period]) + (1.0d - alpha) * (last_s + last_b);
} else {
s = alpha * (vs[i] - seasonal[i - period]) + (1.0d - alpha) * (last_s + last_b);
}
b = beta * (s - last_s) + (1 - beta) * last_b;
if (seasonalityType.equals(HoltWintersModel.SeasonalityType.MULTIPLICATIVE)) {
seasonal[i] = gamma * (vs[i] / (last_s + last_b )) + (1 - gamma) * seasonal[i - period];
} else {
seasonal[i] = gamma * (vs[i] - (last_s + last_b )) + (1 - gamma) * seasonal[i - period];
}
last_s = s;
last_b = b;
}
int seasonCounter = (window.size() - 1) - period;
if (seasonalityType.equals(HoltWintersModel.SeasonalityType.MULTIPLICATIVE)) {
return s + (0 * b) * seasonal[seasonCounter % window.size()];
} else {
return s + (0 * b) + seasonal[seasonCounter % window.size()];
}
}
/**
@ -522,6 +601,60 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
}
}
@Test
public void HoltWintersValuedField() {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, (long) (interval * (numBuckets - 1)))
.subAggregation(metric)
.subAggregation(movingAvg("movavg_counts")
.window(windowSize)
.modelBuilder(new HoltWintersModel.HoltWintersModelBuilder()
.alpha(alpha).beta(beta).gamma(gamma).period(period).seasonalityType(seasonalityType))
.gapPolicy(gapPolicy)
.setBucketsPaths("_count"))
.subAggregation(movingAvg("movavg_values")
.window(windowSize)
.modelBuilder(new HoltWintersModel.HoltWintersModelBuilder()
.alpha(alpha).beta(beta).gamma(gamma).period(period).seasonalityType(seasonalityType))
.gapPolicy(gapPolicy)
.setBucketsPaths("the_metric"))
).execute().actionGet();
assertSearchResponse(response);
InternalHistogram<Bucket> histo = response.getAggregations().get("histo");
assertThat(histo, notNullValue());
assertThat(histo.getName(), equalTo("histo"));
List<? extends Bucket> buckets = histo.getBuckets();
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(mockHisto.size()));
List<Double> expectedCounts = testValues.get(MovAvgType.HOLT_WINTERS.toString() + "_" + MetricTarget.COUNT.toString());
List<Double> expectedValues = testValues.get(MovAvgType.HOLT_WINTERS.toString() + "_" + MetricTarget.VALUE.toString());
Iterator<? extends Histogram.Bucket> actualIter = buckets.iterator();
Iterator<PipelineAggregationHelperTests.MockBucket> expectedBucketIter = mockHisto.iterator();
Iterator<Double> expectedCountsIter = expectedCounts.iterator();
Iterator<Double> expectedValuesIter = expectedValues.iterator();
while (actualIter.hasNext()) {
assertValidIterators(expectedBucketIter, expectedCountsIter, expectedValuesIter);
Histogram.Bucket actual = actualIter.next();
PipelineAggregationHelperTests.MockBucket expected = expectedBucketIter.next();
Double expectedCount = expectedCountsIter.next();
Double expectedValue = expectedValuesIter.next();
assertThat("keys do not match", ((Number) actual.getKey()).longValue(), equalTo(expected.key));
assertThat("doc counts do not match", actual.getDocCount(), equalTo((long)expected.count));
assertBucketContents(actual, expectedCount, expectedValue);
}
}
@Test
public void testPredictNegativeKeysAtStart() {
@ -572,6 +705,7 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
}
}
@Test
public void testSizeZeroWindow() {
try {
@ -1070,6 +1204,55 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
}
}
@Test
public void testHoltWintersNotEnoughData() {
try {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, (long) (interval * (numBuckets - 1)))
.subAggregation(metric)
.subAggregation(movingAvg("movavg_counts")
.window(10)
.modelBuilder(new HoltWintersModel.HoltWintersModelBuilder()
.alpha(alpha).beta(beta).gamma(gamma).period(20).seasonalityType(seasonalityType))
.gapPolicy(gapPolicy)
.setBucketsPaths("_count"))
.subAggregation(movingAvg("movavg_values")
.window(windowSize)
.modelBuilder(new HoltWintersModel.HoltWintersModelBuilder()
.alpha(alpha).beta(beta).gamma(gamma).period(20).seasonalityType(seasonalityType))
.gapPolicy(gapPolicy)
.setBucketsPaths("the_metric"))
).execute().actionGet();
} catch (SearchPhaseExecutionException e) {
// All good
}
}
@Test
public void testBadModelParams() {
try {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, (long) (interval * (numBuckets - 1)))
.subAggregation(metric)
.subAggregation(movingAvg("movavg_counts")
.window(10)
.modelBuilder(randomModelBuilder(100))
.gapPolicy(gapPolicy)
.setBucketsPaths("_count"))
).execute().actionGet();
} catch (SearchPhaseExecutionException e) {
// All good
}
}
private void assertValidIterators(Iterator expectedBucketIter, Iterator expectedCountsIter, Iterator expectedValuesIter) {
if (!expectedBucketIter.hasNext()) {
@ -1088,6 +1271,8 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
SimpleValue countMovAvg = actual.getAggregations().get("movavg_counts");
if (expectedCount == null) {
assertThat("[_count] movavg is not null", countMovAvg, nullValue());
} else if (Double.isNaN(expectedCount)) {
assertThat("[_count] movavg should be NaN, but is ["+countMovAvg.value()+"] instead", countMovAvg.value(), equalTo(Double.NaN));
} else {
assertThat("[_count] movavg is null", countMovAvg, notNullValue());
assertThat("[_count] movavg does not match expected ["+countMovAvg.value()+" vs "+expectedCount+"]",
@ -1098,6 +1283,8 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
SimpleValue valuesMovAvg = actual.getAggregations().get("movavg_values");
if (expectedValue == null) {
assertThat("[value] movavg is not null", valuesMovAvg, Matchers.nullValue());
} else if (Double.isNaN(expectedValue)) {
assertThat("[value] movavg should be NaN, but is ["+valuesMovAvg.value()+"] instead", valuesMovAvg.value(), equalTo(Double.NaN));
} else {
assertThat("[value] movavg is null", valuesMovAvg, notNullValue());
assertThat("[value] movavg does not match expected ["+valuesMovAvg.value()+" vs "+expectedValue+"]",
@ -1106,17 +1293,24 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
}
private MovAvgModelBuilder randomModelBuilder() {
return randomModelBuilder(0);
}
private MovAvgModelBuilder randomModelBuilder(double padding) {
int rand = randomIntBetween(0,3);
// HoltWinters is excluded from random generation, because it's "cold start" behavior makes
// randomized testing too tricky. Should probably add dedicated, randomized tests just for HoltWinters,
// which can compensate for the idiosyncrasies
switch (rand) {
case 0:
return new SimpleModel.SimpleModelBuilder();
case 1:
return new LinearModel.LinearModelBuilder();
case 2:
return new EwmaModel.EWMAModelBuilder().alpha(alpha);
return new EwmaModel.EWMAModelBuilder().alpha(alpha + padding);
case 3:
return new HoltLinearModel.HoltLinearModelBuilder().alpha(alpha).beta(beta);
return new HoltLinearModel.HoltLinearModelBuilder().alpha(alpha + padding).beta(beta + padding);
default:
return new SimpleModel.SimpleModelBuilder();
}

View File

@ -28,6 +28,8 @@ import static org.hamcrest.Matchers.equalTo;
import org.junit.Test;
import java.util.Arrays;
public class MovAvgUnitTests extends ElasticsearchTestCase {
@Test
@ -259,7 +261,7 @@ public class MovAvgUnitTests extends ElasticsearchTestCase {
MovAvgModel model = new HoltLinearModel(alpha, beta);
int windowSize = randomIntBetween(1, 50);
int numPredictions = randomIntBetween(1,50);
int numPredictions = randomIntBetween(1, 50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < windowSize; i++) {
@ -297,4 +299,288 @@ public class MovAvgUnitTests extends ElasticsearchTestCase {
assertThat(Double.compare(expected[i], actual[i]), equalTo(0));
}
}
@Test
public void testHoltWintersMultiplicativePadModel() {
double alpha = randomDouble();
double beta = randomDouble();
double gamma = randomDouble();
int period = randomIntBetween(1,10);
MovAvgModel model = new HoltWintersModel(alpha, beta, gamma, period, HoltWintersModel.SeasonalityType.MULTIPLICATIVE, true);
int windowSize = randomIntBetween(period * 2, 50); // HW requires at least two periods of data
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < windowSize; i++) {
window.offer(randomDouble());
}
// Smoothed value
double s = 0;
double last_s = 0;
// Trend value
double b = 0;
double last_b = 0;
// Seasonal value
double[] seasonal = new double[windowSize];
int counter = 0;
double[] vs = new double[windowSize];
for (double v : window) {
vs[counter] = v + 0.0000000001;
counter += 1;
}
// Initial level value is average of first season
// Calculate the slopes between first and second season for each period
for (int i = 0; i < period; i++) {
s += vs[i];
b += (vs[i] - vs[i + period]) / 2;
}
s /= (double) period;
b /= (double) period;
last_s = s;
last_b = b;
// Calculate first seasonal
if (Double.compare(s, 0.0) == 0 || Double.compare(s, -0.0) == 0) {
Arrays.fill(seasonal, 0.0);
} else {
for (int i = 0; i < period; i++) {
seasonal[i] = vs[i] / s;
}
}
for (int i = period; i < vs.length; i++) {
s = alpha * (vs[i] / seasonal[i - period]) + (1.0d - alpha) * (last_s + last_b);
b = beta * (s - last_s) + (1 - beta) * last_b;
//seasonal[i] = gamma * (vs[i] / s) + ((1 - gamma) * seasonal[i - period]);
seasonal[i] = gamma * (vs[i] / (last_s + last_b )) + (1 - gamma) * seasonal[i - period];
last_s = s;
last_b = b;
}
int seasonCounter = (windowSize - 1) - period;
double expected = s + (0 * b) * seasonal[seasonCounter % windowSize];;
double actual = model.next(window);
assertThat(Double.compare(expected, actual), equalTo(0));
}
@Test
public void testHoltWintersMultiplicativePadPredictionModel() {
double alpha = randomDouble();
double beta = randomDouble();
double gamma = randomDouble();
int period = randomIntBetween(1,10);
MovAvgModel model = new HoltWintersModel(alpha, beta, gamma, period, HoltWintersModel.SeasonalityType.MULTIPLICATIVE, true);
int windowSize = randomIntBetween(period * 2, 50); // HW requires at least two periods of data
int numPredictions = randomIntBetween(1, 50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < windowSize; i++) {
window.offer(randomDouble());
}
double actual[] = model.predict(window, numPredictions);
double expected[] = new double[numPredictions];
// Smoothed value
double s = 0;
double last_s = 0;
// Trend value
double b = 0;
double last_b = 0;
// Seasonal value
double[] seasonal = new double[windowSize];
int counter = 0;
double[] vs = new double[windowSize];
for (double v : window) {
vs[counter] = v + 0.0000000001;
counter += 1;
}
// Initial level value is average of first season
// Calculate the slopes between first and second season for each period
for (int i = 0; i < period; i++) {
s += vs[i];
b += (vs[i] - vs[i + period]) / 2;
}
s /= (double) period;
b /= (double) period;
last_s = s;
last_b = b;
for (int i = 0; i < period; i++) {
// Calculate first seasonal
seasonal[i] = vs[i] / s;
}
for (int i = period; i < vs.length; i++) {
s = alpha * (vs[i] / seasonal[i - period]) + (1.0d - alpha) * (last_s + last_b);
b = beta * (s - last_s) + (1 - beta) * last_b;
//seasonal[i] = gamma * (vs[i] / s) + ((1 - gamma) * seasonal[i - period]);
seasonal[i] = gamma * (vs[i] / (last_s + last_b )) + (1 - gamma) * seasonal[i - period];
last_s = s;
last_b = b;
}
int seasonCounter = (windowSize - 1) - period;
for (int i = 0; i < numPredictions; i++) {
expected[i] = s + (i * b) * seasonal[seasonCounter % windowSize];
assertThat(Double.compare(expected[i], actual[i]), equalTo(0));
seasonCounter += 1;
}
}
@Test
public void testHoltWintersAdditiveModel() {
double alpha = randomDouble();
double beta = randomDouble();
double gamma = randomDouble();
int period = randomIntBetween(1,10);
MovAvgModel model = new HoltWintersModel(alpha, beta, gamma, period, HoltWintersModel.SeasonalityType.ADDITIVE, false);
int windowSize = randomIntBetween(period * 2, 50); // HW requires at least two periods of data
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < windowSize; i++) {
window.offer(randomDouble());
}
// Smoothed value
double s = 0;
double last_s = 0;
// Trend value
double b = 0;
double last_b = 0;
// Seasonal value
double[] seasonal = new double[windowSize];
int counter = 0;
double[] vs = new double[windowSize];
for (double v : window) {
vs[counter] = v;
counter += 1;
}
// Initial level value is average of first season
// Calculate the slopes between first and second season for each period
for (int i = 0; i < period; i++) {
s += vs[i];
b += (vs[i] - vs[i + period]) / 2;
}
s /= (double) period;
b /= (double) period;
last_s = s;
last_b = b;
for (int i = 0; i < period; i++) {
// Calculate first seasonal
seasonal[i] = vs[i] / s;
}
for (int i = period; i < vs.length; i++) {
s = alpha * (vs[i] - seasonal[i - period]) + (1.0d - alpha) * (last_s + last_b);
b = beta * (s - last_s) + (1 - beta) * last_b;
//seasonal[i] = gamma * (vs[i] / s) + ((1 - gamma) * seasonal[i - period]);
seasonal[i] = gamma * (vs[i] - (last_s + last_b )) + (1 - gamma) * seasonal[i - period];
last_s = s;
last_b = b;
}
int seasonCounter = (windowSize - 1) - period;
double expected = s + (0 * b) + seasonal[seasonCounter % windowSize];;
double actual = model.next(window);
assertThat(Double.compare(expected, actual), equalTo(0));
}
@Test
public void testHoltWintersAdditivePredictionModel() {
double alpha = randomDouble();
double beta = randomDouble();
double gamma = randomDouble();
int period = randomIntBetween(1,10);
MovAvgModel model = new HoltWintersModel(alpha, beta, gamma, period, HoltWintersModel.SeasonalityType.ADDITIVE, false);
int windowSize = randomIntBetween(period * 2, 50); // HW requires at least two periods of data
int numPredictions = randomIntBetween(1, 50);
EvictingQueue<Double> window = EvictingQueue.create(windowSize);
for (int i = 0; i < windowSize; i++) {
window.offer(randomDouble());
}
double actual[] = model.predict(window, numPredictions);
double expected[] = new double[numPredictions];
// Smoothed value
double s = 0;
double last_s = 0;
// Trend value
double b = 0;
double last_b = 0;
// Seasonal value
double[] seasonal = new double[windowSize];
int counter = 0;
double[] vs = new double[windowSize];
for (double v : window) {
vs[counter] = v;
counter += 1;
}
// Initial level value is average of first season
// Calculate the slopes between first and second season for each period
for (int i = 0; i < period; i++) {
s += vs[i];
b += (vs[i] - vs[i + period]) / 2;
}
s /= (double) period;
b /= (double) period;
last_s = s;
last_b = b;
for (int i = 0; i < period; i++) {
// Calculate first seasonal
seasonal[i] = vs[i] / s;
}
for (int i = period; i < vs.length; i++) {
s = alpha * (vs[i] - seasonal[i - period]) + (1.0d - alpha) * (last_s + last_b);
b = beta * (s - last_s) + (1 - beta) * last_b;
//seasonal[i] = gamma * (vs[i] / s) + ((1 - gamma) * seasonal[i - period]);
seasonal[i] = gamma * (vs[i] - (last_s + last_b )) + (1 - gamma) * seasonal[i - period];
last_s = s;
last_b = b;
}
int seasonCounter = (windowSize - 1) - period;
for (int i = 0; i < numPredictions; i++) {
expected[i] = s + (i * b) + seasonal[seasonCounter % windowSize];
assertThat(Double.compare(expected[i], actual[i]), equalTo(0));
seasonCounter += 1;
}
}
}