Aggregations Refactor: Refactor Moving Average Aggregation

This commit is contained in:
Colin Goodheart-Smithe 2015-12-02 11:51:46 +00:00
parent 80e58e32a4
commit 2b5aa09ccf
10 changed files with 533 additions and 70 deletions

View File

@ -96,7 +96,6 @@ import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.SearchContext.Lifetime;
import org.elasticsearch.search.internal.ShardSearchLocalRequest;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.search.profile.Profilers;
import org.elasticsearch.search.query.QueryPhase;
import org.elasticsearch.search.query.QuerySearchRequest;
import org.elasticsearch.search.query.QuerySearchResult;
@ -755,7 +754,7 @@ public class SearchService extends AbstractLifecycleComponent<SearchService> imp
// ignore
}
XContentLocation location = completeAggregationsParser != null ? completeAggregationsParser.getTokenLocation() : null;
throw new SearchParseException(context, "failed to parse rescore source [" + sSource + "]", location, e);
throw new SearchParseException(context, "failed to parse aggregation source [" + sSource + "]", location, e);
}
}
if (source.suggest() != null) {

View File

@ -28,8 +28,6 @@ import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorFactory;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModelParserMapper;
import org.elasticsearch.search.aggregations.support.format.ValueFormat;
import org.elasticsearch.search.aggregations.support.format.ValueFormatter;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
@ -65,11 +63,11 @@ public class MovAvgParser implements PipelineAggregator.Parser {
String[] bucketsPaths = null;
String format = null;
GapPolicy gapPolicy = GapPolicy.SKIP;
int window = 5;
GapPolicy gapPolicy = null;
Integer window = null;
Map<String, Object> settings = null;
String model = "simple";
int predict = 0;
String model = null;
Integer predict = null;
Boolean minimize = null;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
@ -86,8 +84,8 @@ public class MovAvgParser implements PipelineAggregator.Parser {
} else if (context.parseFieldMatcher().match(currentFieldName, PREDICT)) {
predict = parser.intValue();
if (predict <= 0) {
throw new SearchParseException(context, "[" + currentFieldName + "] value must be a positive, "
+ "non-zero integer. Value supplied was [" + predict + "] in [" + pipelineAggregatorName + "].",
throw new SearchParseException(context, "[" + currentFieldName + "] value must be a positive integer."
+ " Value supplied was [" + predict + "] in [" + pipelineAggregatorName + "].",
parser.getTokenLocation());
}
} else {
@ -144,17 +142,25 @@ public class MovAvgParser implements PipelineAggregator.Parser {
+ "] for movingAvg aggregation [" + pipelineAggregatorName + "]", parser.getTokenLocation());
}
ValueFormatter formatter = null;
MovAvgPipelineAggregator.Factory factory = new MovAvgPipelineAggregator.Factory(pipelineAggregatorName, bucketsPaths);
if (format != null) {
formatter = ValueFormat.Patternable.Number.format(format).formatter();
} else {
formatter = ValueFormatter.RAW;
factory.format(format);
}
if (gapPolicy != null) {
factory.gapPolicy(gapPolicy);
}
if (window != null) {
factory.window(window);
}
if (predict != null) {
factory.predict(predict);
}
if (model != null) {
MovAvgModel.AbstractModelParser modelParser = movAvgModelParserMapper.get(model);
if (modelParser == null) {
throw new SearchParseException(context, "Unknown model [" + model + "] specified. Valid options are:"
+ movAvgModelParserMapper.getAllNames().toString(), parser.getTokenLocation());
throw new SearchParseException(context,
"Unknown model [" + model + "] specified. Valid options are:" + movAvgModelParserMapper.getAllNames().toString(),
parser.getTokenLocation());
}
MovAvgModel movAvgModel;
@ -163,24 +169,17 @@ public class MovAvgParser implements PipelineAggregator.Parser {
} catch (ParseException exception) {
throw new SearchParseException(context, "Could not parse settings for model [" + model + "].", null, exception);
}
// If the user doesn't set a preference for cost minimization, ask what the model prefers
if (minimize == null) {
minimize = movAvgModel.minimizeByDefault();
} else if (minimize && !movAvgModel.canBeMinimized()) {
// If the user asks to minimize, but this model doesn't support it, throw exception
throw new SearchParseException(context, "The [" + model + "] model cannot be minimized.", null);
factory.model(movAvgModel);
}
if (minimize != null) {
factory.minimize(minimize);
}
return factory;
}
return new MovAvgPipelineAggregator.Factory(pipelineAggregatorName, bucketsPaths, formatter, gapPolicy, window, predict,
movAvgModel, minimize);
}
// NORELEASE implement this method when refactoring this aggregation
@Override
public PipelineAggregatorFactory getFactoryPrototype() {
return null;
return new MovAvgPipelineAggregator.Factory(null, null);
}
}

View File

@ -22,6 +22,7 @@ package org.elasticsearch.search.aggregations.pipeline.movavg;
import org.elasticsearch.common.collect.EvictingQueue;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.InternalAggregation;
@ -37,6 +38,8 @@ import org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorFactory;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorStreams;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.MovAvgModelStreams;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.SimpleModel;
import org.elasticsearch.search.aggregations.support.format.ValueFormat;
import org.elasticsearch.search.aggregations.support.format.ValueFormatter;
import org.elasticsearch.search.aggregations.support.format.ValueFormatterStreams;
import org.joda.time.DateTime;
@ -46,6 +49,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
@ -276,32 +280,151 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
public static class Factory extends PipelineAggregatorFactory {
private final ValueFormatter formatter;
private GapPolicy gapPolicy;
private int window;
private MovAvgModel model;
private int predict;
private boolean minimize;
private String format;
private GapPolicy gapPolicy = GapPolicy.SKIP;
private int window = 5;
private MovAvgModel model = new SimpleModel();
private int predict = 0;
private Boolean minimize;
public Factory(String name, String[] bucketsPaths, ValueFormatter formatter, GapPolicy gapPolicy,
int window, int predict, MovAvgModel model, boolean minimize) {
public Factory(String name, String[] bucketsPaths) {
super(name, TYPE.name(), bucketsPaths);
this.formatter = formatter;
}
/**
* Sets the format to use on the output of this aggregation.
*/
public void format(String format) {
this.format = format;
}
/**
* Gets the format to use on the output of this aggregation.
*/
public String format() {
return format;
}
/**
* Sets the GapPolicy to use on the output of this aggregation.
*/
public void gapPolicy(GapPolicy gapPolicy) {
this.gapPolicy = gapPolicy;
}
/**
* Gets the GapPolicy to use on the output of this aggregation.
*/
public GapPolicy gapPolicy() {
return gapPolicy;
}
protected ValueFormatter formatter() {
if (format != null) {
return ValueFormat.Patternable.Number.format(format).formatter();
} else {
return ValueFormatter.RAW;
}
}
/**
* Sets the window size for the moving average. This window will "slide"
* across the series, and the values inside that window will be used to
* calculate the moving avg value
*
* @param window
* Size of window
*/
public void window(int window) {
this.window = window;
}
/**
* Gets the window size for the moving average. This window will "slide"
* across the series, and the values inside that window will be used to
* calculate the moving avg value
*/
public int window() {
return window;
}
/**
* Sets a MovAvgModel for the Moving Average. The model is used to
* define what type of moving average you want to use on the series
*
* @param model
* A MovAvgModel which has been prepopulated with settings
*/
public void model(MovAvgModel model) {
this.model = model;
}
/**
* Gets a MovAvgModel for the Moving Average. The model is used to
* define what type of moving average you want to use on the series
*/
public MovAvgModel model() {
return model;
}
/**
* Sets the number of predictions that should be returned. Each
* prediction will be spaced at the intervals specified in the
* histogram. E.g "predict: 2" will return two new buckets at the end of
* the histogram with the predicted values.
*
* @param predict
* Number of predictions to make
*/
public void predict(int predict) {
this.predict = predict;
}
/**
* Gets the number of predictions that should be returned. Each
* prediction will be spaced at the intervals specified in the
* histogram. E.g "predict: 2" will return two new buckets at the end of
* the histogram with the predicted values.
*/
public int predict() {
return predict;
}
/**
* Sets whether the model should be fit to the data using a cost
* minimizing algorithm.
*
* @param minimize
* If the model should be fit to the underlying data
*/
public void minimize(boolean minimize) {
this.minimize = minimize;
}
/**
* Gets whether the model should be fit to the data using a cost
* minimizing algorithm.
*/
public Boolean minimize() {
return minimize;
}
@Override
protected PipelineAggregator createInternal(Map<String, Object> metaData) throws IOException {
return new MovAvgPipelineAggregator(name, bucketsPaths, formatter, gapPolicy, window, predict, model, minimize, metaData);
// If the user doesn't set a preference for cost minimization, ask
// what the model prefers
boolean minimize = this.minimize == null ? model.minimizeByDefault() : this.minimize;
return new MovAvgPipelineAggregator(name, bucketsPaths, formatter(), gapPolicy, window, predict, model, minimize, metaData);
}
@Override
public void doValidate(AggregatorFactory parent, AggregatorFactory[] aggFactories,
List<PipelineAggregatorFactory> pipelineAggregatoractories) {
if (minimize != null && minimize && !model.canBeMinimized()) {
// If the user asks to minimize, but this model doesn't support
// it, throw exception
throw new IllegalStateException("The [" + model + "] model cannot be minimized for aggregation [" + name + "]");
}
if (bucketsPaths.length != 1) {
throw new IllegalStateException(PipelineAggregator.Parser.BUCKETS_PATH.getPreferredName()
+ " must contain a single entry for aggregation [" + name + "]");
@ -318,5 +441,60 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
}
}
@Override
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
if (format != null) {
builder.field(MovAvgParser.FORMAT.getPreferredName(), format);
}
builder.field(MovAvgParser.GAP_POLICY.getPreferredName(), gapPolicy.getName());
model.toXContent(builder, params);
builder.field(MovAvgParser.WINDOW.getPreferredName(), window);
if (predict > 0) {
builder.field(MovAvgParser.PREDICT.getPreferredName(), predict);
}
if (minimize != null) {
builder.field(MovAvgParser.MINIMIZE.getPreferredName(), minimize);
}
return builder;
}
@Override
protected PipelineAggregatorFactory doReadFrom(String name, String[] bucketsPaths, StreamInput in) throws IOException {
Factory factory = new Factory(name, bucketsPaths);
factory.format = in.readOptionalString();
factory.gapPolicy = GapPolicy.readFrom(in);
factory.window = in.readVInt();
factory.model = MovAvgModelStreams.read(in);
factory.predict = in.readVInt();
factory.minimize = in.readOptionalBoolean();
return factory;
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalString(format);
gapPolicy.writeTo(out);
out.writeVInt(window);
model.writeTo(out);
out.writeVInt(predict);
out.writeOptionalBoolean(minimize);
}
@Override
protected int doHashCode() {
return Objects.hash(format, gapPolicy, window, model, predict, minimize);
}
@Override
protected boolean doEquals(Object obj) {
Factory other = (Factory) obj;
return Objects.equals(format, other.format)
&& Objects.equals(gapPolicy, other.gapPolicy)
&& Objects.equals(window, other.window)
&& Objects.equals(model, other.model)
&& Objects.equals(predict, other.predict)
&& Objects.equals(minimize, other.minimize);
}
}
}

View File

@ -32,13 +32,16 @@ import java.text.ParseException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
/**
* Calculate a exponentially weighted moving average
*/
public class EwmaModel extends MovAvgModel {
private static final EwmaModel PROTOTYPE = new EwmaModel();
protected static final ParseField NAME_FIELD = new ParseField("ewma");
public static final double DEFAULT_ALPHA = 0.3;
/**
* Controls smoothing of data. Also known as "level" value.
@ -48,6 +51,10 @@ public class EwmaModel extends MovAvgModel {
*/
private final double alpha;
public EwmaModel() {
this(DEFAULT_ALPHA);
}
public EwmaModel(double alpha) {
this.alpha = alpha;
}
@ -97,7 +104,7 @@ public class EwmaModel extends MovAvgModel {
public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() {
@Override
public MovAvgModel readResult(StreamInput in) throws IOException {
return new EwmaModel(in.readDouble());
return PROTOTYPE.readFrom(in);
}
@Override
@ -106,12 +113,43 @@ public class EwmaModel extends MovAvgModel {
}
};
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName());
builder.startObject(MovAvgParser.SETTINGS.getPreferredName());
builder.field("alpha", alpha);
builder.endObject();
return builder;
}
@Override
public MovAvgModel readFrom(StreamInput in) throws IOException {
return new EwmaModel(in.readDouble());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(STREAM.getName());
out.writeDouble(alpha);
}
@Override
public int hashCode() {
return Objects.hash(alpha);
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
EwmaModel other = (EwmaModel) obj;
return Objects.equals(alpha, other.alpha);
}
public static class SingleExpModelParser extends AbstractModelParser {
@Override
@ -123,7 +161,7 @@ public class EwmaModel extends MovAvgModel {
public MovAvgModel parse(@Nullable Map<String, Object> settings, String pipelineName, int windowSize,
ParseFieldMatcher parseFieldMatcher) throws ParseException {
double alpha = parseDoubleParam(settings, "alpha", 0.3);
double alpha = parseDoubleParam(settings, "alpha", DEFAULT_ALPHA);
checkUnrecognizedParams(settings);
return new EwmaModel(alpha);
}

View File

@ -31,13 +31,17 @@ import java.io.IOException;
import java.text.ParseException;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
/**
* Calculate a doubly exponential weighted moving average
*/
public class HoltLinearModel extends MovAvgModel {
private static final HoltLinearModel PROTOTYPE = new HoltLinearModel();
protected static final ParseField NAME_FIELD = new ParseField("holt");
public static final double DEFAULT_ALPHA = 0.3;
public static final double DEFAULT_BETA = 0.1;
/**
* Controls smoothing of data. Also known as "level" value.
@ -55,6 +59,10 @@ public class HoltLinearModel extends MovAvgModel {
*/
private final double beta;
public HoltLinearModel() {
this(DEFAULT_ALPHA, DEFAULT_BETA);
}
public HoltLinearModel(double alpha, double beta) {
this.alpha = alpha;
this.beta = beta;
@ -157,7 +165,7 @@ public class HoltLinearModel extends MovAvgModel {
public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() {
@Override
public MovAvgModel readResult(StreamInput in) throws IOException {
return new HoltLinearModel(in.readDouble(), in.readDouble());
return PROTOTYPE.readFrom(in);
}
@Override
@ -166,6 +174,21 @@ public class HoltLinearModel extends MovAvgModel {
}
};
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName());
builder.startObject(MovAvgParser.SETTINGS.getPreferredName());
builder.field("alpha", alpha);
builder.field("beta", beta);
builder.endObject();
return builder;
}
@Override
public MovAvgModel readFrom(StreamInput in) throws IOException {
return new HoltLinearModel(in.readDouble(), in.readDouble());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(STREAM.getName());
@ -173,6 +196,24 @@ public class HoltLinearModel extends MovAvgModel {
out.writeDouble(beta);
}
@Override
public int hashCode() {
return Objects.hash(alpha, beta);
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
HoltLinearModel other = (HoltLinearModel) obj;
return Objects.equals(alpha, other.alpha)
&& Objects.equals(beta, other.beta);
}
public static class DoubleExpModelParser extends AbstractModelParser {
@Override
@ -184,8 +225,8 @@ public class HoltLinearModel extends MovAvgModel {
public MovAvgModel parse(@Nullable Map<String, Object> settings, String pipelineName, int windowSize,
ParseFieldMatcher parseFieldMatcher) throws ParseException {
double alpha = parseDoubleParam(settings, "alpha", 0.3);
double beta = parseDoubleParam(settings, "beta", 0.1);
double alpha = parseDoubleParam(settings, "alpha", DEFAULT_ALPHA);
double beta = parseDoubleParam(settings, "beta", DEFAULT_BETA);
checkUnrecognizedParams(settings);
return new HoltLinearModel(alpha, beta);
}

View File

@ -37,6 +37,7 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/**
* Calculate a triple exponential weighted moving average
@ -44,6 +45,13 @@ import java.util.Map;
public class HoltWintersModel extends MovAvgModel {
protected static final ParseField NAME_FIELD = new ParseField("holt_winters");
public static final double DEFAULT_ALPHA = 0.3;
public static final double DEFAULT_BETA = 0.1;
public static final double DEFAULT_GAMMA = 0.3;
public static final int DEFAULT_PERIOD = 1;
public static final SeasonalityType DEFAULT_SEASONALITY_TYPE = SeasonalityType.ADDITIVE;
public static final boolean DEFAULT_PAD = false;
private static final HoltWintersModel PROTOTYPE = new HoltWintersModel();
/**
* Controls smoothing of data. Also known as "level" value.
@ -159,6 +167,9 @@ public class HoltWintersModel extends MovAvgModel {
}
}
public HoltWintersModel() {
this(DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_GAMMA, DEFAULT_PERIOD, DEFAULT_SEASONALITY_TYPE, DEFAULT_PAD);
}
public HoltWintersModel(double alpha, double beta, double gamma, int period, SeasonalityType seasonalityType, boolean pad) {
this.alpha = alpha;
@ -273,8 +284,8 @@ public class HoltWintersModel extends MovAvgModel {
s += vs[i];
b += (vs[i + period] - vs[i]) / period;
}
s /= (double) period;
b /= (double) period;
s /= period;
b /= period;
last_s = s;
// Calculate first seasonal
@ -324,14 +335,7 @@ public class HoltWintersModel extends MovAvgModel {
public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() {
@Override
public MovAvgModel readResult(StreamInput in) throws IOException {
double alpha = in.readDouble();
double beta = in.readDouble();
double gamma = in.readDouble();
int period = in.readVInt();
SeasonalityType type = SeasonalityType.readFrom(in);
boolean pad = in.readBoolean();
return new HoltWintersModel(alpha, beta, gamma, period, type, pad);
return PROTOTYPE.readFrom(in);
}
@Override
@ -340,6 +344,26 @@ public class HoltWintersModel extends MovAvgModel {
}
};
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName());
builder.startObject(MovAvgParser.SETTINGS.getPreferredName());
builder.field("alpha", alpha);
builder.field("beta", beta);
builder.field("gamma", gamma);
builder.field("period", period);
builder.field("pad", pad);
builder.field("type", seasonalityType.getName());
builder.endObject();
return builder;
}
@Override
public MovAvgModel readFrom(StreamInput in) throws IOException {
return new HoltWintersModel(in.readDouble(), in.readDouble(), in.readDouble(), in.readVInt(), SeasonalityType.readFrom(in),
in.readBoolean());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(STREAM.getName());
@ -351,6 +375,28 @@ public class HoltWintersModel extends MovAvgModel {
out.writeBoolean(pad);
}
@Override
public int hashCode() {
return Objects.hash(alpha, beta, gamma, period, seasonalityType, pad);
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
HoltWintersModel other = (HoltWintersModel) obj;
return Objects.equals(alpha, other.alpha)
&& Objects.equals(beta, other.beta)
&& Objects.equals(gamma, other.gamma)
&& Objects.equals(period, other.period)
&& Objects.equals(seasonalityType, other.seasonalityType)
&& Objects.equals(pad, other.pad);
}
public static class HoltWintersModelParser extends AbstractModelParser {
@Override
@ -362,10 +408,10 @@ public class HoltWintersModel extends MovAvgModel {
public MovAvgModel parse(@Nullable Map<String, Object> settings, String pipelineName, int windowSize,
ParseFieldMatcher parseFieldMatcher) throws ParseException {
double alpha = parseDoubleParam(settings, "alpha", 0.3);
double beta = parseDoubleParam(settings, "beta", 0.1);
double gamma = parseDoubleParam(settings, "gamma", 0.3);
int period = parseIntegerParam(settings, "period", 1);
double alpha = parseDoubleParam(settings, "alpha", DEFAULT_ALPHA);
double beta = parseDoubleParam(settings, "beta", DEFAULT_BETA);
double gamma = parseDoubleParam(settings, "gamma", DEFAULT_GAMMA);
int period = parseIntegerParam(settings, "period", DEFAULT_PERIOD);
if (windowSize < 2 * period) {
throw new ParseException("Field [window] must be at least twice as large as the period when " +
@ -373,7 +419,7 @@ public class HoltWintersModel extends MovAvgModel {
+ (2 * period), 0);
}
SeasonalityType seasonalityType = SeasonalityType.ADDITIVE;
SeasonalityType seasonalityType = DEFAULT_SEASONALITY_TYPE;
if (settings != null) {
Object value = settings.get("type");

View File

@ -40,6 +40,7 @@ import java.util.Map;
*/
public class LinearModel extends MovAvgModel {
private static final LinearModel PROTOTYPE = new LinearModel();
protected static final ParseField NAME_FIELD = new ParseField("linear");
@ -85,7 +86,7 @@ public class LinearModel extends MovAvgModel {
public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() {
@Override
public MovAvgModel readResult(StreamInput in) throws IOException {
return new LinearModel();
return PROTOTYPE.readFrom(in);
}
@Override
@ -94,6 +95,17 @@ public class LinearModel extends MovAvgModel {
}
};
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName());
return builder;
}
@Override
public MovAvgModel readFrom(StreamInput in) throws IOException {
return new LinearModel();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(STREAM.getName());
@ -121,4 +133,20 @@ public class LinearModel extends MovAvgModel {
return builder;
}
}
@Override
public int hashCode() {
return 0;
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
return true;
}
}

View File

@ -22,6 +22,8 @@ package org.elasticsearch.search.aggregations.pipeline.movavg.models;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import java.io.IOException;
import java.text.ParseException;
@ -29,7 +31,7 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
public abstract class MovAvgModel {
public abstract class MovAvgModel implements Writeable<MovAvgModel>, ToXContent {
/**
* Should this model be fit to the data via a cost minimizing algorithm by default?
@ -116,13 +118,21 @@ public abstract class MovAvgModel {
*
* @param out Output stream
*/
@Override
public abstract void writeTo(StreamOutput out) throws IOException;
/**
* Clone the model, returning an exact copy
*/
@Override
public abstract MovAvgModel clone();
@Override
public abstract int hashCode();
@Override
public abstract boolean equals(Object obj);
/**
* Abstract class which also provides some concrete parsing functionality.
*/

View File

@ -38,6 +38,7 @@ import java.util.Map;
*/
public class SimpleModel extends MovAvgModel {
private static final SimpleModel PROTOTYPE = new SimpleModel();
protected static final ParseField NAME_FIELD = new ParseField("simple");
@ -78,7 +79,7 @@ public class SimpleModel extends MovAvgModel {
public static final MovAvgModelStreams.Stream STREAM = new MovAvgModelStreams.Stream() {
@Override
public MovAvgModel readResult(StreamInput in) throws IOException {
return new SimpleModel();
return PROTOTYPE.readFrom(in);
}
@Override
@ -87,6 +88,17 @@ public class SimpleModel extends MovAvgModel {
}
};
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MovAvgParser.MODEL.getPreferredName(), NAME_FIELD.getPreferredName());
return builder;
}
@Override
public MovAvgModel readFrom(StreamInput in) throws IOException {
return new SimpleModel();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(STREAM.getName());
@ -114,4 +126,20 @@ public class SimpleModel extends MovAvgModel {
return builder;
}
}
@Override
public int hashCode() {
return 0;
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
return true;
}
}

View File

@ -0,0 +1,96 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.pipeline.moving.avg;
import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.BucketHelpers.GapPolicy;
import org.elasticsearch.search.aggregations.pipeline.movavg.MovAvgPipelineAggregator;
import org.elasticsearch.search.aggregations.pipeline.movavg.MovAvgPipelineAggregator.Factory;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.EwmaModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.HoltLinearModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.HoltWintersModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.HoltWintersModel.SeasonalityType;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.LinearModel;
import org.elasticsearch.search.aggregations.pipeline.movavg.models.SimpleModel;;
public class MovAvgTests extends BasePipelineAggregationTestCase<MovAvgPipelineAggregator.Factory> {
@Override
protected Factory createTestAggregatorFactory() {
String name = randomAsciiOfLengthBetween(3, 20);
String[] bucketsPaths = new String[1];
bucketsPaths[0] = randomAsciiOfLengthBetween(3, 20);
Factory factory = new Factory(name, bucketsPaths);
if (randomBoolean()) {
factory.format(randomAsciiOfLengthBetween(1, 10));
}
if (randomBoolean()) {
factory.gapPolicy(randomFrom(GapPolicy.values()));
}
if (randomBoolean()) {
switch (randomInt(4)) {
case 0:
factory.model(new SimpleModel());
factory.window(randomIntBetween(1, 100));
break;
case 1:
factory.model(new LinearModel());
factory.window(randomIntBetween(1, 100));
break;
case 2:
if (randomBoolean()) {
factory.model(new EwmaModel());
factory.window(randomIntBetween(1, 100));
} else {
factory.model(new EwmaModel(randomDouble()));
factory.window(randomIntBetween(1, 100));
}
break;
case 3:
if (randomBoolean()) {
factory.model(new HoltLinearModel());
factory.window(randomIntBetween(1, 100));
} else {
factory.model(new HoltLinearModel(randomDouble(), randomDouble()));
factory.window(randomIntBetween(1, 100));
}
break;
case 4:
default:
if (randomBoolean()) {
factory.model(new HoltWintersModel());
factory.window(randomIntBetween(2, 100));
} else {
int period = randomIntBetween(1, 100);
factory.model(new HoltWintersModel(randomDouble(), randomDouble(), randomDouble(), period,
randomFrom(SeasonalityType.values()), randomBoolean()));
factory.window(randomIntBetween(2 * period, 200 * period));
}
break;
}
}
factory.predict(randomIntBetween(1, 50));
if (factory.model().canBeMinimized() && randomBoolean()) {
factory.minimize(randomBoolean());
}
return factory;
}
}