Aggregations: Fixed Moving Average prediction to calculate the correct keys

The Moving average predict code generated incorrect keys if the key for the first bucket of the histogram was < 0. This fix makes the moving average use the rounding class from the histogram to generate the keys for the new buckets.

Closes #11369
This commit is contained in:
Colin Goodheart-Smithe 2015-05-27 14:34:05 +01:00
parent fc224a0de8
commit 7fbd86aa97
2 changed files with 64 additions and 8 deletions

View File

@ -111,7 +111,6 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
EvictingQueue<Double> values = EvictingQueue.create(this.window);
long lastKey = 0;
long interval = Long.MAX_VALUE;
Object currentKey;
for (InternalHistogram.Bucket bucket : buckets) {
@ -135,10 +134,8 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
if (predict > 0) {
if (currentKey instanceof Number) {
interval = Math.min(interval, ((Number) bucket.getKey()).longValue() - lastKey);
lastKey = ((Number) bucket.getKey()).longValue();
} else if (currentKey instanceof DateTime) {
interval = Math.min(interval, ((DateTime) bucket.getKey()).getMillis() - lastKey);
lastKey = ((DateTime) bucket.getKey()).getMillis();
} else {
throw new AggregationExecutionException("Expected key of type Number or DateTime but got [" + currentKey + "]");
@ -147,7 +144,6 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
}
if (buckets.size() > 0 && predict > 0) {
boolean keyed;
@ -159,9 +155,11 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
for (int i = 0; i < predictions.length; i++) {
List<InternalAggregation> aggs = new ArrayList<>();
aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList<PipelineAggregator>(), metaData()));
InternalHistogram.Bucket newBucket = factory.createBucket(lastKey + (interval * (i + 1)), 0, new InternalAggregations(
long newKey = histo.getRounding().nextRoundingValue(lastKey);
InternalHistogram.Bucket newBucket = factory.createBucket(newKey, 0, new InternalAggregations(
aggs), keyed, formatter);
newBuckets.add(newBucket);
lastKey = newKey;
}
}

View File

@ -22,7 +22,6 @@ package org.elasticsearch.search.aggregations.pipeline.moving.avg;
import com.google.common.collect.EvictingQueue;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchResponse;
@ -32,6 +31,7 @@ import org.elasticsearch.search.aggregations.bucket.histogram.Histogram;
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram;
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram.Bucket;
import org.elasticsearch.search.aggregations.metrics.ValuesSourceMetricsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.avg.Avg;
import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregationHelperTests;
import org.elasticsearch.search.aggregations.pipeline.SimpleValue;
@ -51,7 +51,6 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.movingAvg;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.search.aggregations.AggregationBuilders.avg;
import static org.elasticsearch.search.aggregations.AggregationBuilders.filter;
@ -59,8 +58,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.histogra
import static org.elasticsearch.search.aggregations.AggregationBuilders.max;
import static org.elasticsearch.search.aggregations.AggregationBuilders.min;
import static org.elasticsearch.search.aggregations.AggregationBuilders.range;
import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.movingAvg;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.hamcrest.Matchers.*;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.core.IsNull.notNullValue;
import static org.hamcrest.core.IsNull.nullValue;
@ -154,6 +157,11 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
.field(INTERVAL_FIELD, 49)
.field(GAP_FIELD, 1).endObject()));
for (int i = -10; i < 10; i++) {
builders.add(client().prepareIndex("neg_idx", "type").setSource(
jsonBuilder().startObject().field(INTERVAL_FIELD, i).field(VALUE_FIELD, 10).endObject()));
}
indexRandom(true, builders);
ensureSearchable();
}
@ -514,6 +522,56 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
}
}
@Test
public void testPredictNegativeKeysAtStart() {
SearchResponse response = client()
.prepareSearch("neg_idx")
.setTypes("type")
.addAggregation(
histogram("histo")
.field(INTERVAL_FIELD)
.interval(1)
.subAggregation(avg("avg").field(VALUE_FIELD))
.subAggregation(
movingAvg("movavg_values").window(windowSize).modelBuilder(new SimpleModel.SimpleModelBuilder())
.gapPolicy(gapPolicy).predict(5).setBucketsPaths("avg"))).execute().actionGet();
assertSearchResponse(response);
InternalHistogram<Bucket> histo = response.getAggregations().get("histo");
assertThat(histo, notNullValue());
assertThat(histo.getName(), equalTo("histo"));
List<? extends Bucket> buckets = histo.getBuckets();
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(25));
for (int i = 0; i < 20; i++) {
Bucket bucket = buckets.get(i);
assertThat(bucket, notNullValue());
assertThat((long) bucket.getKey(), equalTo((long) i - 10));
assertThat(bucket.getDocCount(), equalTo(1l));
Avg avgAgg = bucket.getAggregations().get("avg");
assertThat(avgAgg, notNullValue());
assertThat(avgAgg.value(), equalTo(10d));
SimpleValue movAvgAgg = bucket.getAggregations().get("movavg_values");
assertThat(movAvgAgg, notNullValue());
assertThat(movAvgAgg.value(), equalTo(10d));
}
for (int i = 20; i < 25; i++) {
System.out.println(i);
Bucket bucket = buckets.get(i);
assertThat(bucket, notNullValue());
assertThat((long) bucket.getKey(), equalTo((long) i - 10));
assertThat(bucket.getDocCount(), equalTo(0l));
Avg avgAgg = bucket.getAggregations().get("avg");
assertThat(avgAgg, nullValue());
SimpleValue movAvgAgg = bucket.getAggregations().get("movavg_values");
assertThat(movAvgAgg, notNullValue());
assertThat(movAvgAgg.value(), equalTo(10d));
}
}
@Test
public void testSizeZeroWindow() {
try {