From 3b1dfa3b5dbfce1c1cd311abc3977f218598aec4 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Fri, 5 Jun 2020 16:14:28 -0400 Subject: [PATCH] Remove deprecated wrapped from scripted_metric (backport of #57627) (#57763) This removes the deprecated `asMultiBucketAggregator` wrapper from `scripted_metric`. Unlike most other such removals, this isn't likely to save much memory. But it does make the internals of the aggregator slightly less twisted. Relates to #56487 --- .../metrics/ScriptedMetricAggregator.java | 153 ++++++++++++++---- .../ScriptedMetricAggregatorFactory.java | 50 +++--- .../ScriptedMetricAggregatorTests.java | 68 ++++++-- 3 files changed, 200 insertions(+), 71 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java index d88a26a7567..0fa9c8cfae7 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java @@ -23,38 +23,66 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptedMetricAggContexts; +import org.elasticsearch.script.ScriptedMetricAggContexts.MapScript; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.lookup.SearchLookup; import java.io.IOException; +import java.util.HashMap; import java.util.Map; class ScriptedMetricAggregator extends MetricsAggregator { + /** + * Estimated cost to maintain a bucket. Since this aggregator uses + * untracked java collections for its state it is going to both be + * much "heavier" than a normal metric aggregator and not going to be + * tracked by the circuit breakers properly. This is sad. So we pick a big + * number and estimate that each bucket costs that. It could be wildly + * inaccurate. We're sort of hoping that the real memory breaker saves + * us here. Or that folks just don't use the aggregation. + */ + private static final long BUCKET_COST_ESTIMATE = 1024 * 5; - private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript; - private final ScriptedMetricAggContexts.CombineScript combineScript; + private final SearchLookup lookup; + private final Map initialState; + private final ScriptedMetricAggContexts.MapScript.Factory mapScriptFactory; + private final Map mapScriptParams; + private final ScriptedMetricAggContexts.CombineScript.Factory combineScriptFactory; + private final Map combineScriptParams; private final Script reduceScript; - private Map aggState; + private ObjectArray states; - ScriptedMetricAggregator(String name, - ScriptedMetricAggContexts.MapScript.LeafFactory mapScript, - ScriptedMetricAggContexts.CombineScript combineScript, - Script reduceScript, - Map aggState, - SearchContext context, - Aggregator parent, - Map metadata) throws IOException { + ScriptedMetricAggregator( + String name, + SearchLookup lookup, + Map initialState, + ScriptedMetricAggContexts.MapScript.Factory mapScriptFactory, + Map mapScriptParams, + ScriptedMetricAggContexts.CombineScript.Factory combineScriptFactory, + Map combineScriptParams, + Script reduceScript, + SearchContext context, + Aggregator parent, + Map metadata + ) throws IOException { super(name, context, parent, metadata); - this.aggState = aggState; - this.mapScript = mapScript; - this.combineScript = combineScript; + this.lookup = lookup; + this.initialState = initialState; + this.mapScriptFactory = mapScriptFactory; + this.mapScriptParams = mapScriptParams; + this.combineScriptFactory = combineScriptFactory; + this.combineScriptParams = combineScriptParams; this.reduceScript = reduceScript; + states = context.bigArrays().newObjectArray(1); } @Override @@ -63,36 +91,77 @@ class ScriptedMetricAggregator extends MetricsAggregator { } @Override - public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, - final LeafBucketCollector sub) throws IOException { - final ScriptedMetricAggContexts.MapScript leafMapScript = mapScript.newInstance(ctx); - return new LeafBucketCollectorBase(sub, leafMapScript) { + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + // Clear any old leaf scripts so we rebuild them on the new leaf when we first see them. + for (long i = 0; i < states.size(); i++) { + State state = states.get(i); + if (state == null) { + continue; + } + state.leafMapScript = null; + } + return new LeafBucketCollectorBase(sub, null) { + private Scorable scorer; + @Override public void setScorer(Scorable scorer) throws IOException { - leafMapScript.setScorer(scorer); + this.scorer = scorer; } @Override - public void collect(int doc, long bucket) throws IOException { - assert bucket == 0 : bucket; - - leafMapScript.setDocument(doc); - leafMapScript.execute(); + public void collect(int doc, long owningBucketOrd) throws IOException { + states = context.bigArrays().grow(states, owningBucketOrd + 1); + State state = states.get(owningBucketOrd); + if (state == null) { + addRequestCircuitBreakerBytes(BUCKET_COST_ESTIMATE); + state = new State(); + states.set(owningBucketOrd, state); + } + if (state.leafMapScript == null) { + state.leafMapScript = state.mapScript.newInstance(ctx); + state.leafMapScript.setScorer(scorer); + } + state.leafMapScript.setDocument(doc); + state.leafMapScript.execute(); } }; } @Override public InternalAggregation buildAggregation(long owningBucketOrdinal) { - Object aggregation; - if (combineScript != null) { - aggregation = combineScript.execute(); - CollectionUtils.ensureNoSelfReferences(aggregation, "Scripted metric aggs combine script"); - } else { - aggregation = aggState; + Object result = resultFor(aggStateFor(owningBucketOrdinal)); + StreamOutput.checkWriteable(result); + return new InternalScriptedMetric(name, result, reduceScript, metadata()); + } + + private Map aggStateFor(long owningBucketOrdinal) { + if (owningBucketOrdinal >= states.size()) { + return newInitialState(); } - StreamOutput.checkWriteable(aggregation); - return new InternalScriptedMetric(name, aggregation, reduceScript, metadata()); + State state = states.get(owningBucketOrdinal); + if (state == null) { + return newInitialState(); + } + // The last script that touched the state at this point is the "map" script + CollectionUtils.ensureNoSelfReferences(state.aggState, "Scripted metric aggs map script"); + return state.aggState; + } + + private Object resultFor(Map aggState) { + if (combineScriptFactory == null) { + return aggState; + } + Object result = combineScriptFactory.newInstance( + // Send a deep copy of the params because the script is allowed to mutate it + ScriptedMetricAggregatorFactory.deepCopyParams(combineScriptParams, context), + aggState + ).execute(); + CollectionUtils.ensureNoSelfReferences(result, "Scripted metric aggs combine script"); + return result; + } + + private Map newInitialState() { + return initialState == null ? new HashMap<>() : ScriptedMetricAggregatorFactory.deepCopyParams(initialState, context); } @Override @@ -101,9 +170,23 @@ class ScriptedMetricAggregator extends MetricsAggregator { } @Override - protected void doPostCollection() throws IOException { - CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs map script"); + public void close() { + Releasables.close(states); + } - super.doPostCollection(); + private class State { + private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript; + private final Map aggState; + private MapScript leafMapScript; + + State() { + aggState = newInitialState(); + mapScript = mapScriptFactory.newFactory( + // Send a deep copy of the params because the script is allowed to mutate it + ScriptedMetricAggregatorFactory.deepCopyParams(mapScriptParams, context), + aggState, + lookup + ); + } } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java index 162182b8c4d..984d41c0e57 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java @@ -72,32 +72,35 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory { Aggregator parent, boolean collectsFromSingleBucket, Map metadata) throws IOException { - if (collectsFromSingleBucket == false) { - return asMultiBucketAggregator(this, searchContext, parent); - } - Map aggParams = this.aggParams; - if (aggParams != null) { - aggParams = deepCopyParams(aggParams, searchContext); - } else { - aggParams = new HashMap<>(); - } + Map aggParams = this.aggParams == null ? org.elasticsearch.common.collect.Map.of() : this.aggParams; + Map initialState = new HashMap(); - Map aggState = new HashMap(); - - final ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance( - mergeParams(aggParams, initScriptParams), aggState); - final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript = this.mapScript.newFactory( - mergeParams(aggParams, mapScriptParams), aggState, lookup); - final ScriptedMetricAggContexts.CombineScript combineScript = this.combineScript.newInstance( - mergeParams(aggParams, combineScriptParams), aggState); - - final Script reduceScript = deepCopyScript(this.reduceScript, searchContext, aggParams); + ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance( + mergeParams(aggParams, initScriptParams), + initialState + ); if (initScript != null) { initScript.execute(); - CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs init script"); + CollectionUtils.ensureNoSelfReferences(initialState, "Scripted metric aggs init script"); } - return new ScriptedMetricAggregator(name, mapScript, - combineScript, reduceScript, aggState, searchContext, parent, metadata); + + Map mapParams = mergeParams(aggParams, mapScriptParams); + Map combineParams = mergeParams(aggParams, combineScriptParams); + Script reduceScript = deepCopyScript(this.reduceScript, searchContext, aggParams); + + return new ScriptedMetricAggregator( + name, + lookup, + initialState, + mapScript, + mapParams, + combineScript, + combineParams, + reduceScript, + searchContext, + parent, + metadata + ); } private static Script deepCopyScript(Script script, SearchContext context, Map aggParams) { @@ -110,7 +113,7 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory { } @SuppressWarnings({ "unchecked" }) - private static T deepCopyParams(T original, SearchContext context) { + static T deepCopyParams(T original, SearchContext context) { T clone; if (original instanceof Map) { Map originalMap = (Map) original; @@ -152,3 +155,4 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory { return combined; } } + diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java index 7f79a267f79..1d6655f8c3d 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java @@ -20,12 +20,15 @@ package org.elasticsearch.search.aggregations.metrics; import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.index.IndexSettings; @@ -38,9 +41,10 @@ import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.script.ScriptModule; import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptType; +import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregatorTestCase; -import org.elasticsearch.search.aggregations.support.AggregationUsageService; -import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; +import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.junit.BeforeClass; import java.io.IOException; @@ -49,11 +53,11 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Function; import static java.util.Collections.singleton; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.hamcrest.Matchers.equalTo; public class ScriptedMetricAggregatorTests extends AggregatorTestCase { @@ -115,8 +119,8 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { return state; }); SCRIPTS.put("reduceScript", params -> { - Map state = (Map) params.get("state"); - return state; + List states = (List) params.get("states"); + return states.stream().mapToInt(Integer::intValue).sum(); }); SCRIPTS.put("initScriptScore", params -> { @@ -416,6 +420,32 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { } } + public void testAsSubAgg() throws IOException { + AggregationBuilder aggregationBuilder = new TermsAggregationBuilder("t").field("t") + .subAggregation( + new ScriptedMetricAggregationBuilder("scripted").initScript(INIT_SCRIPT) + .mapScript(MAP_SCRIPT) + .combineScript(COMBINE_SCRIPT) + .reduceScript(REDUCE_SCRIPT) + ); + CheckedConsumer buildIndex = iw -> { + for (int i = 0; i < 99; i++) { + iw.addDocument(singleton(new SortedSetDocValuesField("t", i % 2 == 0 ? new BytesRef("even") : new BytesRef("odd")))); + } + }; + Consumer verify = terms -> { + StringTerms.Bucket even = terms.getBucketByKey("even"); + assertThat(even.getDocCount(), equalTo(50L)); + ScriptedMetric evenMetric = even.getAggregations().get("scripted"); + assertThat(evenMetric.aggregation(), equalTo(50)); + StringTerms.Bucket odd = terms.getBucketByKey("odd"); + assertThat(odd.getDocCount(), equalTo(49L)); + ScriptedMetric oddMetric = odd.getAggregations().get("scripted"); + assertThat(oddMetric.aggregation(), equalTo(49)); + }; + testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, verify, keywordField("t"), longField("number")); + } + /** * We cannot use Mockito for mocking QueryShardContext in this case because * script-related methods (e.g. QueryShardContext#getLazyExecutableScript) @@ -430,12 +460,24 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, SCRIPTS, Collections.emptyMap()); Map engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine); ScriptService scriptService = new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS); - ValuesSourceRegistry valuesSourceRegistry = mock(ValuesSourceRegistry.class); - AggregationUsageService.Builder builder = new AggregationUsageService.Builder(); - builder.registerAggregationUsage(ScriptedMetricAggregationBuilder.NAME); - when(valuesSourceRegistry.getUsageService()).thenReturn(builder.build()); - return new QueryShardContext(0, indexSettings, BigArrays.NON_RECYCLING_INSTANCE, null, - null, mapperService, null, scriptService, xContentRegistry(), writableRegistry(), - null, null, System::currentTimeMillis, null, null, () -> true, valuesSourceRegistry); + return new QueryShardContext( + 0, + indexSettings, + BigArrays.NON_RECYCLING_INSTANCE, + null, + getIndexFieldDataLookup(mapperService, circuitBreakerService), + mapperService, + null, + scriptService, + xContentRegistry(), + writableRegistry(), + null, + null, + System::currentTimeMillis, + null, + null, + () -> true, + valuesSourceRegistry + ); } }