diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java index 2b4685e5070..2e799b2903b 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java @@ -449,7 +449,7 @@ public class ScriptedMetricIT extends ESIntegTestCase { assertThat(numShardsRun, greaterThan(0)); } - public void testInitMapWithParams() { + public void testInitMutatesParams() { Map varsMap = new HashMap<>(); varsMap.put("multiplier", 1); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregationBuilder.java index 4d7fe6a66f6..ff33a5f5585 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregationBuilder.java @@ -232,7 +232,7 @@ public class ScriptedMetricAggregationBuilder extends AbstractAggregationBuilder compiledInitScript = queryShardContext.compile(initScript, ScriptedMetricAggContexts.InitScript.CONTEXT); initScriptParams = initScript.getParams(); } else { - compiledInitScript = (p, a) -> null; + compiledInitScript = null; initScriptParams = Collections.emptyMap(); } @@ -241,12 +241,9 @@ public class ScriptedMetricAggregationBuilder extends AbstractAggregationBuilder Map mapScriptParams = mapScript.getParams(); - ScriptedMetricAggContexts.CombineScript.Factory compiledCombineScript; - Map combineScriptParams; - - compiledCombineScript = queryShardContext.compile(combineScript, + ScriptedMetricAggContexts.CombineScript.Factory compiledCombineScript = queryShardContext.compile(combineScript, ScriptedMetricAggContexts.CombineScript.CONTEXT); - combineScriptParams = combineScript.getParams(); + Map combineScriptParams = combineScript.getParams(); return new ScriptedMetricAggregatorFactory(name, compiledMapScript, mapScriptParams, compiledInitScript, initScriptParams, compiledCombineScript, combineScriptParams, reduceScript, diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java index 0fa9c8cfae7..c718024a9b2 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java @@ -22,6 +22,7 @@ package org.elasticsearch.search.aggregations.metrics; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.util.CollectionUtils; @@ -53,7 +54,10 @@ class ScriptedMetricAggregator extends MetricsAggregator { private static final long BUCKET_COST_ESTIMATE = 1024 * 5; private final SearchLookup lookup; - private final Map initialState; + private final Map aggParams; + @Nullable + private final ScriptedMetricAggContexts.InitScript.Factory initScriptFactory; + private final Map initScriptParams; private final ScriptedMetricAggContexts.MapScript.Factory mapScriptFactory; private final Map mapScriptParams; private final ScriptedMetricAggContexts.CombineScript.Factory combineScriptFactory; @@ -64,7 +68,9 @@ class ScriptedMetricAggregator extends MetricsAggregator { ScriptedMetricAggregator( String name, SearchLookup lookup, - Map initialState, + Map aggParams, + @Nullable ScriptedMetricAggContexts.InitScript.Factory initScriptFactory, + Map initScriptParams, ScriptedMetricAggContexts.MapScript.Factory mapScriptFactory, Map mapScriptParams, ScriptedMetricAggContexts.CombineScript.Factory combineScriptFactory, @@ -76,7 +82,9 @@ class ScriptedMetricAggregator extends MetricsAggregator { ) throws IOException { super(name, context, parent, metadata); this.lookup = lookup; - this.initialState = initialState; + this.aggParams = aggParams; + this.initScriptFactory = initScriptFactory; + this.initScriptParams = initScriptParams; this.mapScriptFactory = mapScriptFactory; this.mapScriptParams = mapScriptParams; this.combineScriptFactory = combineScriptFactory; @@ -129,39 +137,22 @@ class ScriptedMetricAggregator extends MetricsAggregator { @Override public InternalAggregation buildAggregation(long owningBucketOrdinal) { - Object result = resultFor(aggStateFor(owningBucketOrdinal)); + Object result = aggStateForResult(owningBucketOrdinal).combine(); StreamOutput.checkWriteable(result); return new InternalScriptedMetric(name, result, reduceScript, metadata()); } - private Map aggStateFor(long owningBucketOrdinal) { + private State aggStateForResult(long owningBucketOrdinal) { if (owningBucketOrdinal >= states.size()) { - return newInitialState(); + return new State(); } State state = states.get(owningBucketOrdinal); if (state == null) { - return newInitialState(); + return new State(); } // The last script that touched the state at this point is the "map" script CollectionUtils.ensureNoSelfReferences(state.aggState, "Scripted metric aggs map script"); - return state.aggState; - } - - private Object resultFor(Map aggState) { - if (combineScriptFactory == null) { - return aggState; - } - Object result = combineScriptFactory.newInstance( - // Send a deep copy of the params because the script is allowed to mutate it - ScriptedMetricAggregatorFactory.deepCopyParams(combineScriptParams, context), - aggState - ).execute(); - CollectionUtils.ensureNoSelfReferences(result, "Scripted metric aggs combine script"); - return result; - } - - private Map newInitialState() { - return initialState == null ? new HashMap<>() : ScriptedMetricAggregatorFactory.deepCopyParams(initialState, context); + return state; } @Override @@ -170,23 +161,48 @@ class ScriptedMetricAggregator extends MetricsAggregator { } @Override - public void close() { + public void doClose() { Releasables.close(states); } private class State { private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript; + private final Map mapScriptParamsForState; + private final Map combineScriptParamsForState; private final Map aggState; private MapScript leafMapScript; State() { - aggState = newInitialState(); + // Its possible for building the initial state to mutate the parameters as a side effect + Map aggParamsForState = ScriptedMetricAggregatorFactory.deepCopyParams(aggParams, context); + mapScriptParamsForState = ScriptedMetricAggregatorFactory.mergeParams(aggParamsForState, mapScriptParams); + combineScriptParamsForState = ScriptedMetricAggregatorFactory.mergeParams(aggParamsForState, combineScriptParams); + aggState = newInitialState(ScriptedMetricAggregatorFactory.mergeParams(aggParamsForState, initScriptParams)); mapScript = mapScriptFactory.newFactory( - // Send a deep copy of the params because the script is allowed to mutate it - ScriptedMetricAggregatorFactory.deepCopyParams(mapScriptParams, context), + ScriptedMetricAggregatorFactory.deepCopyParams(mapScriptParamsForState, context), aggState, lookup ); } + + private Map newInitialState(Map initScriptParamsForState) { + if (initScriptFactory == null) { + return new HashMap<>(); + } + Map initialState = new HashMap<>(); + initScriptFactory.newInstance(initScriptParamsForState, initialState).execute(); + CollectionUtils.ensureNoSelfReferences(initialState, "Scripted metric aggs init script"); + return initialState; + } + + private Object combine() { + if (combineScriptFactory == null) { + return aggState; + } + Object result = combineScriptFactory.newInstance(combineScriptParamsForState, aggState).execute(); + CollectionUtils.ensureNoSelfReferences(result, "Scripted metric aggs combine script"); + return result; + } + } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java index 984d41c0e57..ce6f5f636bf 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java @@ -19,7 +19,7 @@ package org.elasticsearch.search.aggregations.metrics; -import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.common.Nullable; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptedMetricAggContexts; @@ -45,16 +45,26 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory { private final Script reduceScript; private final Map aggParams; private final SearchLookup lookup; + @Nullable private final ScriptedMetricAggContexts.InitScript.Factory initScript; private final Map initScriptParams; - ScriptedMetricAggregatorFactory(String name, - ScriptedMetricAggContexts.MapScript.Factory mapScript, Map mapScriptParams, - ScriptedMetricAggContexts.InitScript.Factory initScript, Map initScriptParams, - ScriptedMetricAggContexts.CombineScript.Factory combineScript, - Map combineScriptParams, Script reduceScript, Map aggParams, - SearchLookup lookup, QueryShardContext queryShardContext, AggregatorFactory parent, - AggregatorFactories.Builder subFactories, Map metadata) throws IOException { + ScriptedMetricAggregatorFactory( + String name, + ScriptedMetricAggContexts.MapScript.Factory mapScript, + Map mapScriptParams, + @Nullable ScriptedMetricAggContexts.InitScript.Factory initScript, + Map initScriptParams, + ScriptedMetricAggContexts.CombineScript.Factory combineScript, + Map combineScriptParams, + Script reduceScript, + Map aggParams, + SearchLookup lookup, + QueryShardContext queryShardContext, + AggregatorFactory parent, + AggregatorFactories.Builder subFactories, + Map metadata + ) throws IOException { super(name, queryShardContext, parent, subFactories, metadata); this.mapScript = mapScript; this.mapScriptParams = mapScriptParams; @@ -73,29 +83,19 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory { boolean collectsFromSingleBucket, Map metadata) throws IOException { Map aggParams = this.aggParams == null ? org.elasticsearch.common.collect.Map.of() : this.aggParams; - Map initialState = new HashMap(); - ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance( - mergeParams(aggParams, initScriptParams), - initialState - ); - if (initScript != null) { - initScript.execute(); - CollectionUtils.ensureNoSelfReferences(initialState, "Scripted metric aggs init script"); - } - - Map mapParams = mergeParams(aggParams, mapScriptParams); - Map combineParams = mergeParams(aggParams, combineScriptParams); Script reduceScript = deepCopyScript(this.reduceScript, searchContext, aggParams); return new ScriptedMetricAggregator( name, lookup, - initialState, + aggParams, + initScript, + initScriptParams, mapScript, - mapParams, + mapScriptParams, combineScript, - combineParams, + combineScriptParams, reduceScript, searchContext, parent, @@ -140,7 +140,7 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory { return clone; } - private static Map mergeParams(Map agg, Map script) { + static Map mergeParams(Map agg, Map script) { // Start with script params Map combined = new HashMap<>(script); diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java index 1d6655f8c3d..d03e64e1b12 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.search.aggregations.metrics; +import org.apache.lucene.document.Document; import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.index.DirectoryReader; @@ -26,12 +27,17 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.CheckedConsumer; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.indices.breaker.CircuitBreakerService; @@ -42,9 +48,13 @@ import org.elasticsearch.script.ScriptModule; import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptType; import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorTestCase; +import org.elasticsearch.search.aggregations.MultiBucketConsumerService.MultiBucketConsumer; import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.internal.SearchContext; +import org.junit.Before; import org.junit.BeforeClass; import java.io.IOException; @@ -58,6 +68,8 @@ import java.util.function.Function; import static java.util.Collections.singleton; import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class ScriptedMetricAggregatorTests extends AggregatorTestCase { @@ -95,6 +107,13 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { private static final Script COMBINE_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptSelfRef", Collections.emptyMap()); + private static final Script INIT_SCRIPT_MAKING_ARRAY = new Script( + ScriptType.INLINE, + MockScriptEngine.NAME, + "initScriptMakingArray", + Collections.emptyMap() + ); + private static final Map, Object>> SCRIPTS = new HashMap<>(); @BeforeClass @@ -181,6 +200,46 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { state.put("selfRef", state); return state; }); + SCRIPTS.put("initScriptMakingArray", params -> { + Map state = (Map) params.get("state"); + state.put("array", new String[] {"foo", "bar"}); + state.put("collector", new ArrayList()); + return state; + }); + } + + private CircuitBreakerService circuitBreakerService; + + @Before + public void mockBreaker() { + circuitBreakerService = mock(CircuitBreakerService.class); + when(circuitBreakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(new NoopCircuitBreaker(CircuitBreaker.REQUEST) { + private long total = 0; + + @Override + public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + logger.debug("Used {} grabbing {} for {}", total, bytes, label); + total += bytes; + return total; + } + + @Override + public long addWithoutBreaking(long bytes) { + logger.debug("Used {} grabbing {}", total, bytes); + total += bytes; + return total; + } + + @Override + public long getUsed() { + return total; + } + }); + } + + @Override + protected void afterClose() { + assertThat(circuitBreakerService.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); } @Override @@ -420,8 +479,19 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { } } + public void testInitScriptMakesArray() throws IOException { + ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME); + aggregationBuilder.initScript(INIT_SCRIPT_MAKING_ARRAY).mapScript(MAP_SCRIPT) + .combineScript(COMBINE_SCRIPT).reduceScript(REDUCE_SCRIPT); + testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + iw.addDocument(new Document()); + }, (InternalScriptedMetric r) -> { + assertEquals(1, r.aggregation()); + }); + } + public void testAsSubAgg() throws IOException { - AggregationBuilder aggregationBuilder = new TermsAggregationBuilder("t").field("t") + AggregationBuilder aggregationBuilder = new TermsAggregationBuilder("t").field("t").executionHint("map") .subAggregation( new ScriptedMetricAggregationBuilder("scripted").initScript(INIT_SCRIPT) .mapScript(MAP_SCRIPT) @@ -446,6 +516,25 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, verify, keywordField("t"), longField("number")); } + protected A createAggregator( + Query query, + AggregationBuilder aggregationBuilder, + IndexSearcher indexSearcher, + IndexSettings indexSettings, + MultiBucketConsumer bucketConsumer, + MappedFieldType... fieldTypes + ) throws IOException { + SearchContext searchContext = createSearchContext( + indexSearcher, + indexSettings, + query, + bucketConsumer, + circuitBreakerService, + fieldTypes + ); + return createAggregator(aggregationBuilder, searchContext); + } + /** * We cannot use Mockito for mocking QueryShardContext in this case because * script-related methods (e.g. QueryShardContext#getLazyExecutableScript) diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 2cf9978bc08..f122e1d37a8 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -237,9 +237,13 @@ public abstract class AggregatorTestCase extends ESTestCase { MultiBucketConsumer bucketConsumer, MappedFieldType... fieldTypes) throws IOException { SearchContext searchContext = createSearchContext(indexSearcher, indexSettings, query, bucketConsumer, fieldTypes); + return createAggregator(aggregationBuilder, searchContext); + } + + protected A createAggregator(AggregationBuilder aggregationBuilder, SearchContext searchContext) + throws IOException { @SuppressWarnings("unchecked") - A aggregator = (A) aggregationBuilder - .rewrite(searchContext.getQueryShardContext()) + A aggregator = (A) aggregationBuilder.rewrite(searchContext.getQueryShardContext()) .build(searchContext.getQueryShardContext(), null) .create(searchContext, null, true); return aggregator; @@ -876,6 +880,11 @@ public abstract class AggregatorTestCase extends ESTestCase { releasables.clear(); } + /** + * Hook for checking things after all {@link Aggregator}s have been closed. + */ + protected void afterClose() {} + /** * Make a {@linkplain DateFieldMapper.DateFieldType} for a {@code date}. */