Better HoltWinters parameter validation (#38747)

We validate HW parameters (namely, window > 2 * period) when parsing
the XContent... but that means transport clients can configure bad
params.

This change allows model to validate the window and throw an 
exception if they wish.

It also makes some test changes:

- removes testBadModelParams(), which was a junk test (didn't do
anything), and bad param checking is done elsewhere in units tests
- Fixes one of the windows in testHoltWintersNotEnoughData()
- Ensures the period in testHoltWintersNotEnoughData() is >> window
- Removes `setTypes()` since that's deprecated
This commit is contained in:
Zachary Tong 2019-02-22 15:25:26 -05:00 committed by GitHub
parent 931953a3ee
commit c7516b03b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 59 deletions

View File

@ -304,12 +304,6 @@ public class HoltWintersModel extends MovAvgModel {
double gamma = parseDoubleParam(settings, "gamma", DEFAULT_GAMMA); double gamma = parseDoubleParam(settings, "gamma", DEFAULT_GAMMA);
int period = parseIntegerParam(settings, "period", DEFAULT_PERIOD); int period = parseIntegerParam(settings, "period", DEFAULT_PERIOD);
if (windowSize < 2 * period) {
throw new ParseException("Field [window] must be at least twice as large as the period when " +
"using Holt-Winters. Value provided was [" + windowSize + "], which is less than (2*period) == "
+ (2 * period), 0);
}
SeasonalityType seasonalityType = DEFAULT_SEASONALITY_TYPE; SeasonalityType seasonalityType = DEFAULT_SEASONALITY_TYPE;
if (settings != null) { if (settings != null) {
@ -332,6 +326,21 @@ public class HoltWintersModel extends MovAvgModel {
} }
}; };
/**
* If the model is a HoltWinters, we need to ensure the window and period are compatible.
* This is verified in the XContent parsing, but transport clients need these checks since they
* skirt XContent parsing
*/
@Override
protected void validate(long window, String aggregationName) {
super.validate(window, aggregationName);
if (window < 2 * period) {
throw new IllegalArgumentException("Field [window] must be at least twice as large as the period when " +
"using Holt-Winters. Value provided was [" + window + "], which is less than (2*period) == "
+ (2 * period));
}
}
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(alpha, beta, gamma, period, seasonalityType, pad); return Objects.hash(alpha, beta, gamma, period, seasonalityType, pad);

View File

@ -99,6 +99,15 @@ public abstract class MovAvgModel implements NamedWriteable, ToXContentFragment
*/ */
protected abstract double[] doPredict(Collection<Double> values, int numPredictions); protected abstract double[] doPredict(Collection<Double> values, int numPredictions);
/**
* This method allows models to validate the window size if required
*/
protected void validate(long window, String aggregationName) {
if (window <= 0) {
throw new IllegalArgumentException("[window] must be a positive integer in aggregation [" + aggregationName + "]");
}
}
/** /**
* Returns an empty set of predictions, filled with NaNs * Returns an empty set of predictions, filled with NaNs
* @param numPredictions Number of empty predictions to generate * @param numPredictions Number of empty predictions to generate

View File

@ -147,6 +147,10 @@ public class MovAvgPipelineAggregationBuilder extends AbstractPipelineAggregatio
if (window <= 0) { if (window <= 0) {
throw new IllegalArgumentException("[window] must be a positive integer: [" + name + "]"); throw new IllegalArgumentException("[window] must be a positive integer: [" + name + "]");
} }
// If we have a model we can validate the window now
if (model != null) {
model.validate(window, name);
}
this.window = window; this.window = window;
return this; return this;
} }
@ -265,7 +269,8 @@ public class MovAvgPipelineAggregationBuilder extends AbstractPipelineAggregatio
throw new IllegalStateException(PipelineAggregator.Parser.BUCKETS_PATH.getPreferredName() throw new IllegalStateException(PipelineAggregator.Parser.BUCKETS_PATH.getPreferredName()
+ " must contain a single entry for aggregation [" + name + "]"); + " must contain a single entry for aggregation [" + name + "]");
} }
// Validate any model-specific window requirements
model.validate(window, name);
validateSequentiallyOrderedParentAggs(parent, NAME, name); validateSequentiallyOrderedParentAggs(parent, NAME, name);
} }

View File

@ -401,7 +401,7 @@ public class MovAvgIT extends ESIntegTestCase {
*/ */
public void testSimpleSingleValuedField() { public void testSimpleSingleValuedField() {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -449,7 +449,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testLinearSingleValuedField() { public void testLinearSingleValuedField() {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -497,7 +497,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testEwmaSingleValuedField() { public void testEwmaSingleValuedField() {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -545,7 +545,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testHoltSingleValuedField() { public void testHoltSingleValuedField() {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -594,7 +594,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testHoltWintersValuedField() { public void testHoltWintersValuedField() {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -648,7 +648,7 @@ public class MovAvgIT extends ESIntegTestCase {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("neg_idx") .prepareSearch("neg_idx")
.setTypes("type")
.addAggregation( .addAggregation(
histogram("histo") histogram("histo")
.field(INTERVAL_FIELD) .field(INTERVAL_FIELD)
@ -701,7 +701,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testSizeZeroWindow() { public void testSizeZeroWindow() {
try { try {
client() client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -720,7 +720,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testBadParent() { public void testBadParent() {
try { try {
client() client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
range("histo").field(INTERVAL_FIELD).addRange(0, 10) range("histo").field(INTERVAL_FIELD).addRange(0, 10)
.subAggregation(randomMetric("the_metric", VALUE_FIELD)) .subAggregation(randomMetric("the_metric", VALUE_FIELD))
@ -739,7 +739,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testNegativeWindow() { public void testNegativeWindow() {
try { try {
client() client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -758,7 +758,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testNoBucketsInHistogram() { public void testNoBucketsInHistogram() {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field("test").interval(interval) histogram("histo").field("test").interval(interval)
.subAggregation(randomMetric("the_metric", VALUE_FIELD)) .subAggregation(randomMetric("the_metric", VALUE_FIELD))
@ -780,7 +780,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testNoBucketsInHistogramWithPredict() { public void testNoBucketsInHistogramWithPredict() {
int numPredictions = randomIntBetween(1,10); int numPredictions = randomIntBetween(1,10);
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field("test").interval(interval) histogram("histo").field("test").interval(interval)
.subAggregation(randomMetric("the_metric", VALUE_FIELD)) .subAggregation(randomMetric("the_metric", VALUE_FIELD))
@ -803,7 +803,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testZeroPrediction() { public void testZeroPrediction() {
try { try {
client() client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -824,7 +824,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testNegativePrediction() { public void testNegativePrediction() {
try { try {
client() client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -842,31 +842,35 @@ public class MovAvgIT extends ESIntegTestCase {
} }
} }
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/34046")
public void testHoltWintersNotEnoughData() { public void testHoltWintersNotEnoughData() {
Client client = client(); Client client = client();
expectThrows(SearchPhaseExecutionException.class, () -> client.prepareSearch("idx").setTypes("type") expectThrows(SearchPhaseExecutionException.class, () -> client.prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
.subAggregation(metric) .subAggregation(metric)
.subAggregation(movingAvg("movavg_counts", "_count") .subAggregation(movingAvg("movavg_counts", "_count")
.window(10) .window(10)
.modelBuilder(new HoltWintersModel.HoltWintersModelBuilder() .modelBuilder(new HoltWintersModel.HoltWintersModelBuilder()
.alpha(alpha).beta(beta).gamma(gamma).period(20).seasonalityType(seasonalityType)) .alpha(alpha).beta(beta).gamma(gamma)
.gapPolicy(gapPolicy)) .period(interval * 10)
.subAggregation(movingAvg("movavg_values", "the_metric") .seasonalityType(seasonalityType))
.window(windowSize) .gapPolicy(gapPolicy))
.modelBuilder(new HoltWintersModel.HoltWintersModelBuilder() .subAggregation(movingAvg("movavg_values", "the_metric")
.alpha(alpha).beta(beta).gamma(gamma).period(20).seasonalityType(seasonalityType)) .window(10)
.gapPolicy(gapPolicy)) .modelBuilder(new HoltWintersModel.HoltWintersModelBuilder()
).get()); .alpha(alpha).beta(beta).gamma(gamma)
.period(interval * 10)
.seasonalityType(seasonalityType))
.gapPolicy(gapPolicy))
).get());
} }
public void testTwoMovAvgsWithPredictions() { public void testTwoMovAvgsWithPredictions() {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("double_predict") .prepareSearch("double_predict")
.setTypes("type")
.addAggregation( .addAggregation(
histogram("histo") histogram("histo")
.field(INTERVAL_FIELD) .field(INTERVAL_FIELD)
@ -980,24 +984,9 @@ public class MovAvgIT extends ESIntegTestCase {
} }
} }
@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/34046")
public void testBadModelParams() {
expectThrows(SearchPhaseExecutionException.class, () -> client()
.prepareSearch("idx").setTypes("type")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
.subAggregation(metric)
.subAggregation(movingAvg("movavg_counts", "_count")
.window(10)
.modelBuilder(randomModelBuilder(100))
.gapPolicy(gapPolicy))
).get());
}
public void testHoltWintersMinimization() { public void testHoltWintersMinimization() {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -1083,7 +1072,7 @@ public class MovAvgIT extends ESIntegTestCase {
*/ */
public void testMinimizeNotEnoughData() { public void testMinimizeNotEnoughData() {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -1137,7 +1126,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testCheckIfNonTunableCanBeMinimized() { public void testCheckIfNonTunableCanBeMinimized() {
try { try {
client() client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -1155,7 +1144,7 @@ public class MovAvgIT extends ESIntegTestCase {
try { try {
client() client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -1185,7 +1174,7 @@ public class MovAvgIT extends ESIntegTestCase {
for (MovAvgModelBuilder builder : builders) { for (MovAvgModelBuilder builder : builders) {
try { try {
client() client()
.prepareSearch("idx").setTypes("type") .prepareSearch("idx")
.addAggregation( .addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval) histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1)) .extendedBounds(0L, interval * (numBuckets - 1))
@ -1225,7 +1214,7 @@ public class MovAvgIT extends ESIntegTestCase {
SearchResponse response = client() SearchResponse response = client()
.prepareSearch("predict_non_empty") .prepareSearch("predict_non_empty")
.setTypes("type")
.addAggregation( .addAggregation(
histogram("histo") histogram("histo")
.field(INTERVAL_FIELD) .field(INTERVAL_FIELD)