Allow pipeline aggs to select specific buckets from multi-bucket aggs (#44179)
This adjusts the `buckets_path` parser so that pipeline aggs can select specific buckets (via their bucket keys) instead of fetching the entire set of buckets. This is useful for bucket_script in particular, which might want specific buckets for calculations. It's possible to workaround this with `filter` aggs, but the workaround is hacky and probably less performant. - Adjusts documentation - Adds a barebones AggregatorTestCase for bucket_script - Tweaks AggTestCase to use getMockScriptService() for reductions and pipelines. Previously pipelines could just pass in a script service for testing, but this didnt work for regular aggs. The new getMockScriptService() method fixes that issue, but needs to be used for pipelines too. This had a knock-on effect of touching MovFn, AvgBucket and ScriptedMetric
This commit is contained in:
parent
e5079ac288
commit
3df1c76f9b
|
@ -35,11 +35,12 @@ parameter, which follows a specific format:
|
|||
// https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_Form
|
||||
[source,ebnf]
|
||||
--------------------------------------------------
|
||||
AGG_SEPARATOR = '>' ;
|
||||
METRIC_SEPARATOR = '.' ;
|
||||
AGG_SEPARATOR = `>` ;
|
||||
METRIC_SEPARATOR = `.` ;
|
||||
AGG_NAME = <the name of the aggregation> ;
|
||||
METRIC = <the name of the metric (in case of multi-value metrics aggregation)> ;
|
||||
PATH = <AGG_NAME> [ <AGG_SEPARATOR>, <AGG_NAME> ]* [ <METRIC_SEPARATOR>, <METRIC> ] ;
|
||||
MULTIBUCKET_KEY = `[<KEY_NAME>]`
|
||||
PATH = <AGG_NAME><MULTIBUCKET_KEY>? (<AGG_SEPARATOR>, <AGG_NAME> )* ( <METRIC_SEPARATOR>, <METRIC> ) ;
|
||||
--------------------------------------------------
|
||||
|
||||
For example, the path `"my_bucket>my_stats.avg"` will path to the `avg` value in the `"my_stats"` metric, which is
|
||||
|
@ -111,6 +112,52 @@ POST /_search
|
|||
<1> `buckets_path` instructs this max_bucket aggregation that we want the maximum value of the `sales` aggregation in the
|
||||
`sales_per_month` date histogram.
|
||||
|
||||
If a Sibling pipeline agg references a multi-bucket aggregation, such as a `terms` agg, it also has the option to
|
||||
select specific keys from the multi-bucket. For example, a `bucket_script` could select two specific buckets (via
|
||||
their bucket keys) to perform the calculation:
|
||||
|
||||
[source,js]
|
||||
--------------------------------------------------
|
||||
POST /_search
|
||||
{
|
||||
"aggs" : {
|
||||
"sales_per_month" : {
|
||||
"date_histogram" : {
|
||||
"field" : "date",
|
||||
"calendar_interval" : "month"
|
||||
},
|
||||
"aggs": {
|
||||
"sale_type": {
|
||||
"terms": {
|
||||
"field": "type"
|
||||
},
|
||||
"aggs": {
|
||||
"sales": {
|
||||
"sum": {
|
||||
"field": "price"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"hat_vs_bag_ratio": {
|
||||
"bucket_script": {
|
||||
"buckets_path": {
|
||||
"hats": "sale_type['hat']>sales", <1>
|
||||
"bags": "sale_type['bag']>sales" <1>
|
||||
},
|
||||
"script": "params.hats / params.bags"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
--------------------------------------------------
|
||||
// CONSOLE
|
||||
// TEST[setup:sales]
|
||||
<1> `buckets_path` selects the hats and bags buckets (via `['hat']`/`['bag']``) to use in the script specifically,
|
||||
instead of fetching all the buckets from `sale_type` aggregation
|
||||
|
||||
[float]
|
||||
=== Special Paths
|
||||
|
||||
|
|
|
@ -102,3 +102,41 @@ setup:
|
|||
- is_false: aggregations.double_terms.buckets.1.key_as_string
|
||||
- match: { aggregations.double_terms.buckets.1.doc_count: 1 }
|
||||
|
||||
---
|
||||
"Bucket script with keys":
|
||||
|
||||
- do:
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
size: 0
|
||||
aggs:
|
||||
placeholder:
|
||||
filters:
|
||||
filters:
|
||||
- match_all: {}
|
||||
aggs:
|
||||
str_terms:
|
||||
terms:
|
||||
field: "str"
|
||||
aggs:
|
||||
the_avg:
|
||||
avg:
|
||||
field: "number"
|
||||
the_bucket_script:
|
||||
bucket_script:
|
||||
buckets_path:
|
||||
foo: "str_terms['bcd']>the_avg.value"
|
||||
script: "params.foo"
|
||||
|
||||
- match: { hits.total: 3 }
|
||||
|
||||
- length: { aggregations.placeholder.buckets.0.str_terms.buckets: 2 }
|
||||
- match: { aggregations.placeholder.buckets.0.str_terms.buckets.0.key: "abc" }
|
||||
- is_false: aggregations.placeholder.buckets.0.str_terms.buckets.0.key_as_string
|
||||
- match: { aggregations.placeholder.buckets.0.str_terms.buckets.0.doc_count: 2 }
|
||||
- match: { aggregations.placeholder.buckets.0.str_terms.buckets.1.key: "bcd" }
|
||||
- is_false: aggregations.placeholder.buckets.0.str_terms.buckets.1.key_as_string
|
||||
- match: { aggregations.placeholder.buckets.0.str_terms.buckets.1.doc_count: 1 }
|
||||
- match: { aggregations.placeholder.buckets.0.the_bucket_script.value: 2.0 }
|
||||
|
||||
|
|
|
@ -73,16 +73,33 @@ public abstract class InternalMultiBucketAggregation<A extends InternalMultiBuck
|
|||
public Object getProperty(List<String> path) {
|
||||
if (path.isEmpty()) {
|
||||
return this;
|
||||
} else if (path.get(0).equals("_bucket_count")) {
|
||||
return getBuckets().size();
|
||||
} else {
|
||||
List<? extends InternalBucket> buckets = getBuckets();
|
||||
Object[] propertyArray = new Object[buckets.size()];
|
||||
for (int i = 0; i < buckets.size(); i++) {
|
||||
propertyArray[i] = buckets.get(i).getProperty(getName(), path);
|
||||
}
|
||||
return propertyArray;
|
||||
}
|
||||
return resolvePropertyFromPath(path, getBuckets(), getName());
|
||||
}
|
||||
|
||||
static Object resolvePropertyFromPath(List<String> path, List<? extends InternalBucket> buckets, String name) {
|
||||
String aggName = path.get(0);
|
||||
if (aggName.equals("_bucket_count")) {
|
||||
return buckets.size();
|
||||
}
|
||||
|
||||
// This is a bucket key, look through our buckets and see if we can find a match
|
||||
if (aggName.startsWith("'") && aggName.endsWith("'")) {
|
||||
for (InternalBucket bucket : buckets) {
|
||||
if (bucket.getKeyAsString().equals(aggName.substring(1, aggName.length() - 1))) {
|
||||
return bucket.getProperty(name, path.subList(1, path.size()));
|
||||
}
|
||||
}
|
||||
// No key match, time to give up
|
||||
throw new InvalidAggregationPathException("Cannot find an key [" + aggName + "] in [" + name + "]");
|
||||
}
|
||||
|
||||
Object[] propertyArray = new Object[buckets.size()];
|
||||
for (int i = 0; i < buckets.size(); i++) {
|
||||
propertyArray[i] = buckets.get(i).getProperty(name, path);
|
||||
}
|
||||
return propertyArray;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,183 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.search.aggregations;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.InternalTerms;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
|
||||
import org.elasticsearch.search.aggregations.metrics.InternalAvg;
|
||||
import org.elasticsearch.search.aggregations.support.AggregationPath;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.search.aggregations.InternalMultiBucketAggregation.resolvePropertyFromPath;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class InternalMultiBucketAggregationTests extends ESTestCase {
|
||||
|
||||
public void testResolveToAgg() {
|
||||
AggregationPath path = AggregationPath.parse("the_avg");
|
||||
List<LongTerms.Bucket> buckets = new ArrayList<>();
|
||||
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
|
||||
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
|
||||
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
|
||||
|
||||
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
|
||||
buckets.add(bucket);
|
||||
|
||||
Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
|
||||
assertThat(value[0], equalTo(agg));
|
||||
}
|
||||
|
||||
public void testResolveToAggValue() {
|
||||
AggregationPath path = AggregationPath.parse("the_avg.value");
|
||||
List<LongTerms.Bucket> buckets = new ArrayList<>();
|
||||
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
|
||||
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
|
||||
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
|
||||
|
||||
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
|
||||
buckets.add(bucket);
|
||||
|
||||
Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
|
||||
assertThat(value[0], equalTo(2.0));
|
||||
}
|
||||
|
||||
public void testResolveToNothing() {
|
||||
AggregationPath path = AggregationPath.parse("foo.value");
|
||||
List<LongTerms.Bucket> buckets = new ArrayList<>();
|
||||
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
|
||||
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
|
||||
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
|
||||
|
||||
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
|
||||
buckets.add(bucket);
|
||||
|
||||
InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
|
||||
() -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"));
|
||||
assertThat(e.getMessage(), equalTo("Cannot find an aggregation named [foo] in [the_long_terms]"));
|
||||
}
|
||||
|
||||
public void testResolveToUnknown() {
|
||||
AggregationPath path = AggregationPath.parse("the_avg.unknown");
|
||||
List<LongTerms.Bucket> buckets = new ArrayList<>();
|
||||
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
|
||||
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
|
||||
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
|
||||
|
||||
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
|
||||
buckets.add(bucket);
|
||||
|
||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
|
||||
() -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"));
|
||||
assertThat(e.getMessage(), equalTo("path not supported for [the_avg]: [unknown]"));
|
||||
}
|
||||
|
||||
public void testResolveToBucketCount() {
|
||||
AggregationPath path = AggregationPath.parse("_bucket_count");
|
||||
List<LongTerms.Bucket> buckets = new ArrayList<>();
|
||||
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
|
||||
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
|
||||
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
|
||||
|
||||
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
|
||||
buckets.add(bucket);
|
||||
|
||||
Object value = resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
|
||||
assertThat(value, equalTo(1));
|
||||
}
|
||||
|
||||
public void testResolveToCount() {
|
||||
AggregationPath path = AggregationPath.parse("_count");
|
||||
List<LongTerms.Bucket> buckets = new ArrayList<>();
|
||||
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
|
||||
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
|
||||
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
|
||||
|
||||
LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
|
||||
buckets.add(bucket);
|
||||
|
||||
Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
|
||||
assertThat(value[0], equalTo(1L));
|
||||
}
|
||||
|
||||
public void testResolveToKey() {
|
||||
AggregationPath path = AggregationPath.parse("_key");
|
||||
List<LongTerms.Bucket> buckets = new ArrayList<>();
|
||||
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
|
||||
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
|
||||
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
|
||||
|
||||
LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW);
|
||||
buckets.add(bucket);
|
||||
|
||||
Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
|
||||
assertThat(value[0], equalTo(19L));
|
||||
}
|
||||
|
||||
public void testResolveToSpecificBucket() {
|
||||
AggregationPath path = AggregationPath.parse("string_terms['foo']>the_avg.value");
|
||||
|
||||
List<LongTerms.Bucket> buckets = new ArrayList<>();
|
||||
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
|
||||
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
|
||||
InternalAggregations internalStringAggs = new InternalAggregations(Collections.singletonList(agg));
|
||||
List<StringTerms.Bucket> stringBuckets = Collections.singletonList(new StringTerms.Bucket(
|
||||
new BytesRef("foo".getBytes(StandardCharsets.UTF_8), 0, "foo".getBytes(StandardCharsets.UTF_8).length), 1,
|
||||
internalStringAggs, false, 0, DocValueFormat.RAW));
|
||||
|
||||
InternalTerms termsAgg = new StringTerms("string_terms", BucketOrder.count(false), 1, 0, Collections.emptyList(),
|
||||
Collections.emptyMap(), DocValueFormat.RAW, 1, false, 0, stringBuckets, 0);
|
||||
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(termsAgg));
|
||||
LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW);
|
||||
buckets.add(bucket);
|
||||
|
||||
Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
|
||||
assertThat(value[0], equalTo(2.0));
|
||||
}
|
||||
|
||||
public void testResolveToMissingSpecificBucket() {
|
||||
AggregationPath path = AggregationPath.parse("string_terms['bar']>the_avg.value");
|
||||
|
||||
List<LongTerms.Bucket> buckets = new ArrayList<>();
|
||||
InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
|
||||
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
|
||||
InternalAggregations internalStringAggs = new InternalAggregations(Collections.singletonList(agg));
|
||||
List<StringTerms.Bucket> stringBuckets = Collections.singletonList(new StringTerms.Bucket(
|
||||
new BytesRef("foo".getBytes(StandardCharsets.UTF_8), 0, "foo".getBytes(StandardCharsets.UTF_8).length), 1,
|
||||
internalStringAggs, false, 0, DocValueFormat.RAW));
|
||||
|
||||
InternalTerms termsAgg = new StringTerms("string_terms", BucketOrder.count(false), 1, 0, Collections.emptyList(),
|
||||
Collections.emptyMap(), DocValueFormat.RAW, 1, false, 0, stringBuckets, 0);
|
||||
InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(termsAgg));
|
||||
LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW);
|
||||
buckets.add(bucket);
|
||||
|
||||
InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
|
||||
() -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"));
|
||||
assertThat(e.getMessage(), equalTo("Cannot find an key ['bar'] in [string_terms]"));
|
||||
}
|
||||
}
|
|
@ -173,6 +173,17 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
|
|||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ScriptService getMockScriptService() {
|
||||
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
|
||||
SCRIPTS,
|
||||
Collections.emptyMap());
|
||||
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
|
||||
|
||||
return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
|
||||
}
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testNoDocs() throws IOException {
|
||||
try (Directory directory = newDirectory()) {
|
||||
|
@ -311,7 +322,7 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
|
|||
.initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS)
|
||||
.combineScript(COMBINE_SCRIPT_PARAMS).reduceScript(REDUCE_SCRIPT_PARAMS);
|
||||
ScriptedMetric scriptedMetric = searchAndReduce(
|
||||
newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder, 0, scriptService);
|
||||
newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder, 0);
|
||||
|
||||
// The result value depends on the script params.
|
||||
assertEquals(4803, scriptedMetric.aggregation());
|
||||
|
|
|
@ -120,9 +120,9 @@ public class AvgBucketAggregatorTests extends AggregatorTestCase {
|
|||
valueFieldType.setName(VALUE_FIELD);
|
||||
valueFieldType.setHasDocValues(true);
|
||||
|
||||
avgResult = searchAndReduce(indexSearcher, query, avgBuilder, 10000, null,
|
||||
avgResult = searchAndReduce(indexSearcher, query, avgBuilder, 10000,
|
||||
new MappedFieldType[]{fieldType, valueFieldType});
|
||||
histogramResult = searchAndReduce(indexSearcher, query, histo, 10000, null,
|
||||
histogramResult = searchAndReduce(indexSearcher, query, histo, 10000,
|
||||
new MappedFieldType[]{fieldType, valueFieldType});
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.search.aggregations.pipeline;
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.SortedNumericDocValuesField;
|
||||
import org.apache.lucene.document.SortedSetDocValuesField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.common.CheckedConsumer;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||
import org.elasticsearch.index.mapper.MappedFieldType;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.script.MockScriptEngine;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.script.ScriptEngine;
|
||||
import org.elasticsearch.script.ScriptModule;
|
||||
import org.elasticsearch.script.ScriptService;
|
||||
import org.elasticsearch.script.ScriptType;
|
||||
import org.elasticsearch.search.aggregations.AggregatorTestCase;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.InternalFilters;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.support.ValueType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
|
||||
public class BucketScriptAggregatorTests extends AggregatorTestCase {
|
||||
private final String SCRIPT_NAME = "script_name";
|
||||
|
||||
@Override
|
||||
protected ScriptService getMockScriptService() {
|
||||
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
|
||||
Collections.singletonMap(SCRIPT_NAME, script -> script.get("the_avg")),
|
||||
Collections.emptyMap());
|
||||
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
|
||||
|
||||
return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
|
||||
}
|
||||
|
||||
public void testScript() throws IOException {
|
||||
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
|
||||
fieldType.setName("number_field");
|
||||
fieldType.setHasDocValues(true);
|
||||
MappedFieldType fieldType1 = new KeywordFieldMapper.KeywordFieldType();
|
||||
fieldType1.setName("the_field");
|
||||
fieldType1.setHasDocValues(true);
|
||||
|
||||
FiltersAggregationBuilder filters = new FiltersAggregationBuilder("placeholder", new MatchAllQueryBuilder())
|
||||
.subAggregation(new TermsAggregationBuilder("the_terms", ValueType.STRING).field("the_field")
|
||||
.subAggregation(new AvgAggregationBuilder("the_avg").field("number_field")))
|
||||
.subAggregation(new BucketScriptPipelineAggregationBuilder("bucket_script",
|
||||
Collections.singletonMap("the_avg", "the_terms['test1']>the_avg.value"),
|
||||
new Script(ScriptType.INLINE, MockScriptEngine.NAME, SCRIPT_NAME, Collections.emptyMap())));
|
||||
|
||||
|
||||
testCase(filters, new MatchAllDocsQuery(), iw -> {
|
||||
Document doc = new Document();
|
||||
doc.add(new SortedSetDocValuesField("the_field", new BytesRef("test1")));
|
||||
doc.add(new SortedNumericDocValuesField("number_field", 19));
|
||||
iw.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(new SortedSetDocValuesField("the_field", new BytesRef("test2")));
|
||||
doc.add(new SortedNumericDocValuesField("number_field", 55));
|
||||
iw.addDocument(doc);
|
||||
}, f -> {
|
||||
assertThat(((InternalSimpleValue)(f.getBuckets().get(0).getAggregations().get("bucket_script"))).value,
|
||||
equalTo(19.0));
|
||||
}, fieldType, fieldType1);
|
||||
}
|
||||
|
||||
private void testCase(FiltersAggregationBuilder aggregationBuilder, Query query,
|
||||
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
|
||||
Consumer<InternalFilters> verify, MappedFieldType... fieldType) throws IOException {
|
||||
|
||||
try (Directory directory = newDirectory()) {
|
||||
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
|
||||
buildIndex.accept(indexWriter);
|
||||
indexWriter.close();
|
||||
|
||||
try (IndexReader indexReader = DirectoryReader.open(directory)) {
|
||||
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
|
||||
|
||||
InternalFilters filters;
|
||||
filters = searchAndReduce(indexSearcher, query, aggregationBuilder, fieldType);
|
||||
verify.accept(filters);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -30,12 +30,17 @@ import org.apache.lucene.search.IndexSearcher;
|
|||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.time.DateFormatters;
|
||||
import org.elasticsearch.index.mapper.DateFieldMapper;
|
||||
import org.elasticsearch.index.mapper.MappedFieldType;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.script.MockScriptEngine;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.script.ScriptEngine;
|
||||
import org.elasticsearch.script.ScriptModule;
|
||||
import org.elasticsearch.script.ScriptService;
|
||||
import org.elasticsearch.script.ScriptType;
|
||||
import org.elasticsearch.search.aggregations.AggregatorTestCase;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.TestAggregatorFactory;
|
||||
|
@ -56,8 +61,6 @@ import java.util.function.Consumer;
|
|||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class MovFnUnitTests extends AggregatorTestCase {
|
||||
|
||||
|
@ -79,31 +82,35 @@ public class MovFnUnitTests extends AggregatorTestCase {
|
|||
|
||||
private static final List<Integer> datasetValues = Arrays.asList(1,2,3,4,5,6,7,8,9,10);
|
||||
|
||||
@Override
|
||||
protected ScriptService getMockScriptService() {
|
||||
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
|
||||
Collections.singletonMap("test", script -> MovingFunctions.max((double[]) script.get("_values"))),
|
||||
Collections.emptyMap());
|
||||
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
|
||||
|
||||
return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
|
||||
}
|
||||
|
||||
public void testMatchAllDocs() throws IOException {
|
||||
check(0, Arrays.asList(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0));
|
||||
check(0, 3, Arrays.asList(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0));
|
||||
}
|
||||
|
||||
public void testShift() throws IOException {
|
||||
check(1, Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0));
|
||||
check(5, Arrays.asList(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN));
|
||||
check(-5, Arrays.asList(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0));
|
||||
check(1, 3, Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0));
|
||||
check(5, 3, Arrays.asList(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN));
|
||||
check(-5, 3, Arrays.asList(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0));
|
||||
}
|
||||
|
||||
public void testWideWindow() throws IOException {
|
||||
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
|
||||
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 100);
|
||||
builder.setShift(50);
|
||||
check(builder, script, Arrays.asList(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0));
|
||||
check(50, 100,Arrays.asList(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0));
|
||||
}
|
||||
|
||||
private void check(int shift, List<Double> expected) throws IOException {
|
||||
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
|
||||
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3);
|
||||
private void check(int shift, int window, List<Double> expected) throws IOException {
|
||||
Script script = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "test", Collections.emptyMap());
|
||||
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, window);
|
||||
builder.setShift(shift);
|
||||
check(builder, script, expected);
|
||||
}
|
||||
|
||||
private void check(MovFnPipelineAggregationBuilder builder, Script script, List<Double> expected) throws IOException {
|
||||
Query query = new MatchAllDocsQuery();
|
||||
DateHistogramAggregationBuilder aggBuilder = new DateHistogramAggregationBuilder("histo");
|
||||
aggBuilder.calendarInterval(DateHistogramInterval.DAY).field(DATE_FIELD);
|
||||
|
@ -111,19 +118,17 @@ public class MovFnUnitTests extends AggregatorTestCase {
|
|||
aggBuilder.subAggregation(builder);
|
||||
|
||||
executeTestCase(query, aggBuilder, histogram -> {
|
||||
List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
|
||||
List<Double> actual = buckets.stream()
|
||||
.map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value())
|
||||
.collect(Collectors.toList());
|
||||
assertThat(actual, equalTo(expected));
|
||||
}, 1000, script);
|
||||
List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
|
||||
List<Double> actual = buckets.stream()
|
||||
.map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value())
|
||||
.collect(Collectors.toList());
|
||||
assertThat(actual, equalTo(expected));
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
private void executeTestCase(Query query,
|
||||
DateHistogramAggregationBuilder aggBuilder,
|
||||
Consumer<Histogram> verify,
|
||||
int maxBucket, Script script) throws IOException {
|
||||
Consumer<Histogram> verify) throws IOException {
|
||||
|
||||
try (Directory directory = newDirectory()) {
|
||||
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
|
||||
|
@ -144,20 +149,6 @@ public class MovFnUnitTests extends AggregatorTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
ScriptService scriptService = mock(ScriptService.class);
|
||||
MovingFunctionScript.Factory factory = mock(MovingFunctionScript.Factory.class);
|
||||
when(scriptService.compile(script, MovingFunctionScript.CONTEXT)).thenReturn(factory);
|
||||
|
||||
MovingFunctionScript scriptInstance = new MovingFunctionScript() {
|
||||
@Override
|
||||
public double execute(Map<String, Object> params, double[] values) {
|
||||
assertNotNull(values);
|
||||
return MovingFunctions.max(values);
|
||||
}
|
||||
};
|
||||
|
||||
when(factory.newInstance()).thenReturn(scriptInstance);
|
||||
|
||||
try (IndexReader indexReader = DirectoryReader.open(directory)) {
|
||||
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
|
||||
|
||||
|
@ -171,7 +162,7 @@ public class MovFnUnitTests extends AggregatorTestCase {
|
|||
valueFieldType.setName("value_field");
|
||||
|
||||
InternalDateHistogram histogram;
|
||||
histogram = searchAndReduce(indexSearcher, query, aggBuilder, maxBucket, scriptService,
|
||||
histogram = searchAndReduce(indexSearcher, query, aggBuilder, 1000,
|
||||
new MappedFieldType[]{fieldType, valueFieldType});
|
||||
verify.accept(histogram);
|
||||
}
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.elasticsearch.index.similarity.ScriptedSimilarity.Field;
|
|||
import org.elasticsearch.index.similarity.ScriptedSimilarity.Query;
|
||||
import org.elasticsearch.index.similarity.ScriptedSimilarity.Term;
|
||||
import org.elasticsearch.search.aggregations.pipeline.MovingFunctionScript;
|
||||
import org.elasticsearch.search.aggregations.pipeline.MovingFunctions;
|
||||
import org.elasticsearch.search.lookup.LeafSearchLookup;
|
||||
import org.elasticsearch.search.lookup.SearchLookup;
|
||||
|
||||
|
@ -271,7 +270,13 @@ public class MockScriptEngine implements ScriptEngine {
|
|||
SimilarityWeightScript.Factory factory = mockCompiled::createSimilarityWeightScript;
|
||||
return context.factoryClazz.cast(factory);
|
||||
} else if (context.instanceClazz.equals(MovingFunctionScript.class)) {
|
||||
MovingFunctionScript.Factory factory = mockCompiled::createMovingFunctionScript;
|
||||
MovingFunctionScript.Factory factory = () -> new MovingFunctionScript() {
|
||||
@Override
|
||||
public double execute(Map<String, Object> params1, double[] values) {
|
||||
params1.put("_values", values);
|
||||
return (double) script.apply(params1);
|
||||
}
|
||||
};
|
||||
return context.factoryClazz.cast(factory);
|
||||
} else if (context.instanceClazz.equals(ScoreScript.class)) {
|
||||
ScoreScript.Factory factory = new MockScoreScript(script);
|
||||
|
@ -335,10 +340,6 @@ public class MockScriptEngine implements ScriptEngine {
|
|||
return new MockSimilarityWeightScript(script != null ? script : ctx -> 42d);
|
||||
}
|
||||
|
||||
public MovingFunctionScript createMovingFunctionScript() {
|
||||
return new MockMovingFunctionScript();
|
||||
}
|
||||
|
||||
public ScriptedMetricAggContexts.InitScript createMetricAggInitScript(Map<String, Object> params, Map<String, Object> state) {
|
||||
return new MockMetricAggInitScript(params, state, script != null ? script : ctx -> 42d);
|
||||
}
|
||||
|
@ -544,13 +545,6 @@ public class MockScriptEngine implements ScriptEngine {
|
|||
return new Script(ScriptType.INLINE, "mock", script, emptyMap());
|
||||
}
|
||||
|
||||
public class MockMovingFunctionScript extends MovingFunctionScript {
|
||||
@Override
|
||||
public double execute(Map<String, Object> params, double[] values) {
|
||||
return MovingFunctions.unweightedAvg(values);
|
||||
}
|
||||
}
|
||||
|
||||
public class MockScoreScript implements ScoreScript.Factory {
|
||||
|
||||
private final Function<Map<String, Object>, Object> script;
|
||||
|
|
|
@ -351,7 +351,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
|||
Query query,
|
||||
AggregationBuilder builder,
|
||||
MappedFieldType... fieldTypes) throws IOException {
|
||||
return searchAndReduce(searcher, query, builder, DEFAULT_MAX_BUCKETS, null, fieldTypes);
|
||||
return searchAndReduce(searcher, query, builder, DEFAULT_MAX_BUCKETS, fieldTypes);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -363,7 +363,6 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
|||
Query query,
|
||||
AggregationBuilder builder,
|
||||
int maxBucket,
|
||||
ScriptService scriptService,
|
||||
MappedFieldType... fieldTypes) throws IOException {
|
||||
final IndexReaderContext ctx = searcher.getTopReaderContext();
|
||||
|
||||
|
@ -408,7 +407,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
|||
List<InternalAggregation> toReduce = aggs.subList(0, r);
|
||||
MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer(maxBucket);
|
||||
InternalAggregation.ReduceContext context =
|
||||
new InternalAggregation.ReduceContext(root.context().bigArrays(), null,
|
||||
new InternalAggregation.ReduceContext(root.context().bigArrays(), getMockScriptService(),
|
||||
reduceBucketConsumer, false);
|
||||
A reduced = (A) aggs.get(0).doReduce(toReduce, context);
|
||||
doAssertReducedMultiBucketConsumer(reduced, reduceBucketConsumer);
|
||||
|
@ -418,7 +417,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
|||
// now do the final reduce
|
||||
MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer(maxBucket);
|
||||
InternalAggregation.ReduceContext context =
|
||||
new InternalAggregation.ReduceContext(root.context().bigArrays(), scriptService, reduceBucketConsumer, true);
|
||||
new InternalAggregation.ReduceContext(root.context().bigArrays(), getMockScriptService(), reduceBucketConsumer, true);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
A internalAgg = (A) aggs.get(0).doReduce(aggs, context);
|
||||
|
|
Loading…
Reference in New Issue