diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index f80c135f882..0a04cc950e3 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -70,6 +70,8 @@ Fixed prerelease version of elasticsearch in the `deb` package to sort before GA Fail snapshot operations early when creating or deleting a snapshot on a repository that has been written to by an older Elasticsearch after writing to it with a newer Elasticsearch version. ({pull}30140[#30140]) +Fix NPE when CumulativeSum agg encounters null value/empty bucket ({pull}29641[#29641]) + //[float] //=== Regressions @@ -100,6 +102,7 @@ multi-argument versions. ({pull}29623[#29623]) Do not ignore request analysis/similarity settings on index resize operations when the source index already contains such settings ({pull}30216[#30216]) +Fix NPE when CumulativeSum agg encounters null value/empty bucket ({pull}29641[#29641]) //[float] //=== Regressions diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/cumulativesum/CumulativeSumPipelineAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/cumulativesum/CumulativeSumPipelineAggregator.java index 8a1b70fdd14..e1441132452 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/cumulativesum/CumulativeSumPipelineAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/cumulativesum/CumulativeSumPipelineAggregator.java @@ -79,11 +79,17 @@ public class CumulativeSumPipelineAggregator extends PipelineAggregator { double sum = 0; for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) { Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], GapPolicy.INSERT_ZEROS); - sum += thisBucketValue; - List aggs = StreamSupport.stream(bucket.getAggregations().spliterator(), false).map((p) -> { - return (InternalAggregation) p; - }).collect(Collectors.toList()); - aggs.add(new InternalSimpleValue(name(), sum, formatter, new ArrayList(), metaData())); + + // Only increment the sum if it's a finite value, otherwise "increment by zero" is correct + if (thisBucketValue != null && thisBucketValue.isInfinite() == false && thisBucketValue.isNaN() == false) { + sum += thisBucketValue; + } + + List aggs = StreamSupport + .stream(bucket.getAggregations().spliterator(), false) + .map((p) -> (InternalAggregation) p) + .collect(Collectors.toList()); + aggs.add(new InternalSimpleValue(name(), sum, formatter, new ArrayList<>(), metaData())); Bucket newBucket = factory.createBucket(factory.getKey(bucket), bucket.getDocCount(), new InternalAggregations(aggs)); newBuckets.add(newBucket); } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/CumulativeSumAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/CumulativeSumAggregatorTests.java new file mode 100644 index 00000000000..fa46921a941 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/CumulativeSumAggregatorTests.java @@ -0,0 +1,316 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.pipeline; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.elasticsearch.common.CheckedConsumer; +import org.elasticsearch.index.mapper.DateFieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregatorTestCase; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; +import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.avg.InternalAvg; +import org.elasticsearch.search.aggregations.metrics.sum.Sum; +import org.elasticsearch.search.aggregations.metrics.sum.SumAggregationBuilder; +import org.elasticsearch.search.aggregations.pipeline.cumulativesum.CumulativeSumPipelineAggregationBuilder; +import org.elasticsearch.search.aggregations.pipeline.derivative.DerivativePipelineAggregationBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.IsNull.notNullValue; + +public class CumulativeSumAggregatorTests extends AggregatorTestCase { + + private static final String HISTO_FIELD = "histo"; + private static final String VALUE_FIELD = "value_field"; + + private static final List datasetTimes = Arrays.asList( + "2017-01-01T01:07:45", + "2017-01-02T03:43:34", + "2017-01-03T04:11:00", + "2017-01-04T05:11:31", + "2017-01-05T08:24:05", + "2017-01-06T13:09:32", + "2017-01-07T13:47:43", + "2017-01-08T16:14:34", + "2017-01-09T17:09:50", + "2017-01-10T22:55:46"); + + private static final List datasetValues = Arrays.asList(1,2,3,4,5,6,7,8,9,10); + + public void testSimple() throws IOException { + Query query = new MatchAllDocsQuery(); + + DateHistogramAggregationBuilder aggBuilder = new DateHistogramAggregationBuilder("histo"); + aggBuilder.dateHistogramInterval(DateHistogramInterval.DAY).field(HISTO_FIELD); + aggBuilder.subAggregation(new AvgAggregationBuilder("the_avg").field(VALUE_FIELD)); + aggBuilder.subAggregation(new CumulativeSumPipelineAggregationBuilder("cusum", "the_avg")); + + executeTestCase(query, aggBuilder, histogram -> { + assertEquals(10, ((Histogram)histogram).getBuckets().size()); + List buckets = ((Histogram)histogram).getBuckets(); + double sum = 0.0; + for (Histogram.Bucket bucket : buckets) { + sum += ((InternalAvg) (bucket.getAggregations().get("the_avg"))).value(); + assertThat(((InternalSimpleValue) (bucket.getAggregations().get("cusum"))).value(), equalTo(sum)); + } + }); + } + + /** + * First value from a derivative is null, so this makes sure the cusum can handle that + */ + public void testDerivative() throws IOException { + Query query = new MatchAllDocsQuery(); + + DateHistogramAggregationBuilder aggBuilder = new DateHistogramAggregationBuilder("histo"); + aggBuilder.dateHistogramInterval(DateHistogramInterval.DAY).field(HISTO_FIELD); + aggBuilder.subAggregation(new AvgAggregationBuilder("the_avg").field(VALUE_FIELD)); + aggBuilder.subAggregation(new DerivativePipelineAggregationBuilder("the_deriv", "the_avg")); + aggBuilder.subAggregation(new CumulativeSumPipelineAggregationBuilder("cusum", "the_deriv")); + + executeTestCase(query, aggBuilder, histogram -> { + assertEquals(10, ((Histogram)histogram).getBuckets().size()); + List buckets = ((Histogram)histogram).getBuckets(); + double sum = 0.0; + for (int i = 0; i < buckets.size(); i++) { + if (i == 0) { + assertThat(((InternalSimpleValue)(buckets.get(i).getAggregations().get("cusum"))).value(), equalTo(0.0)); + } else { + sum += 1.0; + assertThat(((InternalSimpleValue)(buckets.get(i).getAggregations().get("cusum"))).value(), equalTo(sum)); + } + } + }); + } + + public void testDocCount() throws IOException { + Query query = new MatchAllDocsQuery(); + + int numDocs = randomIntBetween(6, 20); + int interval = randomIntBetween(2, 5); + + int minRandomValue = 0; + int maxRandomValue = 20; + + int numValueBuckets = ((maxRandomValue - minRandomValue) / interval) + 1; + long[] valueCounts = new long[numValueBuckets]; + + HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo") + .field(VALUE_FIELD) + .interval(interval) + .extendedBounds(minRandomValue, maxRandomValue); + aggBuilder.subAggregation(new CumulativeSumPipelineAggregationBuilder("cusum", "_count")); + + executeTestCase(query, aggBuilder, histogram -> { + List buckets = ((Histogram)histogram).getBuckets(); + + assertThat(buckets.size(), equalTo(numValueBuckets)); + + double sum = 0; + for (int i = 0; i < numValueBuckets; ++i) { + Histogram.Bucket bucket = buckets.get(i); + assertThat(bucket, notNullValue()); + assertThat(((Number) bucket.getKey()).longValue(), equalTo((long) i * interval)); + assertThat(bucket.getDocCount(), equalTo(valueCounts[i])); + sum += bucket.getDocCount(); + InternalSimpleValue cumulativeSumValue = bucket.getAggregations().get("cusum"); + assertThat(cumulativeSumValue, notNullValue()); + assertThat(cumulativeSumValue.getName(), equalTo("cusum")); + assertThat(cumulativeSumValue.value(), equalTo(sum)); + } + }, indexWriter -> { + Document document = new Document(); + + for (int i = 0; i < numDocs; i++) { + int fieldValue = randomIntBetween(minRandomValue, maxRandomValue); + document.add(new NumericDocValuesField(VALUE_FIELD, fieldValue)); + final int bucket = (fieldValue / interval); + valueCounts[bucket]++; + + indexWriter.addDocument(document); + document.clear(); + } + }); + } + + public void testMetric() throws IOException { + Query query = new MatchAllDocsQuery(); + + int numDocs = randomIntBetween(6, 20); + int interval = randomIntBetween(2, 5); + + int minRandomValue = 0; + int maxRandomValue = 20; + + int numValueBuckets = ((maxRandomValue - minRandomValue) / interval) + 1; + long[] valueCounts = new long[numValueBuckets]; + + HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo") + .field(VALUE_FIELD) + .interval(interval) + .extendedBounds(minRandomValue, maxRandomValue); + aggBuilder.subAggregation(new SumAggregationBuilder("sum").field(VALUE_FIELD)); + aggBuilder.subAggregation(new CumulativeSumPipelineAggregationBuilder("cusum", "sum")); + + executeTestCase(query, aggBuilder, histogram -> { + List buckets = ((Histogram)histogram).getBuckets(); + + assertThat(buckets.size(), equalTo(numValueBuckets)); + + double bucketSum = 0; + for (int i = 0; i < numValueBuckets; ++i) { + Histogram.Bucket bucket = buckets.get(i); + assertThat(bucket, notNullValue()); + assertThat(((Number) bucket.getKey()).longValue(), equalTo((long) i * interval)); + Sum sum = bucket.getAggregations().get("sum"); + assertThat(sum, notNullValue()); + bucketSum += sum.value(); + + InternalSimpleValue sumBucketValue = bucket.getAggregations().get("cusum"); + assertThat(sumBucketValue, notNullValue()); + assertThat(sumBucketValue.getName(), equalTo("cusum")); + assertThat(sumBucketValue.value(), equalTo(bucketSum)); + } + }, indexWriter -> { + Document document = new Document(); + + for (int i = 0; i < numDocs; i++) { + int fieldValue = randomIntBetween(minRandomValue, maxRandomValue); + document.add(new NumericDocValuesField(VALUE_FIELD, fieldValue)); + final int bucket = (fieldValue / interval); + valueCounts[bucket]++; + + indexWriter.addDocument(document); + document.clear(); + } + }); + } + + public void testNoBuckets() throws IOException { + int numDocs = randomIntBetween(6, 20); + int interval = randomIntBetween(2, 5); + + int minRandomValue = 0; + int maxRandomValue = 20; + + int numValueBuckets = ((maxRandomValue - minRandomValue) / interval) + 1; + long[] valueCounts = new long[numValueBuckets]; + + Query query = new MatchNoDocsQuery(); + + HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo") + .field(VALUE_FIELD) + .interval(interval); + aggBuilder.subAggregation(new SumAggregationBuilder("sum").field(VALUE_FIELD)); + aggBuilder.subAggregation(new CumulativeSumPipelineAggregationBuilder("cusum", "sum")); + + executeTestCase(query, aggBuilder, histogram -> { + List buckets = ((Histogram)histogram).getBuckets(); + + assertThat(buckets.size(), equalTo(0)); + + }, indexWriter -> { + Document document = new Document(); + + for (int i = 0; i < numDocs; i++) { + int fieldValue = randomIntBetween(minRandomValue, maxRandomValue); + document.add(new NumericDocValuesField(VALUE_FIELD, fieldValue)); + final int bucket = (fieldValue / interval); + valueCounts[bucket]++; + + indexWriter.addDocument(document); + document.clear(); + } + }); + } + + @SuppressWarnings("unchecked") + private void executeTestCase(Query query, AggregationBuilder aggBuilder, Consumer verify) throws IOException { + executeTestCase(query, aggBuilder, verify, indexWriter -> { + Document document = new Document(); + int counter = 0; + for (String date : datasetTimes) { + if (frequently()) { + indexWriter.commit(); + } + + long instant = asLong(date); + document.add(new SortedNumericDocValuesField(HISTO_FIELD, instant)); + document.add(new NumericDocValuesField(VALUE_FIELD, datasetValues.get(counter))); + indexWriter.addDocument(document); + document.clear(); + counter += 1; + } + }); + } + + @SuppressWarnings("unchecked") + private void executeTestCase(Query query, AggregationBuilder aggBuilder, Consumer verify, + CheckedConsumer setup) throws IOException { + + try (Directory directory = newDirectory()) { + try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + setup.accept(indexWriter); + } + + try (IndexReader indexReader = DirectoryReader.open(directory)) { + IndexSearcher indexSearcher = newSearcher(indexReader, true, true); + + DateFieldMapper.Builder builder = new DateFieldMapper.Builder("_name"); + DateFieldMapper.DateFieldType fieldType = builder.fieldType(); + fieldType.setHasDocValues(true); + fieldType.setName(HISTO_FIELD); + + MappedFieldType valueFieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG); + valueFieldType.setHasDocValues(true); + valueFieldType.setName("value_field"); + + InternalAggregation histogram; + histogram = searchAndReduce(indexSearcher, query, aggBuilder, new MappedFieldType[]{fieldType, valueFieldType}); + verify.accept(histogram); + } + } + } + + private static long asLong(String dateTime) { + return DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER.parser().parseDateTime(dateTime).getMillis(); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/CumulativeSumIT.java b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/CumulativeSumIT.java deleted file mode 100644 index 6a748bd3c84..00000000000 --- a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/CumulativeSumIT.java +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.search.aggregations.pipeline; - -import org.elasticsearch.action.index.IndexRequestBuilder; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.search.aggregations.bucket.histogram.ExtendedBounds; -import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; -import org.elasticsearch.search.aggregations.bucket.histogram.Histogram.Bucket; -import org.elasticsearch.search.aggregations.metrics.sum.Sum; -import org.elasticsearch.test.ESIntegTestCase; - -import java.util.ArrayList; -import java.util.List; - -import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.elasticsearch.index.query.QueryBuilders.rangeQuery; -import static org.elasticsearch.search.aggregations.AggregationBuilders.histogram; -import static org.elasticsearch.search.aggregations.AggregationBuilders.sum; -import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.cumulativeSum; -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.core.IsNull.notNullValue; - -@ESIntegTestCase.SuiteScopeTestCase -public class CumulativeSumIT extends ESIntegTestCase { - - private static final String SINGLE_VALUED_FIELD_NAME = "l_value"; - - static int numDocs; - static int interval; - static int minRandomValue; - static int maxRandomValue; - static int numValueBuckets; - static long[] valueCounts; - - @Override - public void setupSuiteScopeCluster() throws Exception { - createIndex("idx"); - createIndex("idx_unmapped"); - - numDocs = randomIntBetween(6, 20); - interval = randomIntBetween(2, 5); - - minRandomValue = 0; - maxRandomValue = 20; - - numValueBuckets = ((maxRandomValue - minRandomValue) / interval) + 1; - valueCounts = new long[numValueBuckets]; - - List builders = new ArrayList<>(); - - for (int i = 0; i < numDocs; i++) { - int fieldValue = randomIntBetween(minRandomValue, maxRandomValue); - builders.add(client().prepareIndex("idx", "type").setSource( - jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, fieldValue).field("tag", "tag" + (i % interval)) - .endObject())); - final int bucket = (fieldValue / interval); // + (fieldValue < 0 ? -1 : 0) - (minRandomValue / interval - 1); - valueCounts[bucket]++; - } - - assertAcked(prepareCreate("empty_bucket_idx").addMapping("type", SINGLE_VALUED_FIELD_NAME, "type=integer")); - for (int i = 0; i < 2; i++) { - builders.add(client().prepareIndex("empty_bucket_idx", "type", "" + i).setSource( - jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, i * 2).endObject())); - } - indexRandom(true, builders); - ensureSearchable(); - } - - public void testDocCount() throws Exception { - SearchResponse response = client().prepareSearch("idx") - .addAggregation(histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval) - .extendedBounds(minRandomValue, maxRandomValue) - .subAggregation(cumulativeSum("cumulative_sum", "_count"))).execute().actionGet(); - - assertSearchResponse(response); - - Histogram histo = response.getAggregations().get("histo"); - assertThat(histo, notNullValue()); - assertThat(histo.getName(), equalTo("histo")); - List buckets = histo.getBuckets(); - assertThat(buckets.size(), equalTo(numValueBuckets)); - - double sum = 0; - for (int i = 0; i < numValueBuckets; ++i) { - Histogram.Bucket bucket = buckets.get(i); - assertThat(bucket, notNullValue()); - assertThat(((Number) bucket.getKey()).longValue(), equalTo((long) i * interval)); - assertThat(bucket.getDocCount(), equalTo(valueCounts[i])); - sum += bucket.getDocCount(); - InternalSimpleValue cumulativeSumValue = bucket.getAggregations().get("cumulative_sum"); - assertThat(cumulativeSumValue, notNullValue()); - assertThat(cumulativeSumValue.getName(), equalTo("cumulative_sum")); - assertThat(cumulativeSumValue.value(), equalTo(sum)); - } - - } - - public void testMetric() throws Exception { - SearchResponse response = client() - .prepareSearch("idx") - .addAggregation( - histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval) - .extendedBounds(minRandomValue, maxRandomValue) - .subAggregation(sum("sum").field(SINGLE_VALUED_FIELD_NAME)) - .subAggregation(cumulativeSum("cumulative_sum", "sum"))).execute().actionGet(); - - assertSearchResponse(response); - - Histogram histo = response.getAggregations().get("histo"); - assertThat(histo, notNullValue()); - assertThat(histo.getName(), equalTo("histo")); - List buckets = histo.getBuckets(); - - double bucketSum = 0; - for (int i = 0; i < buckets.size(); ++i) { - Bucket bucket = buckets.get(i); - assertThat(bucket, notNullValue()); - assertThat(((Number) bucket.getKey()).longValue(), equalTo((long) i * interval)); - Sum sum = bucket.getAggregations().get("sum"); - assertThat(sum, notNullValue()); - bucketSum += sum.value(); - - InternalSimpleValue sumBucketValue = bucket.getAggregations().get("cumulative_sum"); - assertThat(sumBucketValue, notNullValue()); - assertThat(sumBucketValue.getName(), equalTo("cumulative_sum")); - assertThat(sumBucketValue.value(), equalTo(bucketSum)); - } - } - - public void testNoBuckets() throws Exception { - SearchResponse response = client() - .prepareSearch("idx") - .setQuery(rangeQuery(SINGLE_VALUED_FIELD_NAME).lt(minRandomValue)) - .addAggregation( - histogram("histo").field(SINGLE_VALUED_FIELD_NAME).interval(interval) - .subAggregation(sum("sum").field(SINGLE_VALUED_FIELD_NAME)) - .subAggregation(cumulativeSum("cumulative_sum", "sum"))).execute().actionGet(); - - assertSearchResponse(response); - - Histogram histo = response.getAggregations().get("histo"); - assertThat(histo, notNullValue()); - assertThat(histo.getName(), equalTo("histo")); - List buckets = histo.getBuckets(); - assertThat(buckets.size(), equalTo(0)); - } -} diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 010eb1d7cdc..73ac501ec1d 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -61,6 +61,8 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.indices.fielddata.cache.IndicesFieldDataCache; import org.elasticsearch.mock.orig.Mockito; +import org.elasticsearch.search.aggregations.MultiBucketConsumerService.MultiBucketConsumer; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.fetch.FetchPhase; import org.elasticsearch.search.fetch.subphase.DocValueFieldsFetchSubPhase; import org.elasticsearch.search.fetch.subphase.FetchSourceSubPhase; @@ -79,6 +81,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import static org.elasticsearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS; import static org.mockito.Matchers.anyObject; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.doAnswer; @@ -369,6 +372,11 @@ public abstract class AggregatorTestCase extends ESTestCase { @SuppressWarnings("unchecked") A internalAgg = (A) aggs.get(0).doReduce(aggs, context); + if (internalAgg.pipelineAggregators().size() > 0) { + for (PipelineAggregator pipelineAggregator : internalAgg.pipelineAggregators()) { + internalAgg = (A) pipelineAggregator.reduce(internalAgg, context); + } + } InternalAggregationTestCase.assertMultiBucketConsumer(internalAgg, reduceBucketConsumer); return internalAgg; }