Aggregations: Fixed Moving Average prediction to calculate the correct keys
The Moving average predict code generated incorrect keys if the key for the first bucket of the histogram was < 0. This fix makes the moving average use the rounding class from the histogram to generate the keys for the new buckets. Closes #11369
This commit is contained in:
parent
fc224a0de8
commit
7fbd86aa97
|
@ -111,7 +111,6 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
|
||||||
EvictingQueue<Double> values = EvictingQueue.create(this.window);
|
EvictingQueue<Double> values = EvictingQueue.create(this.window);
|
||||||
|
|
||||||
long lastKey = 0;
|
long lastKey = 0;
|
||||||
long interval = Long.MAX_VALUE;
|
|
||||||
Object currentKey;
|
Object currentKey;
|
||||||
|
|
||||||
for (InternalHistogram.Bucket bucket : buckets) {
|
for (InternalHistogram.Bucket bucket : buckets) {
|
||||||
|
@ -135,10 +134,8 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
|
||||||
|
|
||||||
if (predict > 0) {
|
if (predict > 0) {
|
||||||
if (currentKey instanceof Number) {
|
if (currentKey instanceof Number) {
|
||||||
interval = Math.min(interval, ((Number) bucket.getKey()).longValue() - lastKey);
|
|
||||||
lastKey = ((Number) bucket.getKey()).longValue();
|
lastKey = ((Number) bucket.getKey()).longValue();
|
||||||
} else if (currentKey instanceof DateTime) {
|
} else if (currentKey instanceof DateTime) {
|
||||||
interval = Math.min(interval, ((DateTime) bucket.getKey()).getMillis() - lastKey);
|
|
||||||
lastKey = ((DateTime) bucket.getKey()).getMillis();
|
lastKey = ((DateTime) bucket.getKey()).getMillis();
|
||||||
} else {
|
} else {
|
||||||
throw new AggregationExecutionException("Expected key of type Number or DateTime but got [" + currentKey + "]");
|
throw new AggregationExecutionException("Expected key of type Number or DateTime but got [" + currentKey + "]");
|
||||||
|
@ -147,7 +144,6 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (buckets.size() > 0 && predict > 0) {
|
if (buckets.size() > 0 && predict > 0) {
|
||||||
|
|
||||||
boolean keyed;
|
boolean keyed;
|
||||||
|
@ -159,9 +155,11 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
|
||||||
for (int i = 0; i < predictions.length; i++) {
|
for (int i = 0; i < predictions.length; i++) {
|
||||||
List<InternalAggregation> aggs = new ArrayList<>();
|
List<InternalAggregation> aggs = new ArrayList<>();
|
||||||
aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList<PipelineAggregator>(), metaData()));
|
aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList<PipelineAggregator>(), metaData()));
|
||||||
InternalHistogram.Bucket newBucket = factory.createBucket(lastKey + (interval * (i + 1)), 0, new InternalAggregations(
|
long newKey = histo.getRounding().nextRoundingValue(lastKey);
|
||||||
|
InternalHistogram.Bucket newBucket = factory.createBucket(newKey, 0, new InternalAggregations(
|
||||||
aggs), keyed, formatter);
|
aggs), keyed, formatter);
|
||||||
newBuckets.add(newBucket);
|
newBuckets.add(newBucket);
|
||||||
|
lastKey = newKey;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ package org.elasticsearch.search.aggregations.pipeline.moving.avg;
|
||||||
|
|
||||||
import com.google.common.collect.EvictingQueue;
|
import com.google.common.collect.EvictingQueue;
|
||||||
|
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
|
||||||
import org.elasticsearch.action.index.IndexRequestBuilder;
|
import org.elasticsearch.action.index.IndexRequestBuilder;
|
||||||
import org.elasticsearch.action.search.SearchPhaseExecutionException;
|
import org.elasticsearch.action.search.SearchPhaseExecutionException;
|
||||||
import org.elasticsearch.action.search.SearchResponse;
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
|
@ -32,6 +31,7 @@ import org.elasticsearch.search.aggregations.bucket.histogram.Histogram;
|
||||||
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram;
|
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram;
|
||||||
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram.Bucket;
|
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram.Bucket;
|
||||||
import org.elasticsearch.search.aggregations.metrics.ValuesSourceMetricsAggregationBuilder;
|
import org.elasticsearch.search.aggregations.metrics.ValuesSourceMetricsAggregationBuilder;
|
||||||
|
import org.elasticsearch.search.aggregations.metrics.avg.Avg;
|
||||||
import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
|
import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
|
||||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregationHelperTests;
|
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregationHelperTests;
|
||||||
import org.elasticsearch.search.aggregations.pipeline.SimpleValue;
|
import org.elasticsearch.search.aggregations.pipeline.SimpleValue;
|
||||||
|
@ -51,7 +51,6 @@ import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.movingAvg;
|
|
||||||
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
|
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
|
||||||
import static org.elasticsearch.search.aggregations.AggregationBuilders.avg;
|
import static org.elasticsearch.search.aggregations.AggregationBuilders.avg;
|
||||||
import static org.elasticsearch.search.aggregations.AggregationBuilders.filter;
|
import static org.elasticsearch.search.aggregations.AggregationBuilders.filter;
|
||||||
|
@ -59,8 +58,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.histogra
|
||||||
import static org.elasticsearch.search.aggregations.AggregationBuilders.max;
|
import static org.elasticsearch.search.aggregations.AggregationBuilders.max;
|
||||||
import static org.elasticsearch.search.aggregations.AggregationBuilders.min;
|
import static org.elasticsearch.search.aggregations.AggregationBuilders.min;
|
||||||
import static org.elasticsearch.search.aggregations.AggregationBuilders.range;
|
import static org.elasticsearch.search.aggregations.AggregationBuilders.range;
|
||||||
|
import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.movingAvg;
|
||||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
|
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
|
||||||
import static org.hamcrest.Matchers.*;
|
import static org.hamcrest.Matchers.closeTo;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
||||||
|
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||||
import static org.hamcrest.core.IsNull.notNullValue;
|
import static org.hamcrest.core.IsNull.notNullValue;
|
||||||
import static org.hamcrest.core.IsNull.nullValue;
|
import static org.hamcrest.core.IsNull.nullValue;
|
||||||
|
|
||||||
|
@ -154,6 +157,11 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
|
||||||
.field(INTERVAL_FIELD, 49)
|
.field(INTERVAL_FIELD, 49)
|
||||||
.field(GAP_FIELD, 1).endObject()));
|
.field(GAP_FIELD, 1).endObject()));
|
||||||
|
|
||||||
|
for (int i = -10; i < 10; i++) {
|
||||||
|
builders.add(client().prepareIndex("neg_idx", "type").setSource(
|
||||||
|
jsonBuilder().startObject().field(INTERVAL_FIELD, i).field(VALUE_FIELD, 10).endObject()));
|
||||||
|
}
|
||||||
|
|
||||||
indexRandom(true, builders);
|
indexRandom(true, builders);
|
||||||
ensureSearchable();
|
ensureSearchable();
|
||||||
}
|
}
|
||||||
|
@ -514,6 +522,56 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPredictNegativeKeysAtStart() {
|
||||||
|
|
||||||
|
SearchResponse response = client()
|
||||||
|
.prepareSearch("neg_idx")
|
||||||
|
.setTypes("type")
|
||||||
|
.addAggregation(
|
||||||
|
histogram("histo")
|
||||||
|
.field(INTERVAL_FIELD)
|
||||||
|
.interval(1)
|
||||||
|
.subAggregation(avg("avg").field(VALUE_FIELD))
|
||||||
|
.subAggregation(
|
||||||
|
movingAvg("movavg_values").window(windowSize).modelBuilder(new SimpleModel.SimpleModelBuilder())
|
||||||
|
.gapPolicy(gapPolicy).predict(5).setBucketsPaths("avg"))).execute().actionGet();
|
||||||
|
|
||||||
|
assertSearchResponse(response);
|
||||||
|
|
||||||
|
InternalHistogram<Bucket> histo = response.getAggregations().get("histo");
|
||||||
|
assertThat(histo, notNullValue());
|
||||||
|
assertThat(histo.getName(), equalTo("histo"));
|
||||||
|
List<? extends Bucket> buckets = histo.getBuckets();
|
||||||
|
assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(25));
|
||||||
|
|
||||||
|
for (int i = 0; i < 20; i++) {
|
||||||
|
Bucket bucket = buckets.get(i);
|
||||||
|
assertThat(bucket, notNullValue());
|
||||||
|
assertThat((long) bucket.getKey(), equalTo((long) i - 10));
|
||||||
|
assertThat(bucket.getDocCount(), equalTo(1l));
|
||||||
|
Avg avgAgg = bucket.getAggregations().get("avg");
|
||||||
|
assertThat(avgAgg, notNullValue());
|
||||||
|
assertThat(avgAgg.value(), equalTo(10d));
|
||||||
|
SimpleValue movAvgAgg = bucket.getAggregations().get("movavg_values");
|
||||||
|
assertThat(movAvgAgg, notNullValue());
|
||||||
|
assertThat(movAvgAgg.value(), equalTo(10d));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 20; i < 25; i++) {
|
||||||
|
System.out.println(i);
|
||||||
|
Bucket bucket = buckets.get(i);
|
||||||
|
assertThat(bucket, notNullValue());
|
||||||
|
assertThat((long) bucket.getKey(), equalTo((long) i - 10));
|
||||||
|
assertThat(bucket.getDocCount(), equalTo(0l));
|
||||||
|
Avg avgAgg = bucket.getAggregations().get("avg");
|
||||||
|
assertThat(avgAgg, nullValue());
|
||||||
|
SimpleValue movAvgAgg = bucket.getAggregations().get("movavg_values");
|
||||||
|
assertThat(movAvgAgg, notNullValue());
|
||||||
|
assertThat(movAvgAgg.value(), equalTo(10d));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSizeZeroWindow() {
|
public void testSizeZeroWindow() {
|
||||||
try {
|
try {
|
||||||
|
|
Loading…
Reference in New Issue