From 3df1c76f9b88b025baba5bb495d931f0417c8530 Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Mon, 5 Aug 2019 12:15:42 -0400 Subject: [PATCH] Allow pipeline aggs to select specific buckets from multi-bucket aggs (#44179) This adjusts the `buckets_path` parser so that pipeline aggs can select specific buckets (via their bucket keys) instead of fetching the entire set of buckets. This is useful for bucket_script in particular, which might want specific buckets for calculations. It's possible to workaround this with `filter` aggs, but the workaround is hacky and probably less performant. - Adjusts documentation - Adds a barebones AggregatorTestCase for bucket_script - Tweaks AggTestCase to use getMockScriptService() for reductions and pipelines. Previously pipelines could just pass in a script service for testing, but this didnt work for regular aggs. The new getMockScriptService() method fixes that issue, but needs to be used for pipelines too. This had a knock-on effect of touching MovFn, AvgBucket and ScriptedMetric --- docs/reference/aggregations/pipeline.asciidoc | 53 ++++- .../test/painless/100_terms_agg.yml | 38 ++++ .../InternalMultiBucketAggregation.java | 35 +++- .../InternalMultiBucketAggregationTests.java | 183 ++++++++++++++++++ .../ScriptedMetricAggregatorTests.java | 13 +- .../pipeline/AvgBucketAggregatorTests.java | 4 +- .../pipeline/BucketScriptAggregatorTests.java | 122 ++++++++++++ .../aggregations/pipeline/MovFnUnitTests.java | 71 +++---- .../script/MockScriptEngine.java | 20 +- .../aggregations/AggregatorTestCase.java | 7 +- 10 files changed, 474 insertions(+), 72 deletions(-) create mode 100644 server/src/test/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregationTests.java create mode 100644 server/src/test/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptAggregatorTests.java diff --git a/docs/reference/aggregations/pipeline.asciidoc b/docs/reference/aggregations/pipeline.asciidoc index 81d711cc29c..361c01ca179 100644 --- a/docs/reference/aggregations/pipeline.asciidoc +++ b/docs/reference/aggregations/pipeline.asciidoc @@ -35,11 +35,12 @@ parameter, which follows a specific format: // https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_Form [source,ebnf] -------------------------------------------------- -AGG_SEPARATOR = '>' ; -METRIC_SEPARATOR = '.' ; +AGG_SEPARATOR = `>` ; +METRIC_SEPARATOR = `.` ; AGG_NAME = ; METRIC = ; -PATH = [ , ]* [ , ] ; +MULTIBUCKET_KEY = `[]` +PATH = ? (, )* ( , ) ; -------------------------------------------------- For example, the path `"my_bucket>my_stats.avg"` will path to the `avg` value in the `"my_stats"` metric, which is @@ -111,6 +112,52 @@ POST /_search <1> `buckets_path` instructs this max_bucket aggregation that we want the maximum value of the `sales` aggregation in the `sales_per_month` date histogram. +If a Sibling pipeline agg references a multi-bucket aggregation, such as a `terms` agg, it also has the option to +select specific keys from the multi-bucket. For example, a `bucket_script` could select two specific buckets (via +their bucket keys) to perform the calculation: + +[source,js] +-------------------------------------------------- +POST /_search +{ + "aggs" : { + "sales_per_month" : { + "date_histogram" : { + "field" : "date", + "calendar_interval" : "month" + }, + "aggs": { + "sale_type": { + "terms": { + "field": "type" + }, + "aggs": { + "sales": { + "sum": { + "field": "price" + } + } + } + }, + "hat_vs_bag_ratio": { + "bucket_script": { + "buckets_path": { + "hats": "sale_type['hat']>sales", <1> + "bags": "sale_type['bag']>sales" <1> + }, + "script": "params.hats / params.bags" + } + } + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST[setup:sales] +<1> `buckets_path` selects the hats and bags buckets (via `['hat']`/`['bag']``) to use in the script specifically, +instead of fetching all the buckets from `sale_type` aggregation + [float] === Special Paths diff --git a/modules/lang-painless/src/test/resources/rest-api-spec/test/painless/100_terms_agg.yml b/modules/lang-painless/src/test/resources/rest-api-spec/test/painless/100_terms_agg.yml index 774a5dd59b0..000e1af694d 100644 --- a/modules/lang-painless/src/test/resources/rest-api-spec/test/painless/100_terms_agg.yml +++ b/modules/lang-painless/src/test/resources/rest-api-spec/test/painless/100_terms_agg.yml @@ -102,3 +102,41 @@ setup: - is_false: aggregations.double_terms.buckets.1.key_as_string - match: { aggregations.double_terms.buckets.1.doc_count: 1 } +--- +"Bucket script with keys": + + - do: + search: + rest_total_hits_as_int: true + body: + size: 0 + aggs: + placeholder: + filters: + filters: + - match_all: {} + aggs: + str_terms: + terms: + field: "str" + aggs: + the_avg: + avg: + field: "number" + the_bucket_script: + bucket_script: + buckets_path: + foo: "str_terms['bcd']>the_avg.value" + script: "params.foo" + + - match: { hits.total: 3 } + + - length: { aggregations.placeholder.buckets.0.str_terms.buckets: 2 } + - match: { aggregations.placeholder.buckets.0.str_terms.buckets.0.key: "abc" } + - is_false: aggregations.placeholder.buckets.0.str_terms.buckets.0.key_as_string + - match: { aggregations.placeholder.buckets.0.str_terms.buckets.0.doc_count: 2 } + - match: { aggregations.placeholder.buckets.0.str_terms.buckets.1.key: "bcd" } + - is_false: aggregations.placeholder.buckets.0.str_terms.buckets.1.key_as_string + - match: { aggregations.placeholder.buckets.0.str_terms.buckets.1.doc_count: 1 } + - match: { aggregations.placeholder.buckets.0.the_bucket_script.value: 2.0 } + diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java b/server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java index 00a5271f7f5..e11899fff33 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java @@ -73,16 +73,33 @@ public abstract class InternalMultiBucketAggregation path) { if (path.isEmpty()) { return this; - } else if (path.get(0).equals("_bucket_count")) { - return getBuckets().size(); - } else { - List buckets = getBuckets(); - Object[] propertyArray = new Object[buckets.size()]; - for (int i = 0; i < buckets.size(); i++) { - propertyArray[i] = buckets.get(i).getProperty(getName(), path); - } - return propertyArray; } + return resolvePropertyFromPath(path, getBuckets(), getName()); + } + + static Object resolvePropertyFromPath(List path, List buckets, String name) { + String aggName = path.get(0); + if (aggName.equals("_bucket_count")) { + return buckets.size(); + } + + // This is a bucket key, look through our buckets and see if we can find a match + if (aggName.startsWith("'") && aggName.endsWith("'")) { + for (InternalBucket bucket : buckets) { + if (bucket.getKeyAsString().equals(aggName.substring(1, aggName.length() - 1))) { + return bucket.getProperty(name, path.subList(1, path.size())); + } + } + // No key match, time to give up + throw new InvalidAggregationPathException("Cannot find an key [" + aggName + "] in [" + name + "]"); + } + + Object[] propertyArray = new Object[buckets.size()]; + for (int i = 0; i < buckets.size(); i++) { + propertyArray[i] = buckets.get(i).getProperty(name, path); + } + return propertyArray; + } /** diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregationTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregationTests.java new file mode 100644 index 00000000000..06d664efff5 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregationTests.java @@ -0,0 +1,183 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.bucket.terms.InternalTerms; +import org.elasticsearch.search.aggregations.bucket.terms.LongTerms; +import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; +import org.elasticsearch.search.aggregations.metrics.InternalAvg; +import org.elasticsearch.search.aggregations.support.AggregationPath; +import org.elasticsearch.test.ESTestCase; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.search.aggregations.InternalMultiBucketAggregation.resolvePropertyFromPath; +import static org.hamcrest.Matchers.equalTo; + +public class InternalMultiBucketAggregationTests extends ESTestCase { + + public void testResolveToAgg() { + AggregationPath path = AggregationPath.parse("the_avg"); + List buckets = new ArrayList<>(); + InternalAggregation agg = new InternalAvg("the_avg", 2, 1, + DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap()); + InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg)); + + LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW); + buckets.add(bucket); + + Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"); + assertThat(value[0], equalTo(agg)); + } + + public void testResolveToAggValue() { + AggregationPath path = AggregationPath.parse("the_avg.value"); + List buckets = new ArrayList<>(); + InternalAggregation agg = new InternalAvg("the_avg", 2, 1, + DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap()); + InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg)); + + LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW); + buckets.add(bucket); + + Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"); + assertThat(value[0], equalTo(2.0)); + } + + public void testResolveToNothing() { + AggregationPath path = AggregationPath.parse("foo.value"); + List buckets = new ArrayList<>(); + InternalAggregation agg = new InternalAvg("the_avg", 2, 1, + DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap()); + InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg)); + + LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW); + buckets.add(bucket); + + InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class, + () -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms")); + assertThat(e.getMessage(), equalTo("Cannot find an aggregation named [foo] in [the_long_terms]")); + } + + public void testResolveToUnknown() { + AggregationPath path = AggregationPath.parse("the_avg.unknown"); + List buckets = new ArrayList<>(); + InternalAggregation agg = new InternalAvg("the_avg", 2, 1, + DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap()); + InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg)); + + LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW); + buckets.add(bucket); + + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms")); + assertThat(e.getMessage(), equalTo("path not supported for [the_avg]: [unknown]")); + } + + public void testResolveToBucketCount() { + AggregationPath path = AggregationPath.parse("_bucket_count"); + List buckets = new ArrayList<>(); + InternalAggregation agg = new InternalAvg("the_avg", 2, 1, + DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap()); + InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg)); + + LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW); + buckets.add(bucket); + + Object value = resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"); + assertThat(value, equalTo(1)); + } + + public void testResolveToCount() { + AggregationPath path = AggregationPath.parse("_count"); + List buckets = new ArrayList<>(); + InternalAggregation agg = new InternalAvg("the_avg", 2, 1, + DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap()); + InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg)); + + LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW); + buckets.add(bucket); + + Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"); + assertThat(value[0], equalTo(1L)); + } + + public void testResolveToKey() { + AggregationPath path = AggregationPath.parse("_key"); + List buckets = new ArrayList<>(); + InternalAggregation agg = new InternalAvg("the_avg", 2, 1, + DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap()); + InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg)); + + LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW); + buckets.add(bucket); + + Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"); + assertThat(value[0], equalTo(19L)); + } + + public void testResolveToSpecificBucket() { + AggregationPath path = AggregationPath.parse("string_terms['foo']>the_avg.value"); + + List buckets = new ArrayList<>(); + InternalAggregation agg = new InternalAvg("the_avg", 2, 1, + DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap()); + InternalAggregations internalStringAggs = new InternalAggregations(Collections.singletonList(agg)); + List stringBuckets = Collections.singletonList(new StringTerms.Bucket( + new BytesRef("foo".getBytes(StandardCharsets.UTF_8), 0, "foo".getBytes(StandardCharsets.UTF_8).length), 1, + internalStringAggs, false, 0, DocValueFormat.RAW)); + + InternalTerms termsAgg = new StringTerms("string_terms", BucketOrder.count(false), 1, 0, Collections.emptyList(), + Collections.emptyMap(), DocValueFormat.RAW, 1, false, 0, stringBuckets, 0); + InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(termsAgg)); + LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW); + buckets.add(bucket); + + Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"); + assertThat(value[0], equalTo(2.0)); + } + + public void testResolveToMissingSpecificBucket() { + AggregationPath path = AggregationPath.parse("string_terms['bar']>the_avg.value"); + + List buckets = new ArrayList<>(); + InternalAggregation agg = new InternalAvg("the_avg", 2, 1, + DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap()); + InternalAggregations internalStringAggs = new InternalAggregations(Collections.singletonList(agg)); + List stringBuckets = Collections.singletonList(new StringTerms.Bucket( + new BytesRef("foo".getBytes(StandardCharsets.UTF_8), 0, "foo".getBytes(StandardCharsets.UTF_8).length), 1, + internalStringAggs, false, 0, DocValueFormat.RAW)); + + InternalTerms termsAgg = new StringTerms("string_terms", BucketOrder.count(false), 1, 0, Collections.emptyList(), + Collections.emptyMap(), DocValueFormat.RAW, 1, false, 0, stringBuckets, 0); + InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(termsAgg)); + LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW); + buckets.add(bucket); + + InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class, + () -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms")); + assertThat(e.getMessage(), equalTo("Cannot find an key ['bar'] in [string_terms]")); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java index 54b1a33db06..ea65d218fc6 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java @@ -173,6 +173,17 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { }); } + @Override + protected ScriptService getMockScriptService() { + MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, + SCRIPTS, + Collections.emptyMap()); + Map engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine); + + return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS); + } + + @SuppressWarnings("unchecked") public void testNoDocs() throws IOException { try (Directory directory = newDirectory()) { @@ -311,7 +322,7 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { .initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS) .combineScript(COMBINE_SCRIPT_PARAMS).reduceScript(REDUCE_SCRIPT_PARAMS); ScriptedMetric scriptedMetric = searchAndReduce( - newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder, 0, scriptService); + newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder, 0); // The result value depends on the script params. assertEquals(4803, scriptedMetric.aggregation()); diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/AvgBucketAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/AvgBucketAggregatorTests.java index 4f312a71a83..afea0f13bd7 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/AvgBucketAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/AvgBucketAggregatorTests.java @@ -120,9 +120,9 @@ public class AvgBucketAggregatorTests extends AggregatorTestCase { valueFieldType.setName(VALUE_FIELD); valueFieldType.setHasDocValues(true); - avgResult = searchAndReduce(indexSearcher, query, avgBuilder, 10000, null, + avgResult = searchAndReduce(indexSearcher, query, avgBuilder, 10000, new MappedFieldType[]{fieldType, valueFieldType}); - histogramResult = searchAndReduce(indexSearcher, query, histo, 10000, null, + histogramResult = searchAndReduce(indexSearcher, query, histo, 10000, new MappedFieldType[]{fieldType, valueFieldType}); } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptAggregatorTests.java new file mode 100644 index 00000000000..7feeecedd99 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptAggregatorTests.java @@ -0,0 +1,122 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.pipeline; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.CheckedConsumer; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.script.MockScriptEngine; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.ScriptModule; +import org.elasticsearch.script.ScriptService; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.search.aggregations.AggregatorTestCase; +import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.filter.InternalFilters; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder; +import org.elasticsearch.search.aggregations.support.ValueType; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.function.Consumer; + +import static org.hamcrest.CoreMatchers.equalTo; + +public class BucketScriptAggregatorTests extends AggregatorTestCase { + private final String SCRIPT_NAME = "script_name"; + + @Override + protected ScriptService getMockScriptService() { + MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, + Collections.singletonMap(SCRIPT_NAME, script -> script.get("the_avg")), + Collections.emptyMap()); + Map engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine); + + return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS); + } + + public void testScript() throws IOException { + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER); + fieldType.setName("number_field"); + fieldType.setHasDocValues(true); + MappedFieldType fieldType1 = new KeywordFieldMapper.KeywordFieldType(); + fieldType1.setName("the_field"); + fieldType1.setHasDocValues(true); + + FiltersAggregationBuilder filters = new FiltersAggregationBuilder("placeholder", new MatchAllQueryBuilder()) + .subAggregation(new TermsAggregationBuilder("the_terms", ValueType.STRING).field("the_field") + .subAggregation(new AvgAggregationBuilder("the_avg").field("number_field"))) + .subAggregation(new BucketScriptPipelineAggregationBuilder("bucket_script", + Collections.singletonMap("the_avg", "the_terms['test1']>the_avg.value"), + new Script(ScriptType.INLINE, MockScriptEngine.NAME, SCRIPT_NAME, Collections.emptyMap()))); + + + testCase(filters, new MatchAllDocsQuery(), iw -> { + Document doc = new Document(); + doc.add(new SortedSetDocValuesField("the_field", new BytesRef("test1"))); + doc.add(new SortedNumericDocValuesField("number_field", 19)); + iw.addDocument(doc); + + doc = new Document(); + doc.add(new SortedSetDocValuesField("the_field", new BytesRef("test2"))); + doc.add(new SortedNumericDocValuesField("number_field", 55)); + iw.addDocument(doc); + }, f -> { + assertThat(((InternalSimpleValue)(f.getBuckets().get(0).getAggregations().get("bucket_script"))).value, + equalTo(19.0)); + }, fieldType, fieldType1); + } + + private void testCase(FiltersAggregationBuilder aggregationBuilder, Query query, + CheckedConsumer buildIndex, + Consumer verify, MappedFieldType... fieldType) throws IOException { + + try (Directory directory = newDirectory()) { + RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); + buildIndex.accept(indexWriter); + indexWriter.close(); + + try (IndexReader indexReader = DirectoryReader.open(directory)) { + IndexSearcher indexSearcher = newSearcher(indexReader, true, true); + + InternalFilters filters; + filters = searchAndReduce(indexSearcher, query, aggregationBuilder, fieldType); + verify.accept(filters); + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java index 122a88c9b1e..ac1afee1c79 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java @@ -30,12 +30,17 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.store.Directory; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.time.DateFormatters; import org.elasticsearch.index.mapper.DateFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.script.MockScriptEngine; import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.ScriptModule; import org.elasticsearch.script.ScriptService; +import org.elasticsearch.script.ScriptType; import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.TestAggregatorFactory; @@ -56,8 +61,6 @@ import java.util.function.Consumer; import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class MovFnUnitTests extends AggregatorTestCase { @@ -79,31 +82,35 @@ public class MovFnUnitTests extends AggregatorTestCase { private static final List datasetValues = Arrays.asList(1,2,3,4,5,6,7,8,9,10); + @Override + protected ScriptService getMockScriptService() { + MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, + Collections.singletonMap("test", script -> MovingFunctions.max((double[]) script.get("_values"))), + Collections.emptyMap()); + Map engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine); + + return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS); + } + public void testMatchAllDocs() throws IOException { - check(0, Arrays.asList(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)); + check(0, 3, Arrays.asList(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)); } public void testShift() throws IOException { - check(1, Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)); - check(5, Arrays.asList(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN)); - check(-5, Arrays.asList(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0)); + check(1, 3, Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)); + check(5, 3, Arrays.asList(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN)); + check(-5, 3, Arrays.asList(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0)); } public void testWideWindow() throws IOException { - Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap()); - MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 100); - builder.setShift(50); - check(builder, script, Arrays.asList(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0)); + check(50, 100,Arrays.asList(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0)); } - private void check(int shift, List expected) throws IOException { - Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap()); - MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3); + private void check(int shift, int window, List expected) throws IOException { + Script script = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "test", Collections.emptyMap()); + MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, window); builder.setShift(shift); - check(builder, script, expected); - } - private void check(MovFnPipelineAggregationBuilder builder, Script script, List expected) throws IOException { Query query = new MatchAllDocsQuery(); DateHistogramAggregationBuilder aggBuilder = new DateHistogramAggregationBuilder("histo"); aggBuilder.calendarInterval(DateHistogramInterval.DAY).field(DATE_FIELD); @@ -111,19 +118,17 @@ public class MovFnUnitTests extends AggregatorTestCase { aggBuilder.subAggregation(builder); executeTestCase(query, aggBuilder, histogram -> { - List buckets = histogram.getBuckets(); - List actual = buckets.stream() - .map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value()) - .collect(Collectors.toList()); - assertThat(actual, equalTo(expected)); - }, 1000, script); + List buckets = histogram.getBuckets(); + List actual = buckets.stream() + .map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value()) + .collect(Collectors.toList()); + assertThat(actual, equalTo(expected)); + }); } - private void executeTestCase(Query query, DateHistogramAggregationBuilder aggBuilder, - Consumer verify, - int maxBucket, Script script) throws IOException { + Consumer verify) throws IOException { try (Directory directory = newDirectory()) { try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { @@ -144,20 +149,6 @@ public class MovFnUnitTests extends AggregatorTestCase { } } - ScriptService scriptService = mock(ScriptService.class); - MovingFunctionScript.Factory factory = mock(MovingFunctionScript.Factory.class); - when(scriptService.compile(script, MovingFunctionScript.CONTEXT)).thenReturn(factory); - - MovingFunctionScript scriptInstance = new MovingFunctionScript() { - @Override - public double execute(Map params, double[] values) { - assertNotNull(values); - return MovingFunctions.max(values); - } - }; - - when(factory.newInstance()).thenReturn(scriptInstance); - try (IndexReader indexReader = DirectoryReader.open(directory)) { IndexSearcher indexSearcher = newSearcher(indexReader, true, true); @@ -171,7 +162,7 @@ public class MovFnUnitTests extends AggregatorTestCase { valueFieldType.setName("value_field"); InternalDateHistogram histogram; - histogram = searchAndReduce(indexSearcher, query, aggBuilder, maxBucket, scriptService, + histogram = searchAndReduce(indexSearcher, query, aggBuilder, 1000, new MappedFieldType[]{fieldType, valueFieldType}); verify.accept(histogram); } diff --git a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java index 8044655b44e..bdbf092b897 100644 --- a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java +++ b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java @@ -27,7 +27,6 @@ import org.elasticsearch.index.similarity.ScriptedSimilarity.Field; import org.elasticsearch.index.similarity.ScriptedSimilarity.Query; import org.elasticsearch.index.similarity.ScriptedSimilarity.Term; import org.elasticsearch.search.aggregations.pipeline.MovingFunctionScript; -import org.elasticsearch.search.aggregations.pipeline.MovingFunctions; import org.elasticsearch.search.lookup.LeafSearchLookup; import org.elasticsearch.search.lookup.SearchLookup; @@ -271,7 +270,13 @@ public class MockScriptEngine implements ScriptEngine { SimilarityWeightScript.Factory factory = mockCompiled::createSimilarityWeightScript; return context.factoryClazz.cast(factory); } else if (context.instanceClazz.equals(MovingFunctionScript.class)) { - MovingFunctionScript.Factory factory = mockCompiled::createMovingFunctionScript; + MovingFunctionScript.Factory factory = () -> new MovingFunctionScript() { + @Override + public double execute(Map params1, double[] values) { + params1.put("_values", values); + return (double) script.apply(params1); + } + }; return context.factoryClazz.cast(factory); } else if (context.instanceClazz.equals(ScoreScript.class)) { ScoreScript.Factory factory = new MockScoreScript(script); @@ -335,10 +340,6 @@ public class MockScriptEngine implements ScriptEngine { return new MockSimilarityWeightScript(script != null ? script : ctx -> 42d); } - public MovingFunctionScript createMovingFunctionScript() { - return new MockMovingFunctionScript(); - } - public ScriptedMetricAggContexts.InitScript createMetricAggInitScript(Map params, Map state) { return new MockMetricAggInitScript(params, state, script != null ? script : ctx -> 42d); } @@ -544,13 +545,6 @@ public class MockScriptEngine implements ScriptEngine { return new Script(ScriptType.INLINE, "mock", script, emptyMap()); } - public class MockMovingFunctionScript extends MovingFunctionScript { - @Override - public double execute(Map params, double[] values) { - return MovingFunctions.unweightedAvg(values); - } - } - public class MockScoreScript implements ScoreScript.Factory { private final Function, Object> script; diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 2a9cd664a3b..6afa1b32182 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -351,7 +351,7 @@ public abstract class AggregatorTestCase extends ESTestCase { Query query, AggregationBuilder builder, MappedFieldType... fieldTypes) throws IOException { - return searchAndReduce(searcher, query, builder, DEFAULT_MAX_BUCKETS, null, fieldTypes); + return searchAndReduce(searcher, query, builder, DEFAULT_MAX_BUCKETS, fieldTypes); } /** @@ -363,7 +363,6 @@ public abstract class AggregatorTestCase extends ESTestCase { Query query, AggregationBuilder builder, int maxBucket, - ScriptService scriptService, MappedFieldType... fieldTypes) throws IOException { final IndexReaderContext ctx = searcher.getTopReaderContext(); @@ -408,7 +407,7 @@ public abstract class AggregatorTestCase extends ESTestCase { List toReduce = aggs.subList(0, r); MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer(maxBucket); InternalAggregation.ReduceContext context = - new InternalAggregation.ReduceContext(root.context().bigArrays(), null, + new InternalAggregation.ReduceContext(root.context().bigArrays(), getMockScriptService(), reduceBucketConsumer, false); A reduced = (A) aggs.get(0).doReduce(toReduce, context); doAssertReducedMultiBucketConsumer(reduced, reduceBucketConsumer); @@ -418,7 +417,7 @@ public abstract class AggregatorTestCase extends ESTestCase { // now do the final reduce MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer(maxBucket); InternalAggregation.ReduceContext context = - new InternalAggregation.ReduceContext(root.context().bigArrays(), scriptService, reduceBucketConsumer, true); + new InternalAggregation.ReduceContext(root.context().bigArrays(), getMockScriptService(), reduceBucketConsumer, true); @SuppressWarnings("unchecked") A internalAgg = (A) aggs.get(0).doReduce(aggs, context);