Better HoltWinters parameter validation (#38747)

We validate HW parameters (namely, window > 2 * period) when parsing
the XContent... but that means transport clients can configure bad
params.

This change allows model to validate the window and throw an 
exception if they wish.

It also makes some test changes:

- removes testBadModelParams(), which was a junk test (didn't do
anything), and bad param checking is done elsewhere in units tests
- Fixes one of the windows in testHoltWintersNotEnoughData()
- Ensures the period in testHoltWintersNotEnoughData() is >> window
- Removes `setTypes()` since that's deprecated
This commit is contained in:
Zachary Tong 2019-02-22 15:25:26 -05:00 committed by GitHub
parent 931953a3ee
commit c7516b03b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 59 deletions

View File

@ -304,12 +304,6 @@ public class HoltWintersModel extends MovAvgModel {
double gamma = parseDoubleParam(settings, "gamma", DEFAULT_GAMMA);
int period = parseIntegerParam(settings, "period", DEFAULT_PERIOD);
if (windowSize < 2 * period) {
throw new ParseException("Field [window] must be at least twice as large as the period when " +
"using Holt-Winters. Value provided was [" + windowSize + "], which is less than (2*period) == "
+ (2 * period), 0);
}
SeasonalityType seasonalityType = DEFAULT_SEASONALITY_TYPE;
if (settings != null) {
@ -332,6 +326,21 @@ public class HoltWintersModel extends MovAvgModel {
}
};
/**
* If the model is a HoltWinters, we need to ensure the window and period are compatible.
* This is verified in the XContent parsing, but transport clients need these checks since they
* skirt XContent parsing
*/
@Override
protected void validate(long window, String aggregationName) {
super.validate(window, aggregationName);
if (window < 2 * period) {
throw new IllegalArgumentException("Field [window] must be at least twice as large as the period when " +
"using Holt-Winters. Value provided was [" + window + "], which is less than (2*period) == "
+ (2 * period));
}
}
@Override
public int hashCode() {
return Objects.hash(alpha, beta, gamma, period, seasonalityType, pad);

View File

@ -99,6 +99,15 @@ public abstract class MovAvgModel implements NamedWriteable, ToXContentFragment
*/
protected abstract double[] doPredict(Collection<Double> values, int numPredictions);
/**
* This method allows models to validate the window size if required
*/
protected void validate(long window, String aggregationName) {
if (window <= 0) {
throw new IllegalArgumentException("[window] must be a positive integer in aggregation [" + aggregationName + "]");
}
}
/**
* Returns an empty set of predictions, filled with NaNs
* @param numPredictions Number of empty predictions to generate

View File

@ -147,6 +147,10 @@ public class MovAvgPipelineAggregationBuilder extends AbstractPipelineAggregatio
if (window <= 0) {
throw new IllegalArgumentException("[window] must be a positive integer: [" + name + "]");
}
// If we have a model we can validate the window now
if (model != null) {
model.validate(window, name);
}
this.window = window;
return this;
}
@ -265,7 +269,8 @@ public class MovAvgPipelineAggregationBuilder extends AbstractPipelineAggregatio
throw new IllegalStateException(PipelineAggregator.Parser.BUCKETS_PATH.getPreferredName()
+ " must contain a single entry for aggregation [" + name + "]");
}
// Validate any model-specific window requirements
model.validate(window, name);
validateSequentiallyOrderedParentAggs(parent, NAME, name);
}

View File

@ -401,7 +401,7 @@ public class MovAvgIT extends ESIntegTestCase {
*/
public void testSimpleSingleValuedField() {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -449,7 +449,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testLinearSingleValuedField() {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -497,7 +497,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testEwmaSingleValuedField() {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -545,7 +545,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testHoltSingleValuedField() {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -594,7 +594,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testHoltWintersValuedField() {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -648,7 +648,7 @@ public class MovAvgIT extends ESIntegTestCase {
SearchResponse response = client()
.prepareSearch("neg_idx")
.setTypes("type")
.addAggregation(
histogram("histo")
.field(INTERVAL_FIELD)
@ -701,7 +701,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testSizeZeroWindow() {
try {
client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -720,7 +720,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testBadParent() {
try {
client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
range("histo").field(INTERVAL_FIELD).addRange(0, 10)
.subAggregation(randomMetric("the_metric", VALUE_FIELD))
@ -739,7 +739,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testNegativeWindow() {
try {
client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -758,7 +758,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testNoBucketsInHistogram() {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field("test").interval(interval)
.subAggregation(randomMetric("the_metric", VALUE_FIELD))
@ -780,7 +780,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testNoBucketsInHistogramWithPredict() {
int numPredictions = randomIntBetween(1,10);
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field("test").interval(interval)
.subAggregation(randomMetric("the_metric", VALUE_FIELD))
@ -803,7 +803,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testZeroPrediction() {
try {
client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -824,7 +824,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testNegativePrediction() {
try {
client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -842,10 +842,9 @@ public class MovAvgIT extends ESIntegTestCase {
}
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/34046")
public void testHoltWintersNotEnoughData() {
Client client = client();
expectThrows(SearchPhaseExecutionException.class, () -> client.prepareSearch("idx").setTypes("type")
expectThrows(SearchPhaseExecutionException.class, () -> client.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -853,20 +852,25 @@ public class MovAvgIT extends ESIntegTestCase {
.subAggregation(movingAvg("movavg_counts", "_count")
.window(10)
.modelBuilder(new HoltWintersModel.HoltWintersModelBuilder()
.alpha(alpha).beta(beta).gamma(gamma).period(20).seasonalityType(seasonalityType))
.alpha(alpha).beta(beta).gamma(gamma)
.period(interval * 10)
.seasonalityType(seasonalityType))
.gapPolicy(gapPolicy))
.subAggregation(movingAvg("movavg_values", "the_metric")
.window(windowSize)
.window(10)
.modelBuilder(new HoltWintersModel.HoltWintersModelBuilder()
.alpha(alpha).beta(beta).gamma(gamma).period(20).seasonalityType(seasonalityType))
.alpha(alpha).beta(beta).gamma(gamma)
.period(interval * 10)
.seasonalityType(seasonalityType))
.gapPolicy(gapPolicy))
).get());
}
public void testTwoMovAvgsWithPredictions() {
SearchResponse response = client()
.prepareSearch("double_predict")
.setTypes("type")
.addAggregation(
histogram("histo")
.field(INTERVAL_FIELD)
@ -980,24 +984,9 @@ public class MovAvgIT extends ESIntegTestCase {
}
}
@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/34046")
public void testBadModelParams() {
expectThrows(SearchPhaseExecutionException.class, () -> client()
.prepareSearch("idx").setTypes("type")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
.subAggregation(metric)
.subAggregation(movingAvg("movavg_counts", "_count")
.window(10)
.modelBuilder(randomModelBuilder(100))
.gapPolicy(gapPolicy))
).get());
}
public void testHoltWintersMinimization() {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -1083,7 +1072,7 @@ public class MovAvgIT extends ESIntegTestCase {
*/
public void testMinimizeNotEnoughData() {
SearchResponse response = client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -1137,7 +1126,7 @@ public class MovAvgIT extends ESIntegTestCase {
public void testCheckIfNonTunableCanBeMinimized() {
try {
client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -1155,7 +1144,7 @@ public class MovAvgIT extends ESIntegTestCase {
try {
client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -1185,7 +1174,7 @@ public class MovAvgIT extends ESIntegTestCase {
for (MovAvgModelBuilder builder : builders) {
try {
client()
.prepareSearch("idx").setTypes("type")
.prepareSearch("idx")
.addAggregation(
histogram("histo").field(INTERVAL_FIELD).interval(interval)
.extendedBounds(0L, interval * (numBuckets - 1))
@ -1225,7 +1214,7 @@ public class MovAvgIT extends ESIntegTestCase {
SearchResponse response = client()
.prepareSearch("predict_non_empty")
.setTypes("type")
.addAggregation(
histogram("histo")
.field(INTERVAL_FIELD)