diff --git a/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgPipelineAggregator.java b/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgPipelineAggregator.java index c6e3d943bb8..cc1e6682e70 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgPipelineAggregator.java +++ b/src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgPipelineAggregator.java @@ -111,7 +111,6 @@ public class MovAvgPipelineAggregator extends PipelineAggregator { EvictingQueue values = EvictingQueue.create(this.window); long lastKey = 0; - long interval = Long.MAX_VALUE; Object currentKey; for (InternalHistogram.Bucket bucket : buckets) { @@ -135,10 +134,8 @@ public class MovAvgPipelineAggregator extends PipelineAggregator { if (predict > 0) { if (currentKey instanceof Number) { - interval = Math.min(interval, ((Number) bucket.getKey()).longValue() - lastKey); lastKey = ((Number) bucket.getKey()).longValue(); } else if (currentKey instanceof DateTime) { - interval = Math.min(interval, ((DateTime) bucket.getKey()).getMillis() - lastKey); lastKey = ((DateTime) bucket.getKey()).getMillis(); } else { throw new AggregationExecutionException("Expected key of type Number or DateTime but got [" + currentKey + "]"); @@ -147,7 +144,6 @@ public class MovAvgPipelineAggregator extends PipelineAggregator { } - if (buckets.size() > 0 && predict > 0) { boolean keyed; @@ -159,9 +155,11 @@ public class MovAvgPipelineAggregator extends PipelineAggregator { for (int i = 0; i < predictions.length; i++) { List aggs = new ArrayList<>(); aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList(), metaData())); - InternalHistogram.Bucket newBucket = factory.createBucket(lastKey + (interval * (i + 1)), 0, new InternalAggregations( + long newKey = histo.getRounding().nextRoundingValue(lastKey); + InternalHistogram.Bucket newBucket = factory.createBucket(newKey, 0, new InternalAggregations( aggs), keyed, formatter); newBuckets.add(newBucket); + lastKey = newKey; } } diff --git a/src/test/java/org/elasticsearch/search/aggregations/pipeline/moving/avg/MovAvgTests.java b/src/test/java/org/elasticsearch/search/aggregations/pipeline/moving/avg/MovAvgTests.java index 38da141ad5c..0e0eb239ce0 100644 --- a/src/test/java/org/elasticsearch/search/aggregations/pipeline/moving/avg/MovAvgTests.java +++ b/src/test/java/org/elasticsearch/search/aggregations/pipeline/moving/avg/MovAvgTests.java @@ -22,7 +22,6 @@ package org.elasticsearch.search.aggregations.pipeline.moving.avg; import com.google.common.collect.EvictingQueue; -import org.apache.lucene.util.LuceneTestCase; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.action.search.SearchResponse; @@ -32,6 +31,7 @@ import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram; import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram.Bucket; import org.elasticsearch.search.aggregations.metrics.ValuesSourceMetricsAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.avg.Avg; import org.elasticsearch.search.aggregations.pipeline.BucketHelpers; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregationHelperTests; import org.elasticsearch.search.aggregations.pipeline.SimpleValue; @@ -51,7 +51,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.movingAvg; import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; import static org.elasticsearch.search.aggregations.AggregationBuilders.avg; import static org.elasticsearch.search.aggregations.AggregationBuilders.filter; @@ -59,8 +58,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.histogra import static org.elasticsearch.search.aggregations.AggregationBuilders.max; import static org.elasticsearch.search.aggregations.AggregationBuilders.min; import static org.elasticsearch.search.aggregations.AggregationBuilders.range; +import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.movingAvg; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; -import static org.hamcrest.Matchers.*; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.core.IsNull.notNullValue; import static org.hamcrest.core.IsNull.nullValue; @@ -154,6 +157,11 @@ public class MovAvgTests extends ElasticsearchIntegrationTest { .field(INTERVAL_FIELD, 49) .field(GAP_FIELD, 1).endObject())); + for (int i = -10; i < 10; i++) { + builders.add(client().prepareIndex("neg_idx", "type").setSource( + jsonBuilder().startObject().field(INTERVAL_FIELD, i).field(VALUE_FIELD, 10).endObject())); + } + indexRandom(true, builders); ensureSearchable(); } @@ -514,6 +522,56 @@ public class MovAvgTests extends ElasticsearchIntegrationTest { } } + @Test + public void testPredictNegativeKeysAtStart() { + + SearchResponse response = client() + .prepareSearch("neg_idx") + .setTypes("type") + .addAggregation( + histogram("histo") + .field(INTERVAL_FIELD) + .interval(1) + .subAggregation(avg("avg").field(VALUE_FIELD)) + .subAggregation( + movingAvg("movavg_values").window(windowSize).modelBuilder(new SimpleModel.SimpleModelBuilder()) + .gapPolicy(gapPolicy).predict(5).setBucketsPaths("avg"))).execute().actionGet(); + + assertSearchResponse(response); + + InternalHistogram histo = response.getAggregations().get("histo"); + assertThat(histo, notNullValue()); + assertThat(histo.getName(), equalTo("histo")); + List buckets = histo.getBuckets(); + assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(25)); + + for (int i = 0; i < 20; i++) { + Bucket bucket = buckets.get(i); + assertThat(bucket, notNullValue()); + assertThat((long) bucket.getKey(), equalTo((long) i - 10)); + assertThat(bucket.getDocCount(), equalTo(1l)); + Avg avgAgg = bucket.getAggregations().get("avg"); + assertThat(avgAgg, notNullValue()); + assertThat(avgAgg.value(), equalTo(10d)); + SimpleValue movAvgAgg = bucket.getAggregations().get("movavg_values"); + assertThat(movAvgAgg, notNullValue()); + assertThat(movAvgAgg.value(), equalTo(10d)); + } + + for (int i = 20; i < 25; i++) { + System.out.println(i); + Bucket bucket = buckets.get(i); + assertThat(bucket, notNullValue()); + assertThat((long) bucket.getKey(), equalTo((long) i - 10)); + assertThat(bucket.getDocCount(), equalTo(0l)); + Avg avgAgg = bucket.getAggregations().get("avg"); + assertThat(avgAgg, nullValue()); + SimpleValue movAvgAgg = bucket.getAggregations().get("movavg_values"); + assertThat(movAvgAgg, notNullValue()); + assertThat(movAvgAgg.value(), equalTo(10d)); + } + } + @Test public void testSizeZeroWindow() { try {