Pass through script params in scripted metric agg (#29154)

* Pass script level params into scripted metric aggs (#28819)

Now params that are passed at the script level and at the aggregation level
are merged and can both be used in the aggregation scripts. If there are
any conflicts, aggregation level params will win. This may be followed
by another change detecting that case and throwing an exception to
disallow such conflicts.

* Disallow duplicate parameter names between scripted agg and script (#28819)

If a scripted metric aggregation has aggregation params and script params
which have the same name, throw an IllegalArgumentException when merging
the parameter lists.
This commit is contained in:
rationull 2018-04-03 01:57:49 -07:00 committed by Colin Goodheart-Smithe
parent f8602b1c7e
commit 0028563aac
4 changed files with 146 additions and 26 deletions

View File

@ -37,6 +37,7 @@ import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -198,20 +199,34 @@ public class ScriptedMetricAggregationBuilder extends AbstractAggregationBuilder
Builder subfactoriesBuilder) throws IOException { Builder subfactoriesBuilder) throws IOException {
QueryShardContext queryShardContext = context.getQueryShardContext(); QueryShardContext queryShardContext = context.getQueryShardContext();
// Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have
// access to them for the scripts it's given precompiled.
ExecutableScript.Factory executableInitScript; ExecutableScript.Factory executableInitScript;
Map<String, Object> initScriptParams;
if (initScript != null) { if (initScript != null) {
executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT); executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT);
initScriptParams = initScript.getParams();
} else { } else {
executableInitScript = p -> null; executableInitScript = p -> null;
initScriptParams = Collections.emptyMap();
} }
SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT); SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT);
Map<String, Object> mapScriptParams = mapScript.getParams();
ExecutableScript.Factory executableCombineScript; ExecutableScript.Factory executableCombineScript;
Map<String, Object> combineScriptParams;
if (combineScript != null) { if (combineScript != null) {
executableCombineScript =queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT); executableCombineScript = queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT);
combineScriptParams = combineScript.getParams();
} else { } else {
executableCombineScript = p -> null; executableCombineScript = p -> null;
combineScriptParams = Collections.emptyMap();
} }
return new ScriptedMetricAggregatorFactory(name, searchMapScript, executableInitScript, executableCombineScript, reduceScript, return new ScriptedMetricAggregatorFactory(name, searchMapScript, mapScriptParams, executableInitScript, initScriptParams,
executableCombineScript, combineScriptParams, reduceScript,
params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData); params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData);
} }

View File

@ -35,28 +35,35 @@ import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Function;
public class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedMetricAggregatorFactory> { public class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedMetricAggregatorFactory> {
private final SearchScript.Factory mapScript; private final SearchScript.Factory mapScript;
private final Map<String, Object> mapScriptParams;
private final ExecutableScript.Factory combineScript; private final ExecutableScript.Factory combineScript;
private final Map<String, Object> combineScriptParams;
private final Script reduceScript; private final Script reduceScript;
private final Map<String, Object> params; private final Map<String, Object> aggParams;
private final SearchLookup lookup; private final SearchLookup lookup;
private final ExecutableScript.Factory initScript; private final ExecutableScript.Factory initScript;
private final Map<String, Object> initScriptParams;
public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, ExecutableScript.Factory initScript, public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, Map<String, Object> mapScriptParams,
ExecutableScript.Factory combineScript, Script reduceScript, Map<String, Object> params, ExecutableScript.Factory initScript, Map<String, Object> initScriptParams,
ExecutableScript.Factory combineScript, Map<String, Object> combineScriptParams,
Script reduceScript, Map<String, Object> aggParams,
SearchLookup lookup, SearchContext context, AggregatorFactory<?> parent, SearchLookup lookup, SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactories, Map<String, Object> metaData) throws IOException { AggregatorFactories.Builder subFactories, Map<String, Object> metaData) throws IOException {
super(name, context, parent, subFactories, metaData); super(name, context, parent, subFactories, metaData);
this.mapScript = mapScript; this.mapScript = mapScript;
this.mapScriptParams = mapScriptParams;
this.initScript = initScript; this.initScript = initScript;
this.initScriptParams = initScriptParams;
this.combineScript = combineScript; this.combineScript = combineScript;
this.combineScriptParams = combineScriptParams;
this.reduceScript = reduceScript; this.reduceScript = reduceScript;
this.lookup = lookup; this.lookup = lookup;
this.params = params; this.aggParams = aggParams;
} }
@Override @Override
@ -65,26 +72,26 @@ public class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedM
if (collectsFromSingleBucket == false) { if (collectsFromSingleBucket == false) {
return asMultiBucketAggregator(this, context, parent); return asMultiBucketAggregator(this, context, parent);
} }
Map<String, Object> params = this.params; Map<String, Object> aggParams = this.aggParams;
if (params != null) { if (aggParams != null) {
params = deepCopyParams(params, context); aggParams = deepCopyParams(aggParams, context);
} else { } else {
params = new HashMap<>(); aggParams = new HashMap<>();
} }
if (params.containsKey("_agg") == false) { if (aggParams.containsKey("_agg") == false) {
params.put("_agg", new HashMap<String, Object>()); aggParams.put("_agg", new HashMap<String, Object>());
} }
final ExecutableScript initScript = this.initScript.newInstance(params); final ExecutableScript initScript = this.initScript.newInstance(mergeParams(aggParams, initScriptParams));
final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(params, lookup); final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(mergeParams(aggParams, mapScriptParams), lookup);
final ExecutableScript combineScript = this.combineScript.newInstance(params); final ExecutableScript combineScript = this.combineScript.newInstance(mergeParams(aggParams, combineScriptParams));
final Script reduceScript = deepCopyScript(this.reduceScript, context); final Script reduceScript = deepCopyScript(this.reduceScript, context);
if (initScript != null) { if (initScript != null) {
initScript.run(); initScript.run();
} }
return new ScriptedMetricAggregator(name, mapScript, return new ScriptedMetricAggregator(name, mapScript,
combineScript, reduceScript, params, context, parent, combineScript, reduceScript, aggParams, context, parent,
pipelineAggregators, metaData); pipelineAggregators, metaData);
} }
@ -128,5 +135,18 @@ public class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedM
return clone; return clone;
} }
private static Map<String, Object> mergeParams(Map<String, Object> agg, Map<String, Object> script) {
// Start with script params
Map<String, Object> combined = new HashMap<>(script);
// Add in agg params, throwing an exception if any conflicts are detected
for (Map.Entry<String, Object> aggEntry : agg.entrySet()) {
if (combined.putIfAbsent(aggEntry.getKey(), aggEntry.getValue()) != null) {
throw new IllegalArgumentException("Parameter name \"" + aggEntry.getKey() +
"\" used in both aggregation and script parameters");
}
}
return combined;
}
} }

View File

@ -20,6 +20,8 @@
package org.elasticsearch.search.aggregations.metrics; package org.elasticsearch.search.aggregations.metrics;
import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
@ -62,6 +64,7 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.scripted
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@ -322,11 +325,11 @@ public class ScriptedMetricIT extends ESIntegTestCase {
assertThat(numShardsRun, greaterThan(0)); assertThat(numShardsRun, greaterThan(0));
} }
public void testMapWithParams() { public void testExplicitAggParam() {
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();
params.put("_agg", new ArrayList<>()); params.put("_agg", new ArrayList<>());
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", params); Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", Collections.emptyMap());
SearchResponse response = client().prepareSearch("idx") SearchResponse response = client().prepareSearch("idx")
.setQuery(matchAllQuery()) .setQuery(matchAllQuery())
@ -361,17 +364,17 @@ public class ScriptedMetricIT extends ESIntegTestCase {
} }
public void testMapWithParamsAndImplicitAggMap() { public void testMapWithParamsAndImplicitAggMap() {
Map<String, Object> params = new HashMap<>(); // Split the params up between the script and the aggregation.
// don't put any _agg map in params // Don't put any _agg map in params.
params.put("param1", "12"); Map<String, Object> scriptParams = Collections.singletonMap("param1", "12");
params.put("param2", 1); Map<String, Object> aggregationParams = Collections.singletonMap("param2", 1);
// The _agg hashmap will be available even if not declared in the params map // The _agg hashmap will be available even if not declared in the params map
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", params); Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", scriptParams);
SearchResponse response = client().prepareSearch("idx") SearchResponse response = client().prepareSearch("idx")
.setQuery(matchAllQuery()) .setQuery(matchAllQuery())
.addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript)) .addAggregation(scriptedMetric("scripted").params(aggregationParams).mapScript(mapScript))
.get(); .get();
assertSearchResponse(response); assertSearchResponse(response);
assertThat(response.getHits().getTotalHits(), equalTo(numDocs)); assertThat(response.getHits().getTotalHits(), equalTo(numDocs));
@ -1001,4 +1004,16 @@ public class ScriptedMetricIT extends ESIntegTestCase {
assertThat(client().admin().indices().prepareStats("cache_test_idx").setRequestCache(true).get().getTotal().getRequestCache() assertThat(client().admin().indices().prepareStats("cache_test_idx").setRequestCache(true).get().getTotal().getRequestCache()
.getMissCount(), equalTo(0L)); .getMissCount(), equalTo(0L));
} }
public void testConflictingAggAndScriptParams() {
Map<String, Object> params = Collections.singletonMap("param1", "12");
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", params);
SearchRequestBuilder builder = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript));
SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, builder::get);
assertThat(ex.getCause().getMessage(), containsString("Parameter name \"param1\" used in both aggregation and script parameters"));
}
} }

View File

@ -64,8 +64,16 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
Collections.emptyMap()); Collections.emptyMap());
private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore", private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore",
Collections.emptyMap()); Collections.emptyMap());
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();
private static final Script INIT_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptParams",
Collections.singletonMap("initialValue", 24));
private static final Script MAP_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptParams",
Collections.singletonMap("itemValue", 12));
private static final Script COMBINE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptParams",
Collections.singletonMap("divisor", 4));
private static final String CONFLICTING_PARAM_NAME = "initialValue";
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();
@BeforeClass @BeforeClass
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -99,6 +107,26 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg"); Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
return ((List<Double>) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum(); return ((List<Double>) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum();
}); });
SCRIPTS.put("initScriptParams", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
Integer initialValue = (Integer)params.get("initialValue");
ArrayList<Integer> collector = new ArrayList();
collector.add(initialValue);
agg.put("collector", collector);
return agg;
});
SCRIPTS.put("mapScriptParams", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
Integer itemValue = (Integer) params.get("itemValue");
((List<Integer>) agg.get("collector")).add(itemValue);
return agg;
});
SCRIPTS.put("combineScriptParams", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
int divisor = ((Integer) params.get("divisor"));
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum();
});
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -187,6 +215,48 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
} }
} }
public void testScriptParamsPassedThrough() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
for (int i = 0; i < 100; i++) {
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
}
}
try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
aggregationBuilder.initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS).combineScript(COMBINE_SCRIPT_PARAMS);
ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder);
// The result value depends on the script params.
assertEquals(306, scriptedMetric.aggregation());
}
}
}
public void testConflictingAggAndScriptParams() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
for (int i = 0; i < 100; i++) {
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
}
}
try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
Map<String, Object> aggParams = Collections.singletonMap(CONFLICTING_PARAM_NAME, "blah");
aggregationBuilder.params(aggParams).initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS).
combineScript(COMBINE_SCRIPT_PARAMS);
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
);
assertEquals("Parameter name \"" + CONFLICTING_PARAM_NAME + "\" used in both aggregation and script parameters",
ex.getMessage());
}
}
}
/** /**
* We cannot use Mockito for mocking QueryShardContext in this case because * We cannot use Mockito for mocking QueryShardContext in this case because
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript) * script-related methods (e.g. QueryShardContext#getLazyExecutableScript)