Allow pipeline aggs to select specific buckets from multi-bucket aggs (#44179)

This adjusts the `buckets_path` parser so that pipeline aggs can
select specific buckets (via their bucket keys) instead of fetching
the entire set of buckets.  This is useful for bucket_script in
particular, which might want specific buckets for calculations.

It's possible to workaround this with `filter` aggs, but the workaround
is hacky and probably less performant.

- Adjusts documentation
- Adds a barebones AggregatorTestCase for bucket_script
- Tweaks AggTestCase to use getMockScriptService() for reductions and
pipelines.  Previously pipelines could just pass in a script service
for testing, but this didnt work for regular aggs.  The new
getMockScriptService() method fixes that issue, but needs to be used
for pipelines too.  This had a knock-on effect of touching MovFn,
AvgBucket and ScriptedMetric
This commit is contained in:
Zachary Tong 2019-08-05 12:15:42 -04:00
parent e5079ac288
commit 3df1c76f9b
10 changed files with 474 additions and 72 deletions

View File

@ -35,11 +35,12 @@ parameter, which follows a specific format:
// https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_Form
[source,ebnf]
--------------------------------------------------
AGG_SEPARATOR = '>' ;
METRIC_SEPARATOR = '.' ;
AGG_SEPARATOR = `>` ;
METRIC_SEPARATOR = `.` ;
AGG_NAME = <the name of the aggregation> ;
METRIC = <the name of the metric (in case of multi-value metrics aggregation)> ;
PATH = <AGG_NAME> [ <AGG_SEPARATOR>, <AGG_NAME> ]* [ <METRIC_SEPARATOR>, <METRIC> ] ;
MULTIBUCKET_KEY = `[<KEY_NAME>]`
PATH = <AGG_NAME><MULTIBUCKET_KEY>? (<AGG_SEPARATOR>, <AGG_NAME> )* ( <METRIC_SEPARATOR>, <METRIC> ) ;
--------------------------------------------------
For example, the path `"my_bucket>my_stats.avg"` will path to the `avg` value in the `"my_stats"` metric, which is
@ -111,6 +112,52 @@ POST /_search
<1> `buckets_path` instructs this max_bucket aggregation that we want the maximum value of the `sales` aggregation in the
`sales_per_month` date histogram.
If a Sibling pipeline agg references a multi-bucket aggregation, such as a `terms` agg, it also has the option to
select specific keys from the multi-bucket. For example, a `bucket_script` could select two specific buckets (via
their bucket keys) to perform the calculation:
[source,js]
--------------------------------------------------
POST /_search
{
"aggs" : {
"sales_per_month" : {
"date_histogram" : {
"field" : "date",
"calendar_interval" : "month"
},
"aggs": {
"sale_type": {
"terms": {
"field": "type"
},
"aggs": {
"sales": {
"sum": {
"field": "price"
}
}
}
},
"hat_vs_bag_ratio": {
"bucket_script": {
"buckets_path": {
"hats": "sale_type['hat']>sales", <1>
"bags": "sale_type['bag']>sales" <1>
},
"script": "params.hats / params.bags"
}
}
}
}
}
}
--------------------------------------------------
// CONSOLE
// TEST[setup:sales]
<1> `buckets_path` selects the hats and bags buckets (via `['hat']`/`['bag']``) to use in the script specifically,
instead of fetching all the buckets from `sale_type` aggregation
[float]
=== Special Paths

View File

@ -102,3 +102,41 @@ setup:
- is_false: aggregations.double_terms.buckets.1.key_as_string
- match: { aggregations.double_terms.buckets.1.doc_count: 1 }
---
"Bucket script with keys":
- do:
search:
rest_total_hits_as_int: true
body:
size: 0
aggs:
placeholder:
filters:
filters:
- match_all: {}
aggs:
str_terms:
terms:
field: "str"
aggs:
the_avg:
avg:
field: "number"
the_bucket_script:
bucket_script:
buckets_path:
foo: "str_terms['bcd']>the_avg.value"
script: "params.foo"
- match: { hits.total: 3 }
- length: { aggregations.placeholder.buckets.0.str_terms.buckets: 2 }
- match: { aggregations.placeholder.buckets.0.str_terms.buckets.0.key: "abc" }
- is_false: aggregations.placeholder.buckets.0.str_terms.buckets.0.key_as_string
- match: { aggregations.placeholder.buckets.0.str_terms.buckets.0.doc_count: 2 }
- match: { aggregations.placeholder.buckets.0.str_terms.buckets.1.key: "bcd" }
- is_false: aggregations.placeholder.buckets.0.str_terms.buckets.1.key_as_string
- match: { aggregations.placeholder.buckets.0.str_terms.buckets.1.doc_count: 1 }
- match: { aggregations.placeholder.buckets.0.the_bucket_script.value: 2.0 }

View File

@ -73,16 +73,33 @@ public abstract class InternalMultiBucketAggregation<A extends InternalMultiBuck
public Object getProperty(List<String> path) {
if (path.isEmpty()) {
return this;
} else if (path.get(0).equals("_bucket_count")) {
return getBuckets().size();
} else {
List<? extends InternalBucket> buckets = getBuckets();
Object[] propertyArray = new Object[buckets.size()];
for (int i = 0; i < buckets.size(); i++) {
propertyArray[i] = buckets.get(i).getProperty(getName(), path);
}
return propertyArray;
}
return resolvePropertyFromPath(path, getBuckets(), getName());
}
static Object resolvePropertyFromPath(List<String> path, List<? extends InternalBucket> buckets, String name) {
String aggName = path.get(0);
if (aggName.equals("_bucket_count")) {
return buckets.size();
}
// This is a bucket key, look through our buckets and see if we can find a match
if (aggName.startsWith("'") && aggName.endsWith("'")) {
for (InternalBucket bucket : buckets) {
if (bucket.getKeyAsString().equals(aggName.substring(1, aggName.length() - 1))) {
return bucket.getProperty(name, path.subList(1, path.size()));
}
}
// No key match, time to give up
throw new InvalidAggregationPathException("Cannot find an key [" + aggName + "] in [" + name + "]");
}
Object[] propertyArray = new Object[buckets.size()];
for (int i = 0; i < buckets.size(); i++) {
propertyArray[i] = buckets.get(i).getProperty(name, path);
}
return propertyArray;
}
/**

View File

@ -0,0 +1,183 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.bucket.terms.InternalTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.metrics.InternalAvg;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.test.ESTestCase;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static org.elasticsearch.search.aggregations.InternalMultiBucketAggregation.resolvePropertyFromPath;
import static org.hamcrest.Matchers.equalTo;
public class InternalMultiBucketAggregationTests extends ESTestCase {
public void testResolveToAgg() {
AggregationPath path = AggregationPath.parse("the_avg");
List<LongTerms.Bucket> buckets = new ArrayList<>();
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
buckets.add(bucket);
Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
assertThat(value[0], equalTo(agg));
}
public void testResolveToAggValue() {
AggregationPath path = AggregationPath.parse("the_avg.value");
List<LongTerms.Bucket> buckets = new ArrayList<>();
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
buckets.add(bucket);
Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
assertThat(value[0], equalTo(2.0));
}
public void testResolveToNothing() {
AggregationPath path = AggregationPath.parse("foo.value");
List<LongTerms.Bucket> buckets = new ArrayList<>();
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
buckets.add(bucket);
InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
() -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"));
assertThat(e.getMessage(), equalTo("Cannot find an aggregation named [foo] in [the_long_terms]"));
}
public void testResolveToUnknown() {
AggregationPath path = AggregationPath.parse("the_avg.unknown");
List<LongTerms.Bucket> buckets = new ArrayList<>();
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
buckets.add(bucket);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"));
assertThat(e.getMessage(), equalTo("path not supported for [the_avg]: [unknown]"));
}
public void testResolveToBucketCount() {
AggregationPath path = AggregationPath.parse("_bucket_count");
List<LongTerms.Bucket> buckets = new ArrayList<>();
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
buckets.add(bucket);
Object value = resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
assertThat(value, equalTo(1));
}
public void testResolveToCount() {
AggregationPath path = AggregationPath.parse("_count");
List<LongTerms.Bucket> buckets = new ArrayList<>();
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
buckets.add(bucket);
Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
assertThat(value[0], equalTo(1L));
}
public void testResolveToKey() {
AggregationPath path = AggregationPath.parse("_key");
List<LongTerms.Bucket> buckets = new ArrayList<>();
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW);
buckets.add(bucket);
Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
assertThat(value[0], equalTo(19L));
}
public void testResolveToSpecificBucket() {
AggregationPath path = AggregationPath.parse("string_terms['foo']>the_avg.value");
List<LongTerms.Bucket> buckets = new ArrayList<>();
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
InternalAggregations internalStringAggs = new InternalAggregations(Collections.singletonList(agg));
List<StringTerms.Bucket> stringBuckets = Collections.singletonList(new StringTerms.Bucket(
new BytesRef("foo".getBytes(StandardCharsets.UTF_8), 0, "foo".getBytes(StandardCharsets.UTF_8).length), 1,
internalStringAggs, false, 0, DocValueFormat.RAW));
InternalTerms termsAgg = new StringTerms("string_terms", BucketOrder.count(false), 1, 0, Collections.emptyList(),
Collections.emptyMap(), DocValueFormat.RAW, 1, false, 0, stringBuckets, 0);
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(termsAgg));
LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW);
buckets.add(bucket);
Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
assertThat(value[0], equalTo(2.0));
}
public void testResolveToMissingSpecificBucket() {
AggregationPath path = AggregationPath.parse("string_terms['bar']>the_avg.value");
List<LongTerms.Bucket> buckets = new ArrayList<>();
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
InternalAggregations internalStringAggs = new InternalAggregations(Collections.singletonList(agg));
List<StringTerms.Bucket> stringBuckets = Collections.singletonList(new StringTerms.Bucket(
new BytesRef("foo".getBytes(StandardCharsets.UTF_8), 0, "foo".getBytes(StandardCharsets.UTF_8).length), 1,
internalStringAggs, false, 0, DocValueFormat.RAW));
InternalTerms termsAgg = new StringTerms("string_terms", BucketOrder.count(false), 1, 0, Collections.emptyList(),
Collections.emptyMap(), DocValueFormat.RAW, 1, false, 0, stringBuckets, 0);
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(termsAgg));
LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW);
buckets.add(bucket);
InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
() -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"));
assertThat(e.getMessage(), equalTo("Cannot find an key ['bar'] in [string_terms]"));
}
}

View File

@ -173,6 +173,17 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
});
}
@Override
protected ScriptService getMockScriptService() {
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
SCRIPTS,
Collections.emptyMap());
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
}
@SuppressWarnings("unchecked")
public void testNoDocs() throws IOException {
try (Directory directory = newDirectory()) {
@ -311,7 +322,7 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
.initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS)
.combineScript(COMBINE_SCRIPT_PARAMS).reduceScript(REDUCE_SCRIPT_PARAMS);
ScriptedMetric scriptedMetric = searchAndReduce(
newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder, 0, scriptService);
newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder, 0);
// The result value depends on the script params.
assertEquals(4803, scriptedMetric.aggregation());

View File

@ -120,9 +120,9 @@ public class AvgBucketAggregatorTests extends AggregatorTestCase {
valueFieldType.setName(VALUE_FIELD);
valueFieldType.setHasDocValues(true);
avgResult = searchAndReduce(indexSearcher, query, avgBuilder, 10000, null,
avgResult = searchAndReduce(indexSearcher, query, avgBuilder, 10000,
new MappedFieldType[]{fieldType, valueFieldType});
histogramResult = searchAndReduce(indexSearcher, query, histo, 10000, null,
histogramResult = searchAndReduce(indexSearcher, query, histo, 10000,
new MappedFieldType[]{fieldType, valueFieldType});
}

View File

@ -0,0 +1,122 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.pipeline;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.script.MockScriptEngine;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.ScriptModule;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.filter.InternalFilters;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValueType;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;
import static org.hamcrest.CoreMatchers.equalTo;
public class BucketScriptAggregatorTests extends AggregatorTestCase {
private final String SCRIPT_NAME = "script_name";
@Override
protected ScriptService getMockScriptService() {
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
Collections.singletonMap(SCRIPT_NAME, script -> script.get("the_avg")),
Collections.emptyMap());
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
}
public void testScript() throws IOException {
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType.setName("number_field");
fieldType.setHasDocValues(true);
MappedFieldType fieldType1 = new KeywordFieldMapper.KeywordFieldType();
fieldType1.setName("the_field");
fieldType1.setHasDocValues(true);
FiltersAggregationBuilder filters = new FiltersAggregationBuilder("placeholder", new MatchAllQueryBuilder())
.subAggregation(new TermsAggregationBuilder("the_terms", ValueType.STRING).field("the_field")
.subAggregation(new AvgAggregationBuilder("the_avg").field("number_field")))
.subAggregation(new BucketScriptPipelineAggregationBuilder("bucket_script",
Collections.singletonMap("the_avg", "the_terms['test1']>the_avg.value"),
new Script(ScriptType.INLINE, MockScriptEngine.NAME, SCRIPT_NAME, Collections.emptyMap())));
testCase(filters, new MatchAllDocsQuery(), iw -> {
Document doc = new Document();
doc.add(new SortedSetDocValuesField("the_field", new BytesRef("test1")));
doc.add(new SortedNumericDocValuesField("number_field", 19));
iw.addDocument(doc);
doc = new Document();
doc.add(new SortedSetDocValuesField("the_field", new BytesRef("test2")));
doc.add(new SortedNumericDocValuesField("number_field", 55));
iw.addDocument(doc);
}, f -> {
assertThat(((InternalSimpleValue)(f.getBuckets().get(0).getAggregations().get("bucket_script"))).value,
equalTo(19.0));
}, fieldType, fieldType1);
}
private void testCase(FiltersAggregationBuilder aggregationBuilder, Query query,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalFilters> verify, MappedFieldType... fieldType) throws IOException {
try (Directory directory = newDirectory()) {
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
buildIndex.accept(indexWriter);
indexWriter.close();
try (IndexReader indexReader = DirectoryReader.open(directory)) {
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
InternalFilters filters;
filters = searchAndReduce(indexSearcher, query, aggregationBuilder, fieldType);
verify.accept(filters);
}
}
}
}

View File

@ -30,12 +30,17 @@ import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.time.DateFormatters;
import org.elasticsearch.index.mapper.DateFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.script.MockScriptEngine;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.ScriptModule;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.TestAggregatorFactory;
@ -56,8 +61,6 @@ import java.util.function.Consumer;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class MovFnUnitTests extends AggregatorTestCase {
@ -79,31 +82,35 @@ public class MovFnUnitTests extends AggregatorTestCase {
private static final List<Integer> datasetValues = Arrays.asList(1,2,3,4,5,6,7,8,9,10);
@Override
protected ScriptService getMockScriptService() {
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
Collections.singletonMap("test", script -> MovingFunctions.max((double[]) script.get("_values"))),
Collections.emptyMap());
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
}
public void testMatchAllDocs() throws IOException {
check(0, Arrays.asList(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0));
check(0, 3, Arrays.asList(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0));
}
public void testShift() throws IOException {
check(1, Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0));
check(5, Arrays.asList(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN));
check(-5, Arrays.asList(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0));
check(1, 3, Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0));
check(5, 3, Arrays.asList(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN));
check(-5, 3, Arrays.asList(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0));
}
public void testWideWindow() throws IOException {
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 100);
builder.setShift(50);
check(builder, script, Arrays.asList(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0));
check(50, 100,Arrays.asList(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0));
}
private void check(int shift, List<Double> expected) throws IOException {
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3);
private void check(int shift, int window, List<Double> expected) throws IOException {
Script script = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "test", Collections.emptyMap());
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, window);
builder.setShift(shift);
check(builder, script, expected);
}
private void check(MovFnPipelineAggregationBuilder builder, Script script, List<Double> expected) throws IOException {
Query query = new MatchAllDocsQuery();
DateHistogramAggregationBuilder aggBuilder = new DateHistogramAggregationBuilder("histo");
aggBuilder.calendarInterval(DateHistogramInterval.DAY).field(DATE_FIELD);
@ -111,19 +118,17 @@ public class MovFnUnitTests extends AggregatorTestCase {
aggBuilder.subAggregation(builder);
executeTestCase(query, aggBuilder, histogram -> {
List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
List<Double> actual = buckets.stream()
.map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value())
.collect(Collectors.toList());
assertThat(actual, equalTo(expected));
}, 1000, script);
List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
List<Double> actual = buckets.stream()
.map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value())
.collect(Collectors.toList());
assertThat(actual, equalTo(expected));
});
}
private void executeTestCase(Query query,
DateHistogramAggregationBuilder aggBuilder,
Consumer<Histogram> verify,
int maxBucket, Script script) throws IOException {
Consumer<Histogram> verify) throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
@ -144,20 +149,6 @@ public class MovFnUnitTests extends AggregatorTestCase {
}
}
ScriptService scriptService = mock(ScriptService.class);
MovingFunctionScript.Factory factory = mock(MovingFunctionScript.Factory.class);
when(scriptService.compile(script, MovingFunctionScript.CONTEXT)).thenReturn(factory);
MovingFunctionScript scriptInstance = new MovingFunctionScript() {
@Override
public double execute(Map<String, Object> params, double[] values) {
assertNotNull(values);
return MovingFunctions.max(values);
}
};
when(factory.newInstance()).thenReturn(scriptInstance);
try (IndexReader indexReader = DirectoryReader.open(directory)) {
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
@ -171,7 +162,7 @@ public class MovFnUnitTests extends AggregatorTestCase {
valueFieldType.setName("value_field");
InternalDateHistogram histogram;
histogram = searchAndReduce(indexSearcher, query, aggBuilder, maxBucket, scriptService,
histogram = searchAndReduce(indexSearcher, query, aggBuilder, 1000,
new MappedFieldType[]{fieldType, valueFieldType});
verify.accept(histogram);
}

View File

@ -27,7 +27,6 @@ import org.elasticsearch.index.similarity.ScriptedSimilarity.Field;
import org.elasticsearch.index.similarity.ScriptedSimilarity.Query;
import org.elasticsearch.index.similarity.ScriptedSimilarity.Term;
import org.elasticsearch.search.aggregations.pipeline.MovingFunctionScript;
import org.elasticsearch.search.aggregations.pipeline.MovingFunctions;
import org.elasticsearch.search.lookup.LeafSearchLookup;
import org.elasticsearch.search.lookup.SearchLookup;
@ -271,7 +270,13 @@ public class MockScriptEngine implements ScriptEngine {
SimilarityWeightScript.Factory factory = mockCompiled::createSimilarityWeightScript;
return context.factoryClazz.cast(factory);
} else if (context.instanceClazz.equals(MovingFunctionScript.class)) {
MovingFunctionScript.Factory factory = mockCompiled::createMovingFunctionScript;
MovingFunctionScript.Factory factory = () -> new MovingFunctionScript() {
@Override
public double execute(Map<String, Object> params1, double[] values) {
params1.put("_values", values);
return (double) script.apply(params1);
}
};
return context.factoryClazz.cast(factory);
} else if (context.instanceClazz.equals(ScoreScript.class)) {
ScoreScript.Factory factory = new MockScoreScript(script);
@ -335,10 +340,6 @@ public class MockScriptEngine implements ScriptEngine {
return new MockSimilarityWeightScript(script != null ? script : ctx -> 42d);
}
public MovingFunctionScript createMovingFunctionScript() {
return new MockMovingFunctionScript();
}
public ScriptedMetricAggContexts.InitScript createMetricAggInitScript(Map<String, Object> params, Map<String, Object> state) {
return new MockMetricAggInitScript(params, state, script != null ? script : ctx -> 42d);
}
@ -544,13 +545,6 @@ public class MockScriptEngine implements ScriptEngine {
return new Script(ScriptType.INLINE, "mock", script, emptyMap());
}
public class MockMovingFunctionScript extends MovingFunctionScript {
@Override
public double execute(Map<String, Object> params, double[] values) {
return MovingFunctions.unweightedAvg(values);
}
}
public class MockScoreScript implements ScoreScript.Factory {
private final Function<Map<String, Object>, Object> script;

View File

@ -351,7 +351,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
Query query,
AggregationBuilder builder,
MappedFieldType... fieldTypes) throws IOException {
return searchAndReduce(searcher, query, builder, DEFAULT_MAX_BUCKETS, null, fieldTypes);
return searchAndReduce(searcher, query, builder, DEFAULT_MAX_BUCKETS, fieldTypes);
}
/**
@ -363,7 +363,6 @@ public abstract class AggregatorTestCase extends ESTestCase {
Query query,
AggregationBuilder builder,
int maxBucket,
ScriptService scriptService,
MappedFieldType... fieldTypes) throws IOException {
final IndexReaderContext ctx = searcher.getTopReaderContext();
@ -408,7 +407,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
List<InternalAggregation> toReduce = aggs.subList(0, r);
MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer(maxBucket);
InternalAggregation.ReduceContext context =
new InternalAggregation.ReduceContext(root.context().bigArrays(), null,
new InternalAggregation.ReduceContext(root.context().bigArrays(), getMockScriptService(),
reduceBucketConsumer, false);
A reduced = (A) aggs.get(0).doReduce(toReduce, context);
doAssertReducedMultiBucketConsumer(reduced, reduceBucketConsumer);
@ -418,7 +417,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
// now do the final reduce
MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer(maxBucket);
InternalAggregation.ReduceContext context =
new InternalAggregation.ReduceContext(root.context().bigArrays(), scriptService, reduceBucketConsumer, true);
new InternalAggregation.ReduceContext(root.context().bigArrays(), getMockScriptService(), reduceBucketConsumer, true);
@SuppressWarnings("unchecked")
A internalAgg = (A) aggs.get(0).doReduce(aggs, context);