Fix bug where predictions append to the previous prediction
Fixes #11454
This commit is contained in:
parent
6812ed0bb6
commit
d435fae067
|
@ -47,9 +47,7 @@ import org.elasticsearch.search.aggregations.support.format.ValueFormatterStream
|
||||||
import org.joda.time.DateTime;
|
import org.joda.time.DateTime;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static org.elasticsearch.search.aggregations.pipeline.BucketHelpers.resolveBucketValue;
|
import static org.elasticsearch.search.aggregations.pipeline.BucketHelpers.resolveBucketValue;
|
||||||
|
|
||||||
|
@ -110,12 +108,12 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
|
||||||
List newBuckets = new ArrayList<>();
|
List newBuckets = new ArrayList<>();
|
||||||
EvictingQueue<Double> values = EvictingQueue.create(this.window);
|
EvictingQueue<Double> values = EvictingQueue.create(this.window);
|
||||||
|
|
||||||
long lastKey = 0;
|
long lastValidKey = 0;
|
||||||
Object currentKey;
|
int lastValidPosition = 0;
|
||||||
|
int counter = 0;
|
||||||
|
|
||||||
for (InternalHistogram.Bucket bucket : buckets) {
|
for (InternalHistogram.Bucket bucket : buckets) {
|
||||||
Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], gapPolicy);
|
Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], gapPolicy);
|
||||||
currentKey = bucket.getKey();
|
|
||||||
|
|
||||||
// Default is to reuse existing bucket. Simplifies the rest of the logic,
|
// Default is to reuse existing bucket. Simplifies the rest of the logic,
|
||||||
// since we only change newBucket if we can add to it
|
// since we only change newBucket if we can add to it
|
||||||
|
@ -130,22 +128,23 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
|
||||||
|
|
||||||
List<InternalAggregation> aggs = new ArrayList<>(Lists.transform(bucket.getAggregations().asList(), AGGREGATION_TRANFORM_FUNCTION));
|
List<InternalAggregation> aggs = new ArrayList<>(Lists.transform(bucket.getAggregations().asList(), AGGREGATION_TRANFORM_FUNCTION));
|
||||||
aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList<PipelineAggregator>(), metaData()));
|
aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList<PipelineAggregator>(), metaData()));
|
||||||
newBucket = factory.createBucket(currentKey, bucket.getDocCount(), new InternalAggregations(
|
newBucket = factory.createBucket(bucket.getKey(), bucket.getDocCount(), new InternalAggregations(
|
||||||
aggs), bucket.getKeyed(), bucket.getFormatter());
|
aggs), bucket.getKeyed(), bucket.getFormatter());
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
newBuckets.add(newBucket);
|
|
||||||
|
|
||||||
if (predict > 0) {
|
if (predict > 0) {
|
||||||
if (currentKey instanceof Number) {
|
if (bucket.getKey() instanceof Number) {
|
||||||
lastKey = ((Number) bucket.getKey()).longValue();
|
lastValidKey = ((Number) bucket.getKey()).longValue();
|
||||||
} else if (currentKey instanceof DateTime) {
|
} else if (bucket.getKey() instanceof DateTime) {
|
||||||
lastKey = ((DateTime) bucket.getKey()).getMillis();
|
lastValidKey = ((DateTime) bucket.getKey()).getMillis();
|
||||||
} else {
|
} else {
|
||||||
throw new AggregationExecutionException("Expected key of type Number or DateTime but got [" + currentKey + "]");
|
throw new AggregationExecutionException("Expected key of type Number or DateTime but got [" + lastValidKey + "]");
|
||||||
|
}
|
||||||
|
lastValidPosition = counter;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
counter += 1;
|
||||||
|
newBuckets.add(newBucket);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,13 +157,35 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
|
||||||
|
|
||||||
double[] predictions = model.predict(values, predict);
|
double[] predictions = model.predict(values, predict);
|
||||||
for (int i = 0; i < predictions.length; i++) {
|
for (int i = 0; i < predictions.length; i++) {
|
||||||
List<InternalAggregation> aggs = new ArrayList<>();
|
|
||||||
|
List<InternalAggregation> aggs;
|
||||||
|
long newKey = histo.getRounding().nextRoundingValue(lastValidKey);
|
||||||
|
|
||||||
|
if (lastValidPosition + i + 1 < newBuckets.size()) {
|
||||||
|
InternalHistogram.Bucket bucket = (InternalHistogram.Bucket) newBuckets.get(lastValidPosition + i + 1);
|
||||||
|
|
||||||
|
// Get the existing aggs in the bucket so we don't clobber data
|
||||||
|
aggs = new ArrayList<>(Lists.transform(bucket.getAggregations().asList(), AGGREGATION_TRANFORM_FUNCTION));
|
||||||
aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList<PipelineAggregator>(), metaData()));
|
aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList<PipelineAggregator>(), metaData()));
|
||||||
long newKey = histo.getRounding().nextRoundingValue(lastKey);
|
|
||||||
InternalHistogram.Bucket newBucket = factory.createBucket(newKey, 0, new InternalAggregations(
|
InternalHistogram.Bucket newBucket = factory.createBucket(newKey, 0, new InternalAggregations(
|
||||||
aggs), keyed, formatter);
|
aggs), keyed, formatter);
|
||||||
|
|
||||||
|
// Overwrite the existing bucket with the new version
|
||||||
|
newBuckets.set(lastValidPosition + i + 1, newBucket);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Not seen before, create fresh
|
||||||
|
aggs = new ArrayList<>();
|
||||||
|
aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList<PipelineAggregator>(), metaData()));
|
||||||
|
|
||||||
|
InternalHistogram.Bucket newBucket = factory.createBucket(newKey, 0, new InternalAggregations(
|
||||||
|
aggs), keyed, formatter);
|
||||||
|
|
||||||
|
// Since this is a new bucket, simply append it
|
||||||
newBuckets.add(newBucket);
|
newBuckets.add(newBucket);
|
||||||
lastKey = newKey;
|
}
|
||||||
|
lastValidKey = newKey;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ import org.elasticsearch.search.aggregations.metrics.avg.Avg;
|
||||||
import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
|
import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
|
||||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregationHelperTests;
|
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregationHelperTests;
|
||||||
import org.elasticsearch.search.aggregations.pipeline.SimpleValue;
|
import org.elasticsearch.search.aggregations.pipeline.SimpleValue;
|
||||||
|
import org.elasticsearch.search.aggregations.pipeline.derivative.Derivative;
|
||||||
import org.elasticsearch.search.aggregations.pipeline.movavg.models.*;
|
import org.elasticsearch.search.aggregations.pipeline.movavg.models.*;
|
||||||
import org.elasticsearch.test.ElasticsearchIntegrationTest;
|
import org.elasticsearch.test.ElasticsearchIntegrationTest;
|
||||||
import org.hamcrest.Matchers;
|
import org.hamcrest.Matchers;
|
||||||
|
@ -49,6 +50,7 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.histogra
|
||||||
import static org.elasticsearch.search.aggregations.AggregationBuilders.max;
|
import static org.elasticsearch.search.aggregations.AggregationBuilders.max;
|
||||||
import static org.elasticsearch.search.aggregations.AggregationBuilders.min;
|
import static org.elasticsearch.search.aggregations.AggregationBuilders.min;
|
||||||
import static org.elasticsearch.search.aggregations.AggregationBuilders.range;
|
import static org.elasticsearch.search.aggregations.AggregationBuilders.range;
|
||||||
|
import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.derivative;
|
||||||
import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.movingAvg;
|
import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.movingAvg;
|
||||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
|
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
|
||||||
import static org.hamcrest.Matchers.closeTo;
|
import static org.hamcrest.Matchers.closeTo;
|
||||||
|
@ -160,6 +162,11 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
|
||||||
jsonBuilder().startObject().field(INTERVAL_FIELD, i).field(VALUE_FIELD, 10).endObject()));
|
jsonBuilder().startObject().field(INTERVAL_FIELD, i).field(VALUE_FIELD, 10).endObject()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 12; i++) {
|
||||||
|
builders.add(client().prepareIndex("double_predict", "type").setSource(
|
||||||
|
jsonBuilder().startObject().field(INTERVAL_FIELD, i).field(VALUE_FIELD, 10).endObject()));
|
||||||
|
}
|
||||||
|
|
||||||
indexRandom(true, builders);
|
indexRandom(true, builders);
|
||||||
ensureSearchable();
|
ensureSearchable();
|
||||||
}
|
}
|
||||||
|
@ -957,8 +964,10 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
|
||||||
assertThat(histo, notNullValue());
|
assertThat(histo, notNullValue());
|
||||||
assertThat(histo.getName(), equalTo("histo"));
|
assertThat(histo.getName(), equalTo("histo"));
|
||||||
List<? extends Bucket> buckets = histo.getBuckets();
|
List<? extends Bucket> buckets = histo.getBuckets();
|
||||||
|
|
||||||
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50 + numPredictions));
|
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50 + numPredictions));
|
||||||
|
|
||||||
|
|
||||||
double lastValue = ((SimpleValue)(buckets.get(0).getAggregations().get("movavg_values"))).value();
|
double lastValue = ((SimpleValue)(buckets.get(0).getAggregations().get("movavg_values"))).value();
|
||||||
assertThat(Double.compare(lastValue, 0.0d), greaterThanOrEqualTo(0));
|
assertThat(Double.compare(lastValue, 0.0d), greaterThanOrEqualTo(0));
|
||||||
|
|
||||||
|
@ -1073,8 +1082,10 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
|
||||||
assertThat(histo, notNullValue());
|
assertThat(histo, notNullValue());
|
||||||
assertThat(histo.getName(), equalTo("histo"));
|
assertThat(histo.getName(), equalTo("histo"));
|
||||||
List<? extends Bucket> buckets = histo.getBuckets();
|
List<? extends Bucket> buckets = histo.getBuckets();
|
||||||
|
|
||||||
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50 + numPredictions));
|
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50 + numPredictions));
|
||||||
|
|
||||||
|
|
||||||
double lastValue = 0;
|
double lastValue = 0;
|
||||||
|
|
||||||
double currentValue;
|
double currentValue;
|
||||||
|
@ -1099,8 +1110,7 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This test filters the "gap" data so that the last doc is excluded. This leaves a long stretch of empty
|
* This test filters the "gap" data so that the last doc is excluded. This leaves a long stretch of empty
|
||||||
* buckets after the first bucket. The moving avg should be one at the beginning, then zero for the rest
|
* buckets after the first bucket.
|
||||||
* regardless of mov avg type or gap policy.
|
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testRightGap() {
|
public void testRightGap() {
|
||||||
|
@ -1176,34 +1186,41 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
|
||||||
assertThat(histo, notNullValue());
|
assertThat(histo, notNullValue());
|
||||||
assertThat(histo.getName(), equalTo("histo"));
|
assertThat(histo.getName(), equalTo("histo"));
|
||||||
List<? extends Bucket> buckets = histo.getBuckets();
|
List<? extends Bucket> buckets = histo.getBuckets();
|
||||||
|
|
||||||
|
// If we are skipping, there will only be predictions at the very beginning and won't append any new buckets
|
||||||
|
if (gapPolicy.equals(BucketHelpers.GapPolicy.SKIP)) {
|
||||||
|
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50));
|
||||||
|
} else {
|
||||||
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50 + numPredictions));
|
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50 + numPredictions));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlike left-gap tests, we cannot check the slope of prediction for right-gap. E.g. linear will
|
||||||
|
// converge on zero, but holt-linear may trend upwards based on the first value
|
||||||
|
// Just check for non-nullness
|
||||||
SimpleValue current = buckets.get(0).getAggregations().get("movavg_values");
|
SimpleValue current = buckets.get(0).getAggregations().get("movavg_values");
|
||||||
assertThat(current, notNullValue());
|
assertThat(current, notNullValue());
|
||||||
|
|
||||||
double lastValue = current.value();
|
// If we are skipping, there will only be predictions at the very beginning and won't append any new buckets
|
||||||
|
if (gapPolicy.equals(BucketHelpers.GapPolicy.SKIP)) {
|
||||||
double currentValue;
|
|
||||||
for (int i = 1; i < 50; i++) {
|
|
||||||
current = buckets.get(i).getAggregations().get("movavg_values");
|
|
||||||
if (current != null) {
|
|
||||||
currentValue = current.value();
|
|
||||||
|
|
||||||
assertThat(Double.compare(lastValue, currentValue), greaterThanOrEqualTo(0));
|
|
||||||
lastValue = currentValue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now check predictions
|
// Now check predictions
|
||||||
|
for (int i = 1; i < 1 + numPredictions; i++) {
|
||||||
|
// Unclear at this point which direction the predictions will go, just verify they are
|
||||||
|
// not null
|
||||||
|
assertThat(buckets.get(i).getDocCount(), equalTo(0L));
|
||||||
|
assertThat((buckets.get(i).getAggregations().get("movavg_values")), notNullValue());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Otherwise we'll have some predictions at the end
|
||||||
for (int i = 50; i < 50 + numPredictions; i++) {
|
for (int i = 50; i < 50 + numPredictions; i++) {
|
||||||
// Unclear at this point which direction the predictions will go, just verify they are
|
// Unclear at this point which direction the predictions will go, just verify they are
|
||||||
// not null, and that we don't have the_metric anymore
|
// not null
|
||||||
|
assertThat(buckets.get(i).getDocCount(), equalTo(0L));
|
||||||
assertThat((buckets.get(i).getAggregations().get("movavg_values")), notNullValue());
|
assertThat((buckets.get(i).getAggregations().get("movavg_values")), notNullValue());
|
||||||
assertThat((buckets.get(i).getAggregations().get("the_metric")), nullValue());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testHoltWintersNotEnoughData() {
|
public void testHoltWintersNotEnoughData() {
|
||||||
try {
|
try {
|
||||||
|
@ -1232,6 +1249,100 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTwoMovAvgsWithPredictions() {
|
||||||
|
|
||||||
|
SearchResponse response = client()
|
||||||
|
.prepareSearch("double_predict")
|
||||||
|
.setTypes("type")
|
||||||
|
.addAggregation(
|
||||||
|
histogram("histo")
|
||||||
|
.field(INTERVAL_FIELD)
|
||||||
|
.interval(1)
|
||||||
|
.subAggregation(avg("avg").field(VALUE_FIELD))
|
||||||
|
.subAggregation(derivative("deriv")
|
||||||
|
.setBucketsPaths("avg").gapPolicy(gapPolicy))
|
||||||
|
.subAggregation(
|
||||||
|
movingAvg("avg_movavg").window(windowSize).modelBuilder(new SimpleModel.SimpleModelBuilder())
|
||||||
|
.gapPolicy(gapPolicy).predict(12).setBucketsPaths("avg"))
|
||||||
|
.subAggregation(
|
||||||
|
movingAvg("deriv_movavg").window(windowSize).modelBuilder(new SimpleModel.SimpleModelBuilder())
|
||||||
|
.gapPolicy(gapPolicy).predict(12).setBucketsPaths("deriv"))
|
||||||
|
).execute().actionGet();
|
||||||
|
|
||||||
|
assertSearchResponse(response);
|
||||||
|
|
||||||
|
InternalHistogram<Bucket> histo = response.getAggregations().get("histo");
|
||||||
|
assertThat(histo, notNullValue());
|
||||||
|
assertThat(histo.getName(), equalTo("histo"));
|
||||||
|
List<? extends Bucket> buckets = histo.getBuckets();
|
||||||
|
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(24));
|
||||||
|
|
||||||
|
Bucket bucket = buckets.get(0);
|
||||||
|
assertThat(bucket, notNullValue());
|
||||||
|
assertThat((long) bucket.getKey(), equalTo((long) 0));
|
||||||
|
assertThat(bucket.getDocCount(), equalTo(1l));
|
||||||
|
|
||||||
|
Avg avgAgg = bucket.getAggregations().get("avg");
|
||||||
|
assertThat(avgAgg, notNullValue());
|
||||||
|
assertThat(avgAgg.value(), equalTo(10d));
|
||||||
|
|
||||||
|
SimpleValue movAvgAgg = bucket.getAggregations().get("avg_movavg");
|
||||||
|
assertThat(movAvgAgg, notNullValue());
|
||||||
|
assertThat(movAvgAgg.value(), equalTo(10d));
|
||||||
|
|
||||||
|
Derivative deriv = bucket.getAggregations().get("deriv");
|
||||||
|
assertThat(deriv, nullValue());
|
||||||
|
|
||||||
|
SimpleValue derivMovAvg = bucket.getAggregations().get("deriv_movavg");
|
||||||
|
assertThat(derivMovAvg, nullValue());
|
||||||
|
|
||||||
|
for (int i = 1; i < 12; i++) {
|
||||||
|
bucket = buckets.get(i);
|
||||||
|
assertThat(bucket, notNullValue());
|
||||||
|
assertThat((long) bucket.getKey(), equalTo((long) i));
|
||||||
|
assertThat(bucket.getDocCount(), equalTo(1l));
|
||||||
|
|
||||||
|
avgAgg = bucket.getAggregations().get("avg");
|
||||||
|
assertThat(avgAgg, notNullValue());
|
||||||
|
assertThat(avgAgg.value(), equalTo(10d));
|
||||||
|
|
||||||
|
deriv = bucket.getAggregations().get("deriv");
|
||||||
|
assertThat(deriv, notNullValue());
|
||||||
|
assertThat(deriv.value(), equalTo(0d));
|
||||||
|
|
||||||
|
movAvgAgg = bucket.getAggregations().get("avg_movavg");
|
||||||
|
assertThat(movAvgAgg, notNullValue());
|
||||||
|
assertThat(movAvgAgg.value(), equalTo(10d));
|
||||||
|
|
||||||
|
derivMovAvg = bucket.getAggregations().get("deriv_movavg");
|
||||||
|
assertThat(derivMovAvg, notNullValue());
|
||||||
|
assertThat(derivMovAvg.value(), equalTo(0d));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Predictions
|
||||||
|
for (int i = 12; i < 24; i++) {
|
||||||
|
bucket = buckets.get(i);
|
||||||
|
assertThat(bucket, notNullValue());
|
||||||
|
assertThat((long) bucket.getKey(), equalTo((long) i));
|
||||||
|
assertThat(bucket.getDocCount(), equalTo(0l));
|
||||||
|
|
||||||
|
avgAgg = bucket.getAggregations().get("avg");
|
||||||
|
assertThat(avgAgg, nullValue());
|
||||||
|
|
||||||
|
deriv = bucket.getAggregations().get("deriv");
|
||||||
|
assertThat(deriv, nullValue());
|
||||||
|
|
||||||
|
movAvgAgg = bucket.getAggregations().get("avg_movavg");
|
||||||
|
assertThat(movAvgAgg, notNullValue());
|
||||||
|
assertThat(movAvgAgg.value(), equalTo(10d));
|
||||||
|
|
||||||
|
derivMovAvg = bucket.getAggregations().get("deriv_movavg");
|
||||||
|
assertThat(derivMovAvg, notNullValue());
|
||||||
|
assertThat(derivMovAvg.value(), equalTo(0d));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBadModelParams() {
|
public void testBadModelParams() {
|
||||||
try {
|
try {
|
||||||
|
|
Loading…
Reference in New Issue