diff --git a/core/src/main/java/org/elasticsearch/search/SearchService.java b/core/src/main/java/org/elasticsearch/search/SearchService.java index 6bfd3f08a33..351fc3e3c7e 100644 --- a/core/src/main/java/org/elasticsearch/search/SearchService.java +++ b/core/src/main/java/org/elasticsearch/search/SearchService.java @@ -96,7 +96,6 @@ import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext.Lifetime; import org.elasticsearch.search.internal.ShardSearchLocalRequest; import org.elasticsearch.search.internal.ShardSearchRequest; -import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QueryPhase; import org.elasticsearch.search.query.QuerySearchRequest; import org.elasticsearch.search.query.QuerySearchResult; @@ -755,7 +754,7 @@ public class SearchService extends AbstractLifecycleComponent imp // ignore } XContentLocation location = completeAggregationsParser != null ? completeAggregationsParser.getTokenLocation() : null; - throw new SearchParseException(context, "failed to parse rescore source [" + sSource + "]", location, e); + throw new SearchParseException(context, "failed to parse aggregation source [" + sSource + "]", location, e); } } if (source.suggest() != null) { diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgParser.java b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgParser.java index aeb739bd27c..70ee30d02d8 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgParser.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgParser.java @@ -28,8 +28,6 @@ import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorFactory; import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModel; import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModelParserMapper; -import org.elasticsearch.search.aggregations.support.format.ValueFormat; -import org.elasticsearch.search.aggregations.support.format.ValueFormatter; import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; @@ -65,11 +63,11 @@ public class MovAvgParser implements PipelineAggregator.Parser { String[] bucketsPaths = null; String format = null; - GapPolicy gapPolicy = GapPolicy.SKIP; - int window = 5; + GapPolicy gapPolicy = null; + Integer window = null; Map settings = null; - String model = "simple"; - int predict = 0; + String model = null; + Integer predict = null; Boolean minimize = null; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { @@ -86,8 +84,8 @@ public class MovAvgParser implements PipelineAggregator.Parser { } else if (context.parseFieldMatcher().match(currentFieldName, PREDICT)) { predict = parser.intValue(); if (predict <= 0) { - throw new SearchParseException(context, "[" + currentFieldName + "] value must be a positive, " - + "non-zero integer. Value supplied was [" + predict + "] in [" + pipelineAggregatorName + "].", + throw new SearchParseException(context, "[" + currentFieldName + "] value must be a positive integer." + + " Value supplied was [" + predict + "] in [" + pipelineAggregatorName + "].", parser.getTokenLocation()); } } else { @@ -144,43 +142,44 @@ public class MovAvgParser implements PipelineAggregator.Parser { + "] for movingAvg aggregation [" + pipelineAggregatorName + "]", parser.getTokenLocation()); } - ValueFormatter formatter = null; + MovAvgPipelineAggregator.Factory factory = new MovAvgPipelineAggregator.Factory(pipelineAggregatorName, bucketsPaths); if (format != null) { - formatter = ValueFormat.Patternable.Number.format(format).formatter(); - } else { - formatter = ValueFormatter.RAW; + factory.format(format); } - - MovAvgModel.AbstractModelParser modelParser = movAvgModelParserMapper.get(model); - if (modelParser == null) { - throw new SearchParseException(context, "Unknown model [" + model + "] specified. Valid options are:" - + movAvgModelParserMapper.getAllNames().toString(), parser.getTokenLocation()); + if (gapPolicy != null) { + factory.gapPolicy(gapPolicy); } - - MovAvgModel movAvgModel; - try { - movAvgModel = modelParser.parse(settings, pipelineAggregatorName, window, context.parseFieldMatcher()); - } catch (ParseException exception) { - throw new SearchParseException(context, "Could not parse settings for model [" + model + "].", null, exception); + if (window != null) { + factory.window(window); } - - // If the user doesn't set a preference for cost minimization, ask what the model prefers - if (minimize == null) { - minimize = movAvgModel.minimizeByDefault(); - } else if (minimize && !movAvgModel.canBeMinimized()) { - // If the user asks to minimize, but this model doesn't support it, throw exception - throw new SearchParseException(context, "The [" + model + "] model cannot be minimized.", null); + if (predict != null) { + factory.predict(predict); } + if (model != null) { + MovAvgModel.AbstractModelParser modelParser = movAvgModelParserMapper.get(model); + if (modelParser == null) { + throw new SearchParseException(context, + "Unknown model [" + model + "] specified. Valid options are:" + movAvgModelParserMapper.getAllNames().toString(), + parser.getTokenLocation()); + } - - return new MovAvgPipelineAggregator.Factory(pipelineAggregatorName, bucketsPaths, formatter, gapPolicy, window, predict, - movAvgModel, minimize); + MovAvgModel movAvgModel; + try { + movAvgModel = modelParser.parse(settings, pipelineAggregatorName, window, context.parseFieldMatcher()); + } catch (ParseException exception) { + throw new SearchParseException(context, "Could not parse settings for model [" + model + "].", null, exception); + } + factory.model(movAvgModel); + } + if (minimize != null) { + factory.minimize(minimize); + } + return factory; } - // NORELEASE implement this method when refactoring this aggregation @Override public PipelineAggregatorFactory getFactoryPrototype() { - return null; + return new MovAvgPipelineAggregator.Factory(null, null); } } diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgPipelineAggregator.java b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgPipelineAggregator.java index 4f7034b633f..1faea425b20 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgPipelineAggregator.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgPipelineAggregator.java @@ -22,6 +22,7 @@ package org.elasticsearch.search.aggregations.pipeline.movavg; import org.elasticsearch.common.collect.EvictingQueue; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.search.aggregations.AggregationExecutionException; import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.aggregations.InternalAggregation; @@ -37,6 +38,8 @@ import org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorFactory; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorStreams; import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModel; import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModelStreams; +import org.elasticsearch.search.aggregations.pipeline.movavg.models.SimpleModel; +import org.elasticsearch.search.aggregations.support.format.ValueFormat; import org.elasticsearch.search.aggregations.support.format.ValueFormatter; import org.elasticsearch.search.aggregations.support.format.ValueFormatterStreams; import org.joda.time.DateTime; @@ -46,6 +49,7 @@ import java.util.ArrayList; import java.util.List; import java.util.ListIterator; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -276,32 +280,151 @@ public class MovAvgPipelineAggregator extends PipelineAggregator { public static class Factory extends PipelineAggregatorFactory { - private final ValueFormatter formatter; - private GapPolicy gapPolicy; - private int window; - private MovAvgModel model; - private int predict; - private boolean minimize; + private String format; + private GapPolicy gapPolicy = GapPolicy.SKIP; + private int window = 5; + private MovAvgModel model = new SimpleModel(); + private int predict = 0; + private Boolean minimize; - public Factory(String name, String[] bucketsPaths, ValueFormatter formatter, GapPolicy gapPolicy, - int window, int predict, MovAvgModel model, boolean minimize) { + public Factory(String name, String[] bucketsPaths) { super(name, TYPE.name(), bucketsPaths); - this.formatter = formatter; + } + + /** + * Sets the format to use on the output of this aggregation. + */ + public void format(String format) { + this.format = format; + } + + /** + * Gets the format to use on the output of this aggregation. + */ + public String format() { + return format; + } + + /** + * Sets the GapPolicy to use on the output of this aggregation. + */ + public void gapPolicy(GapPolicy gapPolicy) { this.gapPolicy = gapPolicy; + } + + /** + * Gets the GapPolicy to use on the output of this aggregation. + */ + public GapPolicy gapPolicy() { + return gapPolicy; + } + + protected ValueFormatter formatter() { + if (format != null) { + return ValueFormat.Patternable.Number.format(format).formatter(); + } else { + return ValueFormatter.RAW; + } + } + + /** + * Sets the window size for the moving average. This window will "slide" + * across the series, and the values inside that window will be used to + * calculate the moving avg value + * + * @param window + * Size of window + */ + public void window(int window) { this.window = window; + } + + /** + * Gets the window size for the moving average. This window will "slide" + * across the series, and the values inside that window will be used to + * calculate the moving avg value + */ + public int window() { + return window; + } + + /** + * Sets a MovAvgModel for the Moving Average. The model is used to + * define what type of moving average you want to use on the series + * + * @param model + * A MovAvgModel which has been prepopulated with settings + */ + public void model(MovAvgModel model) { this.model = model; + } + + /** + * Gets a MovAvgModel for the Moving Average. The model is used to + * define what type of moving average you want to use on the series + */ + public MovAvgModel model() { + return model; + } + + /** + * Sets the number of predictions that should be returned. Each + * prediction will be spaced at the intervals specified in the + * histogram. E.g "predict: 2" will return two new buckets at the end of + * the histogram with the predicted values. + * + * @param predict + * Number of predictions to make + */ + public void predict(int predict) { this.predict = predict; + } + + /** + * Gets the number of predictions that should be returned. Each + * prediction will be spaced at the intervals specified in the + * histogram. E.g "predict: 2" will return two new buckets at the end of + * the histogram with the predicted values. + */ + public int predict() { + return predict; + } + + /** + * Sets whether the model should be fit to the data using a cost + * minimizing algorithm. + * + * @param minimize + * If the model should be fit to the underlying data + */ + public void minimize(boolean minimize) { this.minimize = minimize; } + /** + * Gets whether the model should be fit to the data using a cost + * minimizing algorithm. + */ + public Boolean minimize() { + return minimize; + } + @Override protected PipelineAggregator createInternal(Map metaData) throws IOException { - return new MovAvgPipelineAggregator(name, bucketsPaths, formatter, gapPolicy, window, predict, model, minimize, metaData); + // If the user doesn't set a preference for cost minimization, ask + // what the model prefers + boolean minimize = this.minimize == null ? model.minimizeByDefault() : this.minimize; + return new MovAvgPipelineAggregator(name, bucketsPaths, formatter(), gapPolicy, window, predict, model, minimize, metaData); } @Override public void doValidate(AggregatorFactory parent, AggregatorFactory[] aggFactories, List pipelineAggregatoractories) { + if (minimize != null && minimize && !model.canBeMinimized()) { + // If the user asks to minimize, but this model doesn't support + // it, throw exception + throw new IllegalStateException("The [" + model + "] model cannot be minimized for aggregation [" + name + "]"); + } if (bucketsPaths.length != 1) { throw new IllegalStateException(PipelineAggregator.Parser.BUCKETS_PATH.getPreferredName() + " must contain a single entry for aggregation [" + name + "]"); @@ -318,5 +441,60 @@ public class MovAvgPipelineAggregator extends PipelineAggregator { } } + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + if (format != null) { + builder.field(MovAvgParser.FORMAT.getPreferredName(), format); + } + builder.field(MovAvgParser.GAP_POLICY.getPreferredName(), gapPolicy.getName()); + model.toXContent(builder, params); + builder.field(MovAvgParser.WINDOW.getPreferredName(), window); + if (predict > 0) { + builder.field(MovAvgParser.PREDICT.getPreferredName(), predict); + } + if (minimize != null) { + builder.field(MovAvgParser.MINIMIZE.getPreferredName(), minimize); + } + return builder; + } + + @Override + protected PipelineAggregatorFactory doReadFrom(String name, String[] bucketsPaths, StreamInput in) throws IOException { + Factory factory = new Factory(name, bucketsPaths); + factory.format = in.readOptionalString(); + factory.gapPolicy = GapPolicy.readFrom(in); + factory.window = in.readVInt(); + factory.model = MovAvgModelStreams.read(in); + factory.predict = in.readVInt(); + factory.minimize = in.readOptionalBoolean(); + return factory; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeOptionalString(format); + gapPolicy.writeTo(out); + out.writeVInt(window); + model.writeTo(out); + out.writeVInt(predict); + out.writeOptionalBoolean(minimize); + } + + @Override + protected int doHashCode() { + return Objects.hash(format, gapPolicy, window, model, predict, minimize); + } + + @Override + protected boolean doEquals(Object obj) { + Factory other = (Factory) obj; + return Objects.equals(format, other.format) + && Objects.equals(gapPolicy, other.gapPolicy) + && Objects.equals(window, other.window) + && Objects.equals(model, other.model) + && Objects.equals(predict, other.predict) + && Objects.equals(minimize, other.minimize); + } + } } diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/EwmaModel.java b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/EwmaModel.java index 84de794ceed..c424de86aa1 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/EwmaModel.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/EwmaModel.java @@ -32,13 +32,16 @@ import java.text.ParseException; import java.util.Arrays; import java.util.Collection; import java.util.Map; +import java.util.Objects; /** * Calculate a exponentially weighted moving average */ public class EwmaModel extends MovAvgModel { + private static final EwmaModel PROTOTYPE = new EwmaModel(); protected static final ParseField NAME_FIELD = new ParseField("ewma"); + public static final double DEFAULT_ALPHA = 0.3; /** * Controls smoothing of data. Also known as "level" value. @@ -48,6 +51,10 @@ public class EwmaModel extends MovAvgModel { */ private final double alpha; + public EwmaModel() { + this(DEFAULT_ALPHA); + } + public EwmaModel(double alpha) { this.alpha = alpha; } @@ -97,7 +104,7 @@ public class EwmaModel extends MovAvgModel { public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() { @Override public MovAvgModel readResult(StreamInput in) throws IOException { - return new EwmaModel(in.readDouble()); + return PROTOTYPE.readFrom(in); } @Override @@ -106,12 +113,43 @@ public class EwmaModel extends MovAvgModel { } }; + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName()); + builder.startObject(MovAvgParser.SETTINGS.getPreferredName()); + builder.field("alpha", alpha); + builder.endObject(); + return builder; + } + + @Override + public MovAvgModel readFrom(StreamInput in) throws IOException { + return new EwmaModel(in.readDouble()); + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(STREAM.getName()); out.writeDouble(alpha); } + @Override + public int hashCode() { + return Objects.hash(alpha); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + EwmaModel other = (EwmaModel) obj; + return Objects.equals(alpha, other.alpha); + } + public static class SingleExpModelParser extends AbstractModelParser { @Override @@ -123,7 +161,7 @@ public class EwmaModel extends MovAvgModel { public MovAvgModel parse(@Nullable Map settings, String pipelineName, int windowSize, ParseFieldMatcher parseFieldMatcher) throws ParseException { - double alpha = parseDoubleParam(settings, "alpha", 0.3); + double alpha = parseDoubleParam(settings, "alpha", DEFAULT_ALPHA); checkUnrecognizedParams(settings); return new EwmaModel(alpha); } diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/HoltLinearModel.java b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/HoltLinearModel.java index fe0321bf0fc..8734b71ec4e 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/HoltLinearModel.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/HoltLinearModel.java @@ -31,13 +31,17 @@ import java.io.IOException; import java.text.ParseException; import java.util.Collection; import java.util.Map; +import java.util.Objects; /** * Calculate a doubly exponential weighted moving average */ public class HoltLinearModel extends MovAvgModel { + private static final HoltLinearModel PROTOTYPE = new HoltLinearModel(); protected static final ParseField NAME_FIELD = new ParseField("holt"); + public static final double DEFAULT_ALPHA = 0.3; + public static final double DEFAULT_BETA = 0.1; /** * Controls smoothing of data. Also known as "level" value. @@ -55,6 +59,10 @@ public class HoltLinearModel extends MovAvgModel { */ private final double beta; + public HoltLinearModel() { + this(DEFAULT_ALPHA, DEFAULT_BETA); + } + public HoltLinearModel(double alpha, double beta) { this.alpha = alpha; this.beta = beta; @@ -157,7 +165,7 @@ public class HoltLinearModel extends MovAvgModel { public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() { @Override public MovAvgModel readResult(StreamInput in) throws IOException { - return new HoltLinearModel(in.readDouble(), in.readDouble()); + return PROTOTYPE.readFrom(in); } @Override @@ -166,6 +174,21 @@ public class HoltLinearModel extends MovAvgModel { } }; + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName()); + builder.startObject(MovAvgParser.SETTINGS.getPreferredName()); + builder.field("alpha", alpha); + builder.field("beta", beta); + builder.endObject(); + return builder; + } + + @Override + public MovAvgModel readFrom(StreamInput in) throws IOException { + return new HoltLinearModel(in.readDouble(), in.readDouble()); + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(STREAM.getName()); @@ -173,6 +196,24 @@ public class HoltLinearModel extends MovAvgModel { out.writeDouble(beta); } + @Override + public int hashCode() { + return Objects.hash(alpha, beta); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + HoltLinearModel other = (HoltLinearModel) obj; + return Objects.equals(alpha, other.alpha) + && Objects.equals(beta, other.beta); + } + public static class DoubleExpModelParser extends AbstractModelParser { @Override @@ -184,8 +225,8 @@ public class HoltLinearModel extends MovAvgModel { public MovAvgModel parse(@Nullable Map settings, String pipelineName, int windowSize, ParseFieldMatcher parseFieldMatcher) throws ParseException { - double alpha = parseDoubleParam(settings, "alpha", 0.3); - double beta = parseDoubleParam(settings, "beta", 0.1); + double alpha = parseDoubleParam(settings, "alpha", DEFAULT_ALPHA); + double beta = parseDoubleParam(settings, "beta", DEFAULT_BETA); checkUnrecognizedParams(settings); return new HoltLinearModel(alpha, beta); } diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/HoltWintersModel.java b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/HoltWintersModel.java index 55cf6be073c..9f5ecad4b1b 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/HoltWintersModel.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/HoltWintersModel.java @@ -37,6 +37,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Objects; /** * Calculate a triple exponential weighted moving average @@ -44,6 +45,13 @@ import java.util.Map; public class HoltWintersModel extends MovAvgModel { protected static final ParseField NAME_FIELD = new ParseField("holt_winters"); + public static final double DEFAULT_ALPHA = 0.3; + public static final double DEFAULT_BETA = 0.1; + public static final double DEFAULT_GAMMA = 0.3; + public static final int DEFAULT_PERIOD = 1; + public static final SeasonalityType DEFAULT_SEASONALITY_TYPE = SeasonalityType.ADDITIVE; + public static final boolean DEFAULT_PAD = false; + private static final HoltWintersModel PROTOTYPE = new HoltWintersModel(); /** * Controls smoothing of data. Also known as "level" value. @@ -159,6 +167,9 @@ public class HoltWintersModel extends MovAvgModel { } } + public HoltWintersModel() { + this(DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_GAMMA, DEFAULT_PERIOD, DEFAULT_SEASONALITY_TYPE, DEFAULT_PAD); + } public HoltWintersModel(double alpha, double beta, double gamma, int period, SeasonalityType seasonalityType, boolean pad) { this.alpha = alpha; @@ -273,8 +284,8 @@ public class HoltWintersModel extends MovAvgModel { s += vs[i]; b += (vs[i + period] - vs[i]) / period; } - s /= (double) period; - b /= (double) period; + s /= period; + b /= period; last_s = s; // Calculate first seasonal @@ -324,14 +335,7 @@ public class HoltWintersModel extends MovAvgModel { public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() { @Override public MovAvgModel readResult(StreamInput in) throws IOException { - double alpha = in.readDouble(); - double beta = in.readDouble(); - double gamma = in.readDouble(); - int period = in.readVInt(); - SeasonalityType type = SeasonalityType.readFrom(in); - boolean pad = in.readBoolean(); - - return new HoltWintersModel(alpha, beta, gamma, period, type, pad); + return PROTOTYPE.readFrom(in); } @Override @@ -340,6 +344,26 @@ public class HoltWintersModel extends MovAvgModel { } }; + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName()); + builder.startObject(MovAvgParser.SETTINGS.getPreferredName()); + builder.field("alpha", alpha); + builder.field("beta", beta); + builder.field("gamma", gamma); + builder.field("period", period); + builder.field("pad", pad); + builder.field("type", seasonalityType.getName()); + builder.endObject(); + return builder; + } + + @Override + public MovAvgModel readFrom(StreamInput in) throws IOException { + return new HoltWintersModel(in.readDouble(), in.readDouble(), in.readDouble(), in.readVInt(), SeasonalityType.readFrom(in), + in.readBoolean()); + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(STREAM.getName()); @@ -351,6 +375,28 @@ public class HoltWintersModel extends MovAvgModel { out.writeBoolean(pad); } + @Override + public int hashCode() { + return Objects.hash(alpha, beta, gamma, period, seasonalityType, pad); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + HoltWintersModel other = (HoltWintersModel) obj; + return Objects.equals(alpha, other.alpha) + && Objects.equals(beta, other.beta) + && Objects.equals(gamma, other.gamma) + && Objects.equals(period, other.period) + && Objects.equals(seasonalityType, other.seasonalityType) + && Objects.equals(pad, other.pad); + } + public static class HoltWintersModelParser extends AbstractModelParser { @Override @@ -362,10 +408,10 @@ public class HoltWintersModel extends MovAvgModel { public MovAvgModel parse(@Nullable Map settings, String pipelineName, int windowSize, ParseFieldMatcher parseFieldMatcher) throws ParseException { - double alpha = parseDoubleParam(settings, "alpha", 0.3); - double beta = parseDoubleParam(settings, "beta", 0.1); - double gamma = parseDoubleParam(settings, "gamma", 0.3); - int period = parseIntegerParam(settings, "period", 1); + double alpha = parseDoubleParam(settings, "alpha", DEFAULT_ALPHA); + double beta = parseDoubleParam(settings, "beta", DEFAULT_BETA); + double gamma = parseDoubleParam(settings, "gamma", DEFAULT_GAMMA); + int period = parseIntegerParam(settings, "period", DEFAULT_PERIOD); if (windowSize < 2 * period) { throw new ParseException("Field [window] must be at least twice as large as the period when " + @@ -373,7 +419,7 @@ public class HoltWintersModel extends MovAvgModel { + (2 * period), 0); } - SeasonalityType seasonalityType = SeasonalityType.ADDITIVE; + SeasonalityType seasonalityType = DEFAULT_SEASONALITY_TYPE; if (settings != null) { Object value = settings.get("type"); diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/LinearModel.java b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/LinearModel.java index 264a42509b7..a5dfddf3e90 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/LinearModel.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/LinearModel.java @@ -40,6 +40,7 @@ import java.util.Map; */ public class LinearModel extends MovAvgModel { + private static final LinearModel PROTOTYPE = new LinearModel(); protected static final ParseField NAME_FIELD = new ParseField("linear"); @@ -85,7 +86,7 @@ public class LinearModel extends MovAvgModel { public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() { @Override public MovAvgModel readResult(StreamInput in) throws IOException { - return new LinearModel(); + return PROTOTYPE.readFrom(in); } @Override @@ -94,6 +95,17 @@ public class LinearModel extends MovAvgModel { } }; + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName()); + return builder; + } + + @Override + public MovAvgModel readFrom(StreamInput in) throws IOException { + return new LinearModel(); + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(STREAM.getName()); @@ -121,4 +133,20 @@ public class LinearModel extends MovAvgModel { return builder; } } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + return true; + } } diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/MovAvgModel.java b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/MovAvgModel.java index 4bfac9d44cb..92f46156618 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/MovAvgModel.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/MovAvgModel.java @@ -22,6 +22,8 @@ package org.elasticsearch.search.aggregations.pipeline.movavg.models; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseFieldMatcher; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; import java.io.IOException; import java.text.ParseException; @@ -29,7 +31,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Map; -public abstract class MovAvgModel { +public abstract class MovAvgModel implements Writeable, ToXContent { /** * Should this model be fit to the data via a cost minimizing algorithm by default? @@ -116,13 +118,21 @@ public abstract class MovAvgModel { * * @param out Output stream */ + @Override public abstract void writeTo(StreamOutput out) throws IOException; /** * Clone the model, returning an exact copy */ + @Override public abstract MovAvgModel clone(); + @Override + public abstract int hashCode(); + + @Override + public abstract boolean equals(Object obj); + /** * Abstract class which also provides some concrete parsing functionality. */ diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/SimpleModel.java b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/SimpleModel.java index e0c7781ec4a..619654e44f1 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/SimpleModel.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/models/SimpleModel.java @@ -38,6 +38,7 @@ import java.util.Map; */ public class SimpleModel extends MovAvgModel { + private static final SimpleModel PROTOTYPE = new SimpleModel(); protected static final ParseField NAME_FIELD = new ParseField("simple"); @@ -78,7 +79,7 @@ public class SimpleModel extends MovAvgModel { public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() { @Override public MovAvgModel readResult(StreamInput in) throws IOException { - return new SimpleModel(); + return PROTOTYPE.readFrom(in); } @Override @@ -87,6 +88,17 @@ public class SimpleModel extends MovAvgModel { } }; + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName()); + return builder; + } + + @Override + public MovAvgModel readFrom(StreamInput in) throws IOException { + return new SimpleModel(); + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(STREAM.getName()); @@ -114,4 +126,20 @@ public class SimpleModel extends MovAvgModel { return builder; } } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + return true; + } } diff --git a/core/src/test/java/org/elasticsearch/search/aggregations/pipeline/moving/avg/MovAvgTests.java b/core/src/test/java/org/elasticsearch/search/aggregations/pipeline/moving/avg/MovAvgTests.java new file mode 100644 index 00000000000..6767a305577 --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/aggregations/pipeline/moving/avg/MovAvgTests.java @@ -0,0 +1,96 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.pipeline.moving.avg; + +import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase; +import org.elasticsearch.search.aggregations.pipeline.BucketHelpers.GapPolicy; +import org.elasticsearch.search.aggregations.pipeline.movavg.MovAvgPipelineAggregator; +import org.elasticsearch.search.aggregations.pipeline.movavg.MovAvgPipelineAggregator.Factory; +import org.elasticsearch.search.aggregations.pipeline.movavg.models.EwmaModel; +import org.elasticsearch.search.aggregations.pipeline.movavg.models.HoltLinearModel; +import org.elasticsearch.search.aggregations.pipeline.movavg.models.HoltWintersModel; +import org.elasticsearch.search.aggregations.pipeline.movavg.models.HoltWintersModel.SeasonalityType; +import org.elasticsearch.search.aggregations.pipeline.movavg.models.LinearModel; +import org.elasticsearch.search.aggregations.pipeline.movavg.models.SimpleModel;; + +public class MovAvgTests extends BasePipelineAggregationTestCase { + + @Override + protected Factory createTestAggregatorFactory() { + String name = randomAsciiOfLengthBetween(3, 20); + String[] bucketsPaths = new String[1]; + bucketsPaths[0] = randomAsciiOfLengthBetween(3, 20); + Factory factory = new Factory(name, bucketsPaths); + if (randomBoolean()) { + factory.format(randomAsciiOfLengthBetween(1, 10)); + } + if (randomBoolean()) { + factory.gapPolicy(randomFrom(GapPolicy.values())); + } + if (randomBoolean()) { + switch (randomInt(4)) { + case 0: + factory.model(new SimpleModel()); + factory.window(randomIntBetween(1, 100)); + break; + case 1: + factory.model(new LinearModel()); + factory.window(randomIntBetween(1, 100)); + break; + case 2: + if (randomBoolean()) { + factory.model(new EwmaModel()); + factory.window(randomIntBetween(1, 100)); + } else { + factory.model(new EwmaModel(randomDouble())); + factory.window(randomIntBetween(1, 100)); + } + break; + case 3: + if (randomBoolean()) { + factory.model(new HoltLinearModel()); + factory.window(randomIntBetween(1, 100)); + } else { + factory.model(new HoltLinearModel(randomDouble(), randomDouble())); + factory.window(randomIntBetween(1, 100)); + } + break; + case 4: + default: + if (randomBoolean()) { + factory.model(new HoltWintersModel()); + factory.window(randomIntBetween(2, 100)); + } else { + int period = randomIntBetween(1, 100); + factory.model(new HoltWintersModel(randomDouble(), randomDouble(), randomDouble(), period, + randomFrom(SeasonalityType.values()), randomBoolean())); + factory.window(randomIntBetween(2 * period, 200 * period)); + } + break; + } + } + factory.predict(randomIntBetween(1, 50)); + if (factory.model().canBeMinimized() && randomBoolean()) { + factory.minimize(randomBoolean()); + } + return factory; + } + +}