diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java index 69ac175c419..c11c68f9b25 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java @@ -37,6 +37,7 @@ import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; +import java.util.Collections; import java.util.Map; import java.util.Objects; @@ -198,20 +199,34 @@ public class ScriptedMetricAggregationBuilder extends AbstractAggregationBuilder Builder subfactoriesBuilder) throws IOException { QueryShardContext queryShardContext = context.getQueryShardContext(); + + // Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have + // access to them for the scripts it's given precompiled. + ExecutableScript.Factory executableInitScript; + Map initScriptParams; if (initScript != null) { executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT); + initScriptParams = initScript.getParams(); } else { executableInitScript = p -> null; + initScriptParams = Collections.emptyMap(); } + SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT); + Map mapScriptParams = mapScript.getParams(); + ExecutableScript.Factory executableCombineScript; + Map combineScriptParams; if (combineScript != null) { - executableCombineScript =queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT); + executableCombineScript = queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT); + combineScriptParams = combineScript.getParams(); } else { executableCombineScript = p -> null; + combineScriptParams = Collections.emptyMap(); } - return new ScriptedMetricAggregatorFactory(name, searchMapScript, executableInitScript, executableCombineScript, reduceScript, + return new ScriptedMetricAggregatorFactory(name, searchMapScript, mapScriptParams, executableInitScript, initScriptParams, + executableCombineScript, combineScriptParams, reduceScript, params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java index aa7de3e1ab6..0bc6a614e54 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java @@ -35,28 +35,35 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; public class ScriptedMetricAggregatorFactory extends AggregatorFactory { private final SearchScript.Factory mapScript; + private final Map mapScriptParams; private final ExecutableScript.Factory combineScript; + private final Map combineScriptParams; private final Script reduceScript; - private final Map params; + private final Map aggParams; private final SearchLookup lookup; private final ExecutableScript.Factory initScript; + private final Map initScriptParams; - public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, ExecutableScript.Factory initScript, - ExecutableScript.Factory combineScript, Script reduceScript, Map params, + public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, Map mapScriptParams, + ExecutableScript.Factory initScript, Map initScriptParams, + ExecutableScript.Factory combineScript, Map combineScriptParams, + Script reduceScript, Map aggParams, SearchLookup lookup, SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactories, Map metaData) throws IOException { super(name, context, parent, subFactories, metaData); this.mapScript = mapScript; + this.mapScriptParams = mapScriptParams; this.initScript = initScript; + this.initScriptParams = initScriptParams; this.combineScript = combineScript; + this.combineScriptParams = combineScriptParams; this.reduceScript = reduceScript; this.lookup = lookup; - this.params = params; + this.aggParams = aggParams; } @Override @@ -65,26 +72,26 @@ public class ScriptedMetricAggregatorFactory extends AggregatorFactory params = this.params; - if (params != null) { - params = deepCopyParams(params, context); + Map aggParams = this.aggParams; + if (aggParams != null) { + aggParams = deepCopyParams(aggParams, context); } else { - params = new HashMap<>(); + aggParams = new HashMap<>(); } - if (params.containsKey("_agg") == false) { - params.put("_agg", new HashMap()); + if (aggParams.containsKey("_agg") == false) { + aggParams.put("_agg", new HashMap()); } - final ExecutableScript initScript = this.initScript.newInstance(params); - final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(params, lookup); - final ExecutableScript combineScript = this.combineScript.newInstance(params); + final ExecutableScript initScript = this.initScript.newInstance(mergeParams(aggParams, initScriptParams)); + final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(mergeParams(aggParams, mapScriptParams), lookup); + final ExecutableScript combineScript = this.combineScript.newInstance(mergeParams(aggParams, combineScriptParams)); final Script reduceScript = deepCopyScript(this.reduceScript, context); if (initScript != null) { initScript.run(); } return new ScriptedMetricAggregator(name, mapScript, - combineScript, reduceScript, params, context, parent, + combineScript, reduceScript, aggParams, context, parent, pipelineAggregators, metaData); } @@ -128,5 +135,18 @@ public class ScriptedMetricAggregatorFactory extends AggregatorFactory mergeParams(Map agg, Map script) { + // Start with script params + Map combined = new HashMap<>(script); + // Add in agg params, throwing an exception if any conflicts are detected + for (Map.Entry aggEntry : agg.entrySet()) { + if (combined.putIfAbsent(aggEntry.getKey(), aggEntry.getValue()) != null) { + throw new IllegalArgumentException("Parameter name \"" + aggEntry.getKey() + + "\" used in both aggregation and script parameters"); + } + } + + return combined; + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java index 24d94d5a464..9db5b237a85 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java @@ -20,6 +20,8 @@ package org.elasticsearch.search.aggregations.metrics; import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; @@ -62,6 +64,7 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.scripted import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -322,11 +325,11 @@ public class ScriptedMetricIT extends ESIntegTestCase { assertThat(numShardsRun, greaterThan(0)); } - public void testMapWithParams() { + public void testExplicitAggParam() { Map params = new HashMap<>(); params.put("_agg", new ArrayList<>()); - Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", params); + Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", Collections.emptyMap()); SearchResponse response = client().prepareSearch("idx") .setQuery(matchAllQuery()) @@ -361,17 +364,17 @@ public class ScriptedMetricIT extends ESIntegTestCase { } public void testMapWithParamsAndImplicitAggMap() { - Map params = new HashMap<>(); - // don't put any _agg map in params - params.put("param1", "12"); - params.put("param2", 1); + // Split the params up between the script and the aggregation. + // Don't put any _agg map in params. + Map scriptParams = Collections.singletonMap("param1", "12"); + Map aggregationParams = Collections.singletonMap("param2", 1); // The _agg hashmap will be available even if not declared in the params map - Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", params); + Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", scriptParams); SearchResponse response = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript)) + .addAggregation(scriptedMetric("scripted").params(aggregationParams).mapScript(mapScript)) .get(); assertSearchResponse(response); assertThat(response.getHits().getTotalHits(), equalTo(numDocs)); @@ -1001,4 +1004,16 @@ public class ScriptedMetricIT extends ESIntegTestCase { assertThat(client().admin().indices().prepareStats("cache_test_idx").setRequestCache(true).get().getTotal().getRequestCache() .getMissCount(), equalTo(0L)); } + + public void testConflictingAggAndScriptParams() { + Map params = Collections.singletonMap("param1", "12"); + Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", params); + + SearchRequestBuilder builder = client().prepareSearch("idx") + .setQuery(matchAllQuery()) + .addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript)); + + SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, builder::get); + assertThat(ex.getCause().getMessage(), containsString("Parameter name \"param1\" used in both aggregation and script parameters")); + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java index db2feafe6c4..0989b1ce6a3 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java @@ -64,8 +64,16 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { Collections.emptyMap()); private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore", Collections.emptyMap()); - private static final Map, Object>> SCRIPTS = new HashMap<>(); + private static final Script INIT_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptParams", + Collections.singletonMap("initialValue", 24)); + private static final Script MAP_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptParams", + Collections.singletonMap("itemValue", 12)); + private static final Script COMBINE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptParams", + Collections.singletonMap("divisor", 4)); + private static final String CONFLICTING_PARAM_NAME = "initialValue"; + + private static final Map, Object>> SCRIPTS = new HashMap<>(); @BeforeClass @SuppressWarnings("unchecked") @@ -99,6 +107,26 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { Map agg = (Map) params.get("_agg"); return ((List) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum(); }); + + SCRIPTS.put("initScriptParams", params -> { + Map agg = (Map) params.get("_agg"); + Integer initialValue = (Integer)params.get("initialValue"); + ArrayList collector = new ArrayList(); + collector.add(initialValue); + agg.put("collector", collector); + return agg; + }); + SCRIPTS.put("mapScriptParams", params -> { + Map agg = (Map) params.get("_agg"); + Integer itemValue = (Integer) params.get("itemValue"); + ((List) agg.get("collector")).add(itemValue); + return agg; + }); + SCRIPTS.put("combineScriptParams", params -> { + Map agg = (Map) params.get("_agg"); + int divisor = ((Integer) params.get("divisor")); + return ((List) agg.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum(); + }); } @SuppressWarnings("unchecked") @@ -187,6 +215,48 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { } } + public void testScriptParamsPassedThrough() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + for (int i = 0; i < 100; i++) { + indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i))); + } + } + + try (IndexReader indexReader = DirectoryReader.open(directory)) { + ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME); + aggregationBuilder.initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS).combineScript(COMBINE_SCRIPT_PARAMS); + ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder); + + // The result value depends on the script params. + assertEquals(306, scriptedMetric.aggregation()); + } + } + } + + public void testConflictingAggAndScriptParams() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + for (int i = 0; i < 100; i++) { + indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i))); + } + } + + try (IndexReader indexReader = DirectoryReader.open(directory)) { + ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME); + Map aggParams = Collections.singletonMap(CONFLICTING_PARAM_NAME, "blah"); + aggregationBuilder.params(aggParams).initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS). + combineScript(COMBINE_SCRIPT_PARAMS); + + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> + search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder) + ); + assertEquals("Parameter name \"" + CONFLICTING_PARAM_NAME + "\" used in both aggregation and script parameters", + ex.getMessage()); + } + } + } + /** * We cannot use Mockito for mocking QueryShardContext in this case because * script-related methods (e.g. QueryShardContext#getLazyExecutableScript)