Migrate scripted metric aggregation scripts to ScriptContext design (#30111)

* Migrate scripted metric aggregation scripts to ScriptContext design #29328

* Rename new script context container class and add clarifying comments to remaining references to params._agg(s)

* Misc cleanup: make mock metric agg script inner classes static

* Move _score to an accessor rather than an arg for scripted metric agg scripts

This causes the score to be evaluated only when it's used.

* Documentation changes for params._agg -> agg

* Migration doc addition for scripted metric aggs _agg object change

* Rename "agg" Scripted Metric Aggregation script context variable to "state"

* Rename a private base class from ...Agg to ...State that I missed in my last commit

* Clean up imports after merge
This commit is contained in:
Jonathan Little 2018-06-25 04:01:33 -07:00 committed by Colin Goodheart-Smithe
parent 8b698f0bce
commit 8e4768890a
13 changed files with 619 additions and 111 deletions

View File

@ -13,8 +13,8 @@ Here is an example on how to create the aggregation request:
--------------------------------------------------
ScriptedMetricAggregationBuilder aggregation = AggregationBuilders
.scriptedMetric("agg")
.initScript(new Script("params._agg.heights = []"))
.mapScript(new Script("params._agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)"));
.initScript(new Script("state.heights = []"))
.mapScript(new Script("state.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)"));
--------------------------------------------------
You can also specify a `combine` script which will be executed on each shard:
@ -23,9 +23,9 @@ You can also specify a `combine` script which will be executed on each shard:
--------------------------------------------------
ScriptedMetricAggregationBuilder aggregation = AggregationBuilders
.scriptedMetric("agg")
.initScript(new Script("params._agg.heights = []"))
.mapScript(new Script("params._agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)"))
.combineScript(new Script("double heights_sum = 0.0; for (t in params._agg.heights) { heights_sum += t } return heights_sum"));
.initScript(new Script("state.heights = []"))
.mapScript(new Script("state.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)"))
.combineScript(new Script("double heights_sum = 0.0; for (t in state.heights) { heights_sum += t } return heights_sum"));
--------------------------------------------------
You can also specify a `reduce` script which will be executed on the node which gets the request:
@ -34,10 +34,10 @@ You can also specify a `reduce` script which will be executed on the node which
--------------------------------------------------
ScriptedMetricAggregationBuilder aggregation = AggregationBuilders
.scriptedMetric("agg")
.initScript(new Script("params._agg.heights = []"))
.mapScript(new Script("params._agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)"))
.combineScript(new Script("double heights_sum = 0.0; for (t in params._agg.heights) { heights_sum += t } return heights_sum"))
.reduceScript(new Script("double heights_sum = 0.0; for (a in params._aggs) { heights_sum += a } return heights_sum"));
.initScript(new Script("state.heights = []"))
.mapScript(new Script("state.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)"))
.combineScript(new Script("double heights_sum = 0.0; for (t in state.heights) { heights_sum += t } return heights_sum"))
.reduceScript(new Script("double heights_sum = 0.0; for (a in states) { heights_sum += a } return heights_sum"));
--------------------------------------------------

View File

@ -15,10 +15,10 @@ POST ledger/_search?size=0
"aggs": {
"profit": {
"scripted_metric": {
"init_script" : "params._agg.transactions = []",
"map_script" : "params._agg.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)", <1>
"combine_script" : "double profit = 0; for (t in params._agg.transactions) { profit += t } return profit",
"reduce_script" : "double profit = 0; for (a in params._aggs) { profit += a } return profit"
"init_script" : "state.transactions = []",
"map_script" : "state.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)", <1>
"combine_script" : "double profit = 0; for (t in state.transactions) { profit += t } return profit",
"reduce_script" : "double profit = 0; for (a in states) { profit += a } return profit"
}
}
}
@ -67,8 +67,7 @@ POST ledger/_search?size=0
"id": "my_combine_script"
},
"params": {
"field": "amount", <1>
"_agg": {} <2>
"field": "amount" <1>
},
"reduce_script" : {
"id": "my_reduce_script"
@ -82,8 +81,7 @@ POST ledger/_search?size=0
// TEST[setup:ledger,stored_scripted_metric_script]
<1> script parameters for `init`, `map` and `combine` scripts must be specified
in a global `params` object so that it can be share between the scripts.
<2> if you specify script parameters then you must specify `"_agg": {}`.
in a global `params` object so that it can be shared between the scripts.
////
Verify this response as well but in a hidden block.
@ -108,7 +106,7 @@ For more details on specifying scripts see <<modules-scripting, script documenta
==== Allowed return types
Whilst any valid script object can be used within a single script, the scripts must return or store in the `_agg` object only the following types:
Whilst any valid script object can be used within a single script, the scripts must return or store in the `state` object only the following types:
* primitive types
* String
@ -121,10 +119,10 @@ The scripted metric aggregation uses scripts at 4 stages of its execution:
init_script:: Executed prior to any collection of documents. Allows the aggregation to set up any initial state.
+
In the above example, the `init_script` creates an array `transactions` in the `_agg` object.
In the above example, the `init_script` creates an array `transactions` in the `state` object.
map_script:: Executed once per document collected. This is the only required script. If no combine_script is specified, the resulting state
needs to be stored in an object named `_agg`.
needs to be stored in the `state` object.
+
In the above example, the `map_script` checks the value of the type field. If the value is 'sale' the value of the amount field
is added to the transactions array. If the value of the type field is not 'sale' the negated value of the amount field is added
@ -137,8 +135,8 @@ In the above example, the `combine_script` iterates through all the stored trans
and finally returns `profit`.
reduce_script:: Executed once on the coordinating node after all shards have returned their results. The script is provided with access to a
variable `_aggs` which is an array of the result of the combine_script on each shard. If a reduce_script is not provided
the reduce phase will return the `_aggs` variable.
variable `states` which is an array of the result of the combine_script on each shard. If a reduce_script is not provided
the reduce phase will return the `states` variable.
+
In the above example, the `reduce_script` iterates through the `profit` returned by each shard summing the values before returning the
final combined profit which will be returned in the response of the aggregation.
@ -166,13 +164,11 @@ at each stage of the example above.
===== Before init_script
No params object was specified so the default params object is used:
`state` is initialized as a new empty object.
[source,js]
--------------------------------------------------
"params" : {
"_agg" : {}
}
"state" : {}
--------------------------------------------------
// NOTCONSOLE
@ -184,10 +180,8 @@ Shard A::
+
[source,js]
--------------------------------------------------
"params" : {
"_agg" : {
"transactions" : []
}
"state" : {
"transactions" : []
}
--------------------------------------------------
// NOTCONSOLE
@ -196,10 +190,8 @@ Shard B::
+
[source,js]
--------------------------------------------------
"params" : {
"_agg" : {
"transactions" : []
}
"state" : {
"transactions" : []
}
--------------------------------------------------
// NOTCONSOLE
@ -212,10 +204,8 @@ Shard A::
+
[source,js]
--------------------------------------------------
"params" : {
"_agg" : {
"transactions" : [ 80, -30 ]
}
"state" : {
"transactions" : [ 80, -30 ]
}
--------------------------------------------------
// NOTCONSOLE
@ -224,10 +214,8 @@ Shard B::
+
[source,js]
--------------------------------------------------
"params" : {
"_agg" : {
"transactions" : [ -10, 130 ]
}
"state" : {
"transactions" : [ -10, 130 ]
}
--------------------------------------------------
// NOTCONSOLE
@ -242,11 +230,11 @@ Shard B:: 120
===== After reduce_script
The reduce_script receives an `_aggs` array containing the result of the combine script for each shard:
The reduce_script receives a `states` array containing the result of the combine script for each shard:
[source,js]
--------------------------------------------------
"_aggs" : [
"states" : [
50,
120
]
@ -279,14 +267,12 @@ params:: Optional. An object whose contents will be passed as variable
+
[source,js]
--------------------------------------------------
"params" : {
"_agg" : {}
}
"params" : {}
--------------------------------------------------
// NOTCONSOLE
==== Empty Buckets
If a parent bucket of the scripted metric aggregation does not collect any documents an empty aggregation response will be returned from the
shard with a `null` value. In this case the `reduce_script`'s `_aggs` variable will contain `null` as a response from that shard.
shard with a `null` value. In this case the `reduce_script`'s `states` variable will contain `null` as a response from that shard.
`reduce_script`'s should therefore expect and deal with `null` responses from shards.

View File

@ -14,4 +14,12 @@ Requests that try to return more than the limit will fail with an exception.
==== `missing` option of the `composite` aggregation has been removed
The `missing` option of the `composite` aggregation, deprecated in 6.x,
has been removed. `missing_bucket` should be used instead.
has been removed. `missing_bucket` should be used instead.
==== Replaced `params._agg` with `state` context variable in scripted metric aggregations
The object used to share aggregation state between the scripts in a Scripted Metric
Aggregation is now a variable called `state` available in the script context, rather than
being provided via the `params` object as `params._agg`.
The old `params._agg` variable is still available as well.

View File

@ -0,0 +1,126 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.painless;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Scorer;
import org.elasticsearch.painless.spi.Whitelist;
import org.elasticsearch.script.ScriptedMetricAggContexts;
import org.elasticsearch.script.ScriptContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ScriptedMetricAggContextsTests extends ScriptTestCase {
@Override
protected Map<ScriptContext<?>, List<Whitelist>> scriptContexts() {
Map<ScriptContext<?>, List<Whitelist>> contexts = new HashMap<>();
contexts.put(ScriptedMetricAggContexts.InitScript.CONTEXT, Whitelist.BASE_WHITELISTS);
contexts.put(ScriptedMetricAggContexts.MapScript.CONTEXT, Whitelist.BASE_WHITELISTS);
contexts.put(ScriptedMetricAggContexts.CombineScript.CONTEXT, Whitelist.BASE_WHITELISTS);
contexts.put(ScriptedMetricAggContexts.ReduceScript.CONTEXT, Whitelist.BASE_WHITELISTS);
return contexts;
}
public void testInitBasic() {
ScriptedMetricAggContexts.InitScript.Factory factory = scriptEngine.compile("test",
"state.testField = params.initialVal", ScriptedMetricAggContexts.InitScript.CONTEXT, Collections.emptyMap());
Map<String, Object> params = new HashMap<>();
Map<String, Object> state = new HashMap<>();
params.put("initialVal", 10);
ScriptedMetricAggContexts.InitScript script = factory.newInstance(params, state);
script.execute();
assert(state.containsKey("testField"));
assertEquals(10, state.get("testField"));
}
public void testMapBasic() {
ScriptedMetricAggContexts.MapScript.Factory factory = scriptEngine.compile("test",
"state.testField = 2*_score", ScriptedMetricAggContexts.MapScript.CONTEXT, Collections.emptyMap());
Map<String, Object> params = new HashMap<>();
Map<String, Object> state = new HashMap<>();
Scorer scorer = new Scorer(null) {
@Override
public int docID() { return 0; }
@Override
public float score() { return 0.5f; }
@Override
public DocIdSetIterator iterator() { return null; }
};
ScriptedMetricAggContexts.MapScript.LeafFactory leafFactory = factory.newFactory(params, state, null);
ScriptedMetricAggContexts.MapScript script = leafFactory.newInstance(null);
script.setScorer(scorer);
script.execute();
assert(state.containsKey("testField"));
assertEquals(1.0, state.get("testField"));
}
public void testCombineBasic() {
ScriptedMetricAggContexts.CombineScript.Factory factory = scriptEngine.compile("test",
"state.testField = params.initialVal; return state.testField + params.inc", ScriptedMetricAggContexts.CombineScript.CONTEXT,
Collections.emptyMap());
Map<String, Object> params = new HashMap<>();
Map<String, Object> state = new HashMap<>();
params.put("initialVal", 10);
params.put("inc", 2);
ScriptedMetricAggContexts.CombineScript script = factory.newInstance(params, state);
Object res = script.execute();
assert(state.containsKey("testField"));
assertEquals(10, state.get("testField"));
assertEquals(12, res);
}
public void testReduceBasic() {
ScriptedMetricAggContexts.ReduceScript.Factory factory = scriptEngine.compile("test",
"states[0].testField + states[1].testField", ScriptedMetricAggContexts.ReduceScript.CONTEXT, Collections.emptyMap());
Map<String, Object> params = new HashMap<>();
List<Object> states = new ArrayList<>();
Map<String, Object> state1 = new HashMap<>(), state2 = new HashMap<>();
state1.put("testField", 1);
state2.put("testField", 2);
states.add(state1);
states.add(state2);
ScriptedMetricAggContexts.ReduceScript script = factory.newInstance(params, states);
Object res = script.execute();
assertEquals(3, res);
}
}

View File

@ -53,7 +53,11 @@ public class ScriptModule {
SimilarityScript.CONTEXT,
SimilarityWeightScript.CONTEXT,
TemplateScript.CONTEXT,
MovingFunctionScript.CONTEXT
MovingFunctionScript.CONTEXT,
ScriptedMetricAggContexts.InitScript.CONTEXT,
ScriptedMetricAggContexts.MapScript.CONTEXT,
ScriptedMetricAggContexts.CombineScript.CONTEXT,
ScriptedMetricAggContexts.ReduceScript.CONTEXT
).collect(Collectors.toMap(c -> c.name, Function.identity()));
}

View File

@ -0,0 +1,161 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.script;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Scorer;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.search.lookup.LeafSearchLookup;
import org.elasticsearch.search.lookup.SearchLookup;
import java.io.IOException;
import java.util.List;
import java.util.Map;
public class ScriptedMetricAggContexts {
private abstract static class ParamsAndStateBase {
private final Map<String, Object> params;
private final Object state;
ParamsAndStateBase(Map<String, Object> params, Object state) {
this.params = params;
this.state = state;
}
public Map<String, Object> getParams() {
return params;
}
public Object getState() {
return state;
}
}
public abstract static class InitScript extends ParamsAndStateBase {
public InitScript(Map<String, Object> params, Object state) {
super(params, state);
}
public abstract void execute();
public interface Factory {
InitScript newInstance(Map<String, Object> params, Object state);
}
public static String[] PARAMETERS = {};
public static ScriptContext<Factory> CONTEXT = new ScriptContext<>("aggs_init", Factory.class);
}
public abstract static class MapScript extends ParamsAndStateBase {
private final LeafSearchLookup leafLookup;
private Scorer scorer;
public MapScript(Map<String, Object> params, Object state, SearchLookup lookup, LeafReaderContext leafContext) {
super(params, state);
this.leafLookup = leafContext == null ? null : lookup.getLeafSearchLookup(leafContext);
}
// Return the doc as a map (instead of LeafDocLookup) in order to abide by type whitelisting rules for
// Painless scripts.
public Map<String, ScriptDocValues<?>> getDoc() {
return leafLookup == null ? null : leafLookup.doc();
}
public void setDocument(int docId) {
if (leafLookup != null) {
leafLookup.setDocument(docId);
}
}
public void setScorer(Scorer scorer) {
this.scorer = scorer;
}
// get_score() is named this way so that it's picked up by Painless as '_score'
public double get_score() {
if (scorer == null) {
return 0.0;
}
try {
return scorer.score();
} catch (IOException e) {
throw new ElasticsearchException("Couldn't look up score", e);
}
}
public abstract void execute();
public interface LeafFactory {
MapScript newInstance(LeafReaderContext ctx);
}
public interface Factory {
LeafFactory newFactory(Map<String, Object> params, Object state, SearchLookup lookup);
}
public static String[] PARAMETERS = new String[] {};
public static ScriptContext<Factory> CONTEXT = new ScriptContext<>("aggs_map", Factory.class);
}
public abstract static class CombineScript extends ParamsAndStateBase {
public CombineScript(Map<String, Object> params, Object state) {
super(params, state);
}
public abstract Object execute();
public interface Factory {
CombineScript newInstance(Map<String, Object> params, Object state);
}
public static String[] PARAMETERS = {};
public static ScriptContext<Factory> CONTEXT = new ScriptContext<>("aggs_combine", Factory.class);
}
public abstract static class ReduceScript {
private final Map<String, Object> params;
private final List<Object> states;
public ReduceScript(Map<String, Object> params, List<Object> states) {
this.params = params;
this.states = states;
}
public Map<String, Object> getParams() {
return params;
}
public List<Object> getStates() {
return states;
}
public abstract Object execute();
public interface Factory {
ReduceScript newInstance(Map<String, Object> params, List<Object> states);
}
public static String[] PARAMETERS = {};
public static ScriptContext<Factory> CONTEXT = new ScriptContext<>("aggs_reduce", Factory.class);
}
}

View File

@ -23,7 +23,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.ScriptedMetricAggContexts;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
@ -90,16 +90,19 @@ public class InternalScriptedMetric extends InternalAggregation implements Scrip
InternalScriptedMetric firstAggregation = ((InternalScriptedMetric) aggregations.get(0));
List<Object> aggregation;
if (firstAggregation.reduceScript != null && reduceContext.isFinalReduce()) {
Map<String, Object> vars = new HashMap<>();
vars.put("_aggs", aggregationObjects);
Map<String, Object> params = new HashMap<>();
if (firstAggregation.reduceScript.getParams() != null) {
vars.putAll(firstAggregation.reduceScript.getParams());
params.putAll(firstAggregation.reduceScript.getParams());
}
ExecutableScript.Factory factory = reduceContext.scriptService().compile(
firstAggregation.reduceScript, ExecutableScript.AGGS_CONTEXT);
ExecutableScript script = factory.newInstance(vars);
Object scriptResult = script.run();
// Add _aggs to params map for backwards compatibility (redundant with a context variable on the ReduceScript created below).
params.put("_aggs", aggregationObjects);
ScriptedMetricAggContexts.ReduceScript.Factory factory = reduceContext.scriptService().compile(
firstAggregation.reduceScript, ScriptedMetricAggContexts.ReduceScript.CONTEXT);
ScriptedMetricAggContexts.ReduceScript script = factory.newInstance(params, aggregationObjects);
Object scriptResult = script.execute();
CollectionUtils.ensureNoSelfReferences(scriptResult, "reduce script");
aggregation = Collections.singletonList(scriptResult);

View File

@ -26,9 +26,8 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.ScriptedMetricAggContexts;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.SearchScript;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories.Builder;
@ -202,30 +201,32 @@ public class ScriptedMetricAggregationBuilder extends AbstractAggregationBuilder
// Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have
// access to them for the scripts it's given precompiled.
ExecutableScript.Factory executableInitScript;
ScriptedMetricAggContexts.InitScript.Factory compiledInitScript;
Map<String, Object> initScriptParams;
if (initScript != null) {
executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT);
compiledInitScript = queryShardContext.getScriptService().compile(initScript, ScriptedMetricAggContexts.InitScript.CONTEXT);
initScriptParams = initScript.getParams();
} else {
executableInitScript = p -> null;
compiledInitScript = (p, a) -> null;
initScriptParams = Collections.emptyMap();
}
SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT);
ScriptedMetricAggContexts.MapScript.Factory compiledMapScript = queryShardContext.getScriptService().compile(mapScript,
ScriptedMetricAggContexts.MapScript.CONTEXT);
Map<String, Object> mapScriptParams = mapScript.getParams();
ExecutableScript.Factory executableCombineScript;
ScriptedMetricAggContexts.CombineScript.Factory compiledCombineScript;
Map<String, Object> combineScriptParams;
if (combineScript != null) {
executableCombineScript = queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT);
compiledCombineScript = queryShardContext.getScriptService().compile(combineScript,
ScriptedMetricAggContexts.CombineScript.CONTEXT);
combineScriptParams = combineScript.getParams();
} else {
executableCombineScript = p -> null;
compiledCombineScript = (p, a) -> null;
combineScriptParams = Collections.emptyMap();
}
return new ScriptedMetricAggregatorFactory(name, searchMapScript, mapScriptParams, executableInitScript, initScriptParams,
executableCombineScript, combineScriptParams, reduceScript,
return new ScriptedMetricAggregatorFactory(name, compiledMapScript, mapScriptParams, compiledInitScript,
initScriptParams, compiledCombineScript, combineScriptParams, reduceScript,
params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData);
}

View File

@ -20,10 +20,10 @@
package org.elasticsearch.search.aggregations.metrics.scripted;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Scorer;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.ScriptedMetricAggContexts;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.SearchScript;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
@ -38,17 +38,17 @@ import java.util.Map;
public class ScriptedMetricAggregator extends MetricsAggregator {
private final SearchScript.LeafFactory mapScript;
private final ExecutableScript combineScript;
private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript;
private final ScriptedMetricAggContexts.CombineScript combineScript;
private final Script reduceScript;
private Map<String, Object> params;
private Object aggState;
protected ScriptedMetricAggregator(String name, SearchScript.LeafFactory mapScript, ExecutableScript combineScript,
Script reduceScript,
Map<String, Object> params, SearchContext context, Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData)
throws IOException {
protected ScriptedMetricAggregator(String name, ScriptedMetricAggContexts.MapScript.LeafFactory mapScript, ScriptedMetricAggContexts.CombineScript combineScript,
Script reduceScript, Object aggState, SearchContext context, Aggregator parent,
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData)
throws IOException {
super(name, context, parent, pipelineAggregators, metaData);
this.params = params;
this.aggState = aggState;
this.mapScript = mapScript;
this.combineScript = combineScript;
this.reduceScript = reduceScript;
@ -62,14 +62,20 @@ public class ScriptedMetricAggregator extends MetricsAggregator {
@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
final LeafBucketCollector sub) throws IOException {
final SearchScript leafMapScript = mapScript.newInstance(ctx);
final ScriptedMetricAggContexts.MapScript leafMapScript = mapScript.newInstance(ctx);
return new LeafBucketCollectorBase(sub, leafMapScript) {
@Override
public void setScorer(Scorer scorer) throws IOException {
leafMapScript.setScorer(scorer);
}
@Override
public void collect(int doc, long bucket) throws IOException {
assert bucket == 0 : bucket;
leafMapScript.setDocument(doc);
leafMapScript.run();
CollectionUtils.ensureNoSelfReferences(params, "Scripted metric aggs map script");
leafMapScript.execute();
CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs map script");
}
};
}
@ -78,10 +84,10 @@ public class ScriptedMetricAggregator extends MetricsAggregator {
public InternalAggregation buildAggregation(long owningBucketOrdinal) {
Object aggregation;
if (combineScript != null) {
aggregation = combineScript.run();
aggregation = combineScript.execute();
CollectionUtils.ensureNoSelfReferences(aggregation, "Scripted metric aggs combine script");
} else {
aggregation = params.get("_agg");
aggregation = aggState;
}
return new InternalScriptedMetric(name, aggregation, reduceScript, pipelineAggregators(),
metaData());

View File

@ -19,10 +19,9 @@
package org.elasticsearch.search.aggregations.metrics.scripted;
import org.elasticsearch.script.ScriptedMetricAggContexts;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.SearchScript;
import org.elasticsearch.search.SearchParseException;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
@ -39,20 +38,21 @@ import java.util.Map;
public class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedMetricAggregatorFactory> {
private final SearchScript.Factory mapScript;
private final ScriptedMetricAggContexts.MapScript.Factory mapScript;
private final Map<String, Object> mapScriptParams;
private final ExecutableScript.Factory combineScript;
private final ScriptedMetricAggContexts.CombineScript.Factory combineScript;
private final Map<String, Object> combineScriptParams;
private final Script reduceScript;
private final Map<String, Object> aggParams;
private final SearchLookup lookup;
private final ExecutableScript.Factory initScript;
private final ScriptedMetricAggContexts.InitScript.Factory initScript;
private final Map<String, Object> initScriptParams;
public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, Map<String, Object> mapScriptParams,
ExecutableScript.Factory initScript, Map<String, Object> initScriptParams,
ExecutableScript.Factory combineScript, Map<String, Object> combineScriptParams,
Script reduceScript, Map<String, Object> aggParams,
public ScriptedMetricAggregatorFactory(String name,
ScriptedMetricAggContexts.MapScript.Factory mapScript, Map<String, Object> mapScriptParams,
ScriptedMetricAggContexts.InitScript.Factory initScript, Map<String, Object> initScriptParams,
ScriptedMetricAggContexts.CombineScript.Factory combineScript,
Map<String, Object> combineScriptParams, Script reduceScript, Map<String, Object> aggParams,
SearchLookup lookup, SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactories, Map<String, Object> metaData) throws IOException {
super(name, context, parent, subFactories, metaData);
@ -79,21 +79,29 @@ public class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedM
} else {
aggParams = new HashMap<>();
}
// Add _agg to params map for backwards compatibility (redundant with context variables on the scripts created below).
// When this is removed, aggState (as passed to ScriptedMetricAggregator) can be changed to Map<String, Object>, since
// it won't be possible to completely replace it with another type as is possible when it's an entry in params.
if (aggParams.containsKey("_agg") == false) {
aggParams.put("_agg", new HashMap<String, Object>());
}
Object aggState = aggParams.get("_agg");
final ExecutableScript initScript = this.initScript.newInstance(mergeParams(aggParams, initScriptParams));
final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(mergeParams(aggParams, mapScriptParams), lookup);
final ExecutableScript combineScript = this.combineScript.newInstance(mergeParams(aggParams, combineScriptParams));
final ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance(
mergeParams(aggParams, initScriptParams), aggState);
final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript = this.mapScript.newFactory(
mergeParams(aggParams, mapScriptParams), aggState, lookup);
final ScriptedMetricAggContexts.CombineScript combineScript = this.combineScript.newInstance(
mergeParams(aggParams, combineScriptParams), aggState);
final Script reduceScript = deepCopyScript(this.reduceScript, context);
if (initScript != null) {
initScript.run();
CollectionUtils.ensureNoSelfReferences(aggParams.get("_agg"), "Scripted metric aggs init script");
initScript.execute();
CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs init script");
}
return new ScriptedMetricAggregator(name, mapScript,
combineScript, reduceScript, aggParams, context, parent,
combineScript, reduceScript, aggState, context, parent,
pipelineAggregators, metaData);
}

View File

@ -193,14 +193,55 @@ public class ScriptedMetricIT extends ESIntegTestCase {
return newAggregation;
});
scripts.put("state.items = new ArrayList()", vars ->
aggContextScript(vars, state -> ((HashMap) state).put("items", new ArrayList())));
scripts.put("state.items.add(1)", vars ->
aggContextScript(vars, state -> {
HashMap stateMap = (HashMap) state;
List items = (List) stateMap.get("items");
items.add(1);
}));
scripts.put("sum context state values", vars -> {
int sum = 0;
HashMap state = (HashMap) vars.get("state");
List items = (List) state.get("items");
for (Object x : items) {
sum += (Integer)x;
}
return sum;
});
scripts.put("sum context states", vars -> {
Integer sum = 0;
List<?> states = (List<?>) vars.get("states");
for (Object state : states) {
sum += ((Number) state).intValue();
}
return sum;
});
return scripts;
}
@SuppressWarnings("unchecked")
static <T> Object aggScript(Map<String, Object> vars, Consumer<T> fn) {
T agg = (T) vars.get("_agg");
fn.accept(agg);
return agg;
return aggScript(vars, fn, "_agg");
}
static <T> Object aggContextScript(Map<String, Object> vars, Consumer<T> fn) {
return aggScript(vars, fn, "state");
}
@SuppressWarnings("unchecked")
private static <T> Object aggScript(Map<String, Object> vars, Consumer<T> fn, String stateVarName) {
T aggState = (T) vars.get(stateVarName);
fn.accept(aggState);
return aggState;
}
}
@ -1015,4 +1056,37 @@ public class ScriptedMetricIT extends ESIntegTestCase {
SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, builder::get);
assertThat(ex.getCause().getMessage(), containsString("Parameter name \"param1\" used in both aggregation and script parameters"));
}
public void testAggFromContext() {
Script initScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "state.items = new ArrayList()", Collections.emptyMap());
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "state.items.add(1)", Collections.emptyMap());
Script combineScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "sum context state values", Collections.emptyMap());
Script reduceScript =
new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "sum context states",
Collections.emptyMap());
SearchResponse response = client()
.prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(
scriptedMetric("scripted")
.initScript(initScript)
.mapScript(mapScript)
.combineScript(combineScript)
.reduceScript(reduceScript))
.get();
Aggregation aggregation = response.getAggregations().get("scripted");
assertThat(aggregation, notNullValue());
assertThat(aggregation, instanceOf(ScriptedMetric.class));
ScriptedMetric scriptedMetricAggregation = (ScriptedMetric) aggregation;
assertThat(scriptedMetricAggregation.getName(), equalTo("scripted"));
assertThat(scriptedMetricAggregation.aggregation(), notNullValue());
assertThat(scriptedMetricAggregation.aggregation(), instanceOf(Integer.class));
Integer aggResult = (Integer) scriptedMetricAggregation.aggregation();
long totalAgg = aggResult.longValue();
assertThat(totalAgg, equalTo(numDocs));
}
}

View File

@ -31,7 +31,6 @@ import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.script.MockScriptEngine;
import org.elasticsearch.script.ScoreAccessor;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.ScriptModule;
@ -107,7 +106,7 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
});
SCRIPTS.put("mapScriptScore", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
((List<Double>) agg.get("collector")).add(((ScoreAccessor) params.get("_score")).doubleValue());
((List<Double>) agg.get("collector")).add(((Number) params.get("_score")).doubleValue());
return agg;
});
SCRIPTS.put("combineScriptScore", params -> {

View File

@ -33,6 +33,7 @@ import org.elasticsearch.search.lookup.SearchLookup;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
@ -115,6 +116,18 @@ public class MockScriptEngine implements ScriptEngine {
} else if (context.instanceClazz.equals(ScoreScript.class)) {
ScoreScript.Factory factory = new MockScoreScript(script);
return context.factoryClazz.cast(factory);
} else if (context.instanceClazz.equals(ScriptedMetricAggContexts.InitScript.class)) {
ScriptedMetricAggContexts.InitScript.Factory factory = mockCompiled::createMetricAggInitScript;
return context.factoryClazz.cast(factory);
} else if (context.instanceClazz.equals(ScriptedMetricAggContexts.MapScript.class)) {
ScriptedMetricAggContexts.MapScript.Factory factory = mockCompiled::createMetricAggMapScript;
return context.factoryClazz.cast(factory);
} else if (context.instanceClazz.equals(ScriptedMetricAggContexts.CombineScript.class)) {
ScriptedMetricAggContexts.CombineScript.Factory factory = mockCompiled::createMetricAggCombineScript;
return context.factoryClazz.cast(factory);
} else if (context.instanceClazz.equals(ScriptedMetricAggContexts.ReduceScript.class)) {
ScriptedMetricAggContexts.ReduceScript.Factory factory = mockCompiled::createMetricAggReduceScript;
return context.factoryClazz.cast(factory);
}
throw new IllegalArgumentException("mock script engine does not know how to handle context [" + context.name + "]");
}
@ -179,6 +192,23 @@ public class MockScriptEngine implements ScriptEngine {
public MovingFunctionScript createMovingFunctionScript() {
return new MockMovingFunctionScript();
}
public ScriptedMetricAggContexts.InitScript createMetricAggInitScript(Map<String, Object> params, Object state) {
return new MockMetricAggInitScript(params, state, script != null ? script : ctx -> 42d);
}
public ScriptedMetricAggContexts.MapScript.LeafFactory createMetricAggMapScript(Map<String, Object> params, Object state,
SearchLookup lookup) {
return new MockMetricAggMapScript(params, state, lookup, script != null ? script : ctx -> 42d);
}
public ScriptedMetricAggContexts.CombineScript createMetricAggCombineScript(Map<String, Object> params, Object state) {
return new MockMetricAggCombineScript(params, state, script != null ? script : ctx -> 42d);
}
public ScriptedMetricAggContexts.ReduceScript createMetricAggReduceScript(Map<String, Object> params, List<Object> states) {
return new MockMetricAggReduceScript(params, states, script != null ? script : ctx -> 42d);
}
}
public class MockExecutableScript implements ExecutableScript {
@ -333,6 +363,108 @@ public class MockScriptEngine implements ScriptEngine {
}
}
public static class MockMetricAggInitScript extends ScriptedMetricAggContexts.InitScript {
private final Function<Map<String, Object>, Object> script;
MockMetricAggInitScript(Map<String, Object> params, Object state,
Function<Map<String, Object>, Object> script) {
super(params, state);
this.script = script;
}
public void execute() {
Map<String, Object> map = new HashMap<>();
if (getParams() != null) {
map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key
map.put("params", getParams());
}
map.put("state", getState());
script.apply(map);
}
}
public static class MockMetricAggMapScript implements ScriptedMetricAggContexts.MapScript.LeafFactory {
private final Map<String, Object> params;
private final Object state;
private final SearchLookup lookup;
private final Function<Map<String, Object>, Object> script;
MockMetricAggMapScript(Map<String, Object> params, Object state, SearchLookup lookup,
Function<Map<String, Object>, Object> script) {
this.params = params;
this.state = state;
this.lookup = lookup;
this.script = script;
}
@Override
public ScriptedMetricAggContexts.MapScript newInstance(LeafReaderContext context) {
return new ScriptedMetricAggContexts.MapScript(params, state, lookup, context) {
@Override
public void execute() {
Map<String, Object> map = new HashMap<>();
if (getParams() != null) {
map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key
map.put("params", getParams());
}
map.put("state", getState());
map.put("doc", getDoc());
map.put("_score", get_score());
script.apply(map);
}
};
}
}
public static class MockMetricAggCombineScript extends ScriptedMetricAggContexts.CombineScript {
private final Function<Map<String, Object>, Object> script;
MockMetricAggCombineScript(Map<String, Object> params, Object state,
Function<Map<String, Object>, Object> script) {
super(params, state);
this.script = script;
}
public Object execute() {
Map<String, Object> map = new HashMap<>();
if (getParams() != null) {
map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key
map.put("params", getParams());
}
map.put("state", getState());
return script.apply(map);
}
}
public static class MockMetricAggReduceScript extends ScriptedMetricAggContexts.ReduceScript {
private final Function<Map<String, Object>, Object> script;
MockMetricAggReduceScript(Map<String, Object> params, List<Object> states,
Function<Map<String, Object>, Object> script) {
super(params, states);
this.script = script;
}
public Object execute() {
Map<String, Object> map = new HashMap<>();
if (getParams() != null) {
map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key
map.put("params", getParams());
}
map.put("states", getStates());
return script.apply(map);
}
}
public static Script mockInlineScript(final String script) {
return new Script(ScriptType.INLINE, "mock", script, emptyMap());
}
@ -343,15 +475,15 @@ public class MockScriptEngine implements ScriptEngine {
return MovingFunctions.unweightedAvg(values);
}
}
public class MockScoreScript implements ScoreScript.Factory {
private final Function<Map<String, Object>, Object> scripts;
MockScoreScript(Function<Map<String, Object>, Object> scripts) {
this.scripts = scripts;
}
@Override
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup) {
return new ScoreScript.LeafFactory() {
@ -359,7 +491,7 @@ public class MockScriptEngine implements ScriptEngine {
public boolean needs_score() {
return true;
}
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
Scorer[] scorerHolder = new Scorer[1];
@ -373,7 +505,7 @@ public class MockScriptEngine implements ScriptEngine {
}
return ((Number) scripts.apply(vars)).doubleValue();
}
@Override
public void setScorer(Scorer scorer) {
scorerHolder[0] = scorer;