Remove deprecated wrapped from scripted_metric (backport of #57627) (#57763)

This removes the deprecated `asMultiBucketAggregator` wrapper from
`scripted_metric`. Unlike most other such removals, this isn't likely to
save much memory. But it does make the internals of the aggregator
slightly less twisted.

Relates to #56487
This commit is contained in:
Nik Everett 2020-06-05 16:14:28 -04:00 committed by GitHub
parent 9666a895f7
commit 3b1dfa3b5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 200 additions and 71 deletions

View File

@ -23,38 +23,66 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Scorable; import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.script.Script; import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.script.ScriptedMetricAggContexts;
import org.elasticsearch.script.ScriptedMetricAggContexts.MapScript;
import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.lookup.SearchLookup;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
class ScriptedMetricAggregator extends MetricsAggregator { class ScriptedMetricAggregator extends MetricsAggregator {
/**
* Estimated cost to maintain a bucket. Since this aggregator uses
* untracked java collections for its state it is going to both be
* much "heavier" than a normal metric aggregator and not going to be
* tracked by the circuit breakers properly. This is sad. So we pick a big
* number and estimate that each bucket costs that. It could be wildly
* inaccurate. We're sort of hoping that the real memory breaker saves
* us here. Or that folks just don't use the aggregation.
*/
private static final long BUCKET_COST_ESTIMATE = 1024 * 5;
private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript; private final SearchLookup lookup;
private final ScriptedMetricAggContexts.CombineScript combineScript; private final Map<String, Object> initialState;
private final ScriptedMetricAggContexts.MapScript.Factory mapScriptFactory;
private final Map<String, Object> mapScriptParams;
private final ScriptedMetricAggContexts.CombineScript.Factory combineScriptFactory;
private final Map<String, Object> combineScriptParams;
private final Script reduceScript; private final Script reduceScript;
private Map<String, Object> aggState; private ObjectArray<State> states;
ScriptedMetricAggregator(String name, ScriptedMetricAggregator(
ScriptedMetricAggContexts.MapScript.LeafFactory mapScript, String name,
ScriptedMetricAggContexts.CombineScript combineScript, SearchLookup lookup,
Map<String, Object> initialState,
ScriptedMetricAggContexts.MapScript.Factory mapScriptFactory,
Map<String, Object> mapScriptParams,
ScriptedMetricAggContexts.CombineScript.Factory combineScriptFactory,
Map<String, Object> combineScriptParams,
Script reduceScript, Script reduceScript,
Map<String, Object> aggState,
SearchContext context, SearchContext context,
Aggregator parent, Aggregator parent,
Map<String, Object> metadata) throws IOException { Map<String, Object> metadata
) throws IOException {
super(name, context, parent, metadata); super(name, context, parent, metadata);
this.aggState = aggState; this.lookup = lookup;
this.mapScript = mapScript; this.initialState = initialState;
this.combineScript = combineScript; this.mapScriptFactory = mapScriptFactory;
this.mapScriptParams = mapScriptParams;
this.combineScriptFactory = combineScriptFactory;
this.combineScriptParams = combineScriptParams;
this.reduceScript = reduceScript; this.reduceScript = reduceScript;
states = context.bigArrays().newObjectArray(1);
} }
@Override @Override
@ -63,36 +91,77 @@ class ScriptedMetricAggregator extends MetricsAggregator {
} }
@Override @Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
final LeafBucketCollector sub) throws IOException { // Clear any old leaf scripts so we rebuild them on the new leaf when we first see them.
final ScriptedMetricAggContexts.MapScript leafMapScript = mapScript.newInstance(ctx); for (long i = 0; i < states.size(); i++) {
return new LeafBucketCollectorBase(sub, leafMapScript) { State state = states.get(i);
if (state == null) {
continue;
}
state.leafMapScript = null;
}
return new LeafBucketCollectorBase(sub, null) {
private Scorable scorer;
@Override @Override
public void setScorer(Scorable scorer) throws IOException { public void setScorer(Scorable scorer) throws IOException {
leafMapScript.setScorer(scorer); this.scorer = scorer;
} }
@Override @Override
public void collect(int doc, long bucket) throws IOException { public void collect(int doc, long owningBucketOrd) throws IOException {
assert bucket == 0 : bucket; states = context.bigArrays().grow(states, owningBucketOrd + 1);
State state = states.get(owningBucketOrd);
leafMapScript.setDocument(doc); if (state == null) {
leafMapScript.execute(); addRequestCircuitBreakerBytes(BUCKET_COST_ESTIMATE);
state = new State();
states.set(owningBucketOrd, state);
}
if (state.leafMapScript == null) {
state.leafMapScript = state.mapScript.newInstance(ctx);
state.leafMapScript.setScorer(scorer);
}
state.leafMapScript.setDocument(doc);
state.leafMapScript.execute();
} }
}; };
} }
@Override @Override
public InternalAggregation buildAggregation(long owningBucketOrdinal) { public InternalAggregation buildAggregation(long owningBucketOrdinal) {
Object aggregation; Object result = resultFor(aggStateFor(owningBucketOrdinal));
if (combineScript != null) { StreamOutput.checkWriteable(result);
aggregation = combineScript.execute(); return new InternalScriptedMetric(name, result, reduceScript, metadata());
CollectionUtils.ensureNoSelfReferences(aggregation, "Scripted metric aggs combine script");
} else {
aggregation = aggState;
} }
StreamOutput.checkWriteable(aggregation);
return new InternalScriptedMetric(name, aggregation, reduceScript, metadata()); private Map<String, Object> aggStateFor(long owningBucketOrdinal) {
if (owningBucketOrdinal >= states.size()) {
return newInitialState();
}
State state = states.get(owningBucketOrdinal);
if (state == null) {
return newInitialState();
}
// The last script that touched the state at this point is the "map" script
CollectionUtils.ensureNoSelfReferences(state.aggState, "Scripted metric aggs map script");
return state.aggState;
}
private Object resultFor(Map<String, Object> aggState) {
if (combineScriptFactory == null) {
return aggState;
}
Object result = combineScriptFactory.newInstance(
// Send a deep copy of the params because the script is allowed to mutate it
ScriptedMetricAggregatorFactory.deepCopyParams(combineScriptParams, context),
aggState
).execute();
CollectionUtils.ensureNoSelfReferences(result, "Scripted metric aggs combine script");
return result;
}
private Map<String, Object> newInitialState() {
return initialState == null ? new HashMap<>() : ScriptedMetricAggregatorFactory.deepCopyParams(initialState, context);
} }
@Override @Override
@ -101,9 +170,23 @@ class ScriptedMetricAggregator extends MetricsAggregator {
} }
@Override @Override
protected void doPostCollection() throws IOException { public void close() {
CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs map script"); Releasables.close(states);
}
super.doPostCollection(); private class State {
private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript;
private final Map<String, Object> aggState;
private MapScript leafMapScript;
State() {
aggState = newInitialState();
mapScript = mapScriptFactory.newFactory(
// Send a deep copy of the params because the script is allowed to mutate it
ScriptedMetricAggregatorFactory.deepCopyParams(mapScriptParams, context),
aggState,
lookup
);
}
} }
} }

View File

@ -72,32 +72,35 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory {
Aggregator parent, Aggregator parent,
boolean collectsFromSingleBucket, boolean collectsFromSingleBucket,
Map<String, Object> metadata) throws IOException { Map<String, Object> metadata) throws IOException {
if (collectsFromSingleBucket == false) { Map<String, Object> aggParams = this.aggParams == null ? org.elasticsearch.common.collect.Map.of() : this.aggParams;
return asMultiBucketAggregator(this, searchContext, parent); Map<String, Object> initialState = new HashMap<String, Object>();
}
Map<String, Object> aggParams = this.aggParams;
if (aggParams != null) {
aggParams = deepCopyParams(aggParams, searchContext);
} else {
aggParams = new HashMap<>();
}
Map<String, Object> aggState = new HashMap<String, Object>(); ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance(
mergeParams(aggParams, initScriptParams),
final ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance( initialState
mergeParams(aggParams, initScriptParams), aggState); );
final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript = this.mapScript.newFactory(
mergeParams(aggParams, mapScriptParams), aggState, lookup);
final ScriptedMetricAggContexts.CombineScript combineScript = this.combineScript.newInstance(
mergeParams(aggParams, combineScriptParams), aggState);
final Script reduceScript = deepCopyScript(this.reduceScript, searchContext, aggParams);
if (initScript != null) { if (initScript != null) {
initScript.execute(); initScript.execute();
CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs init script"); CollectionUtils.ensureNoSelfReferences(initialState, "Scripted metric aggs init script");
} }
return new ScriptedMetricAggregator(name, mapScript,
combineScript, reduceScript, aggState, searchContext, parent, metadata); Map<String, Object> mapParams = mergeParams(aggParams, mapScriptParams);
Map<String, Object> combineParams = mergeParams(aggParams, combineScriptParams);
Script reduceScript = deepCopyScript(this.reduceScript, searchContext, aggParams);
return new ScriptedMetricAggregator(
name,
lookup,
initialState,
mapScript,
mapParams,
combineScript,
combineParams,
reduceScript,
searchContext,
parent,
metadata
);
} }
private static Script deepCopyScript(Script script, SearchContext context, Map<String, Object> aggParams) { private static Script deepCopyScript(Script script, SearchContext context, Map<String, Object> aggParams) {
@ -110,7 +113,7 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory {
} }
@SuppressWarnings({ "unchecked" }) @SuppressWarnings({ "unchecked" })
private static <T> T deepCopyParams(T original, SearchContext context) { static <T> T deepCopyParams(T original, SearchContext context) {
T clone; T clone;
if (original instanceof Map) { if (original instanceof Map) {
Map<?, ?> originalMap = (Map<?, ?>) original; Map<?, ?> originalMap = (Map<?, ?>) original;
@ -152,3 +155,4 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory {
return combined; return combined;
} }
} }

View File

@ -20,12 +20,15 @@
package org.elasticsearch.search.aggregations.metrics; package org.elasticsearch.search.aggregations.metrics;
import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexSettings;
@ -38,9 +41,10 @@ import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.ScriptModule; import org.elasticsearch.script.ScriptModule;
import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptService;
import org.elasticsearch.script.ScriptType; import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.support.AggregationUsageService; import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import java.io.IOException; import java.io.IOException;
@ -49,11 +53,11 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import static java.util.Collections.singleton; import static java.util.Collections.singleton;
import static org.mockito.Mockito.mock; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.when;
public class ScriptedMetricAggregatorTests extends AggregatorTestCase { public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
@ -115,8 +119,8 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
return state; return state;
}); });
SCRIPTS.put("reduceScript", params -> { SCRIPTS.put("reduceScript", params -> {
Map<String, Object> state = (Map<String, Object>) params.get("state"); List<Integer> states = (List<Integer>) params.get("states");
return state; return states.stream().mapToInt(Integer::intValue).sum();
}); });
SCRIPTS.put("initScriptScore", params -> { SCRIPTS.put("initScriptScore", params -> {
@ -416,6 +420,32 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
} }
} }
public void testAsSubAgg() throws IOException {
AggregationBuilder aggregationBuilder = new TermsAggregationBuilder("t").field("t")
.subAggregation(
new ScriptedMetricAggregationBuilder("scripted").initScript(INIT_SCRIPT)
.mapScript(MAP_SCRIPT)
.combineScript(COMBINE_SCRIPT)
.reduceScript(REDUCE_SCRIPT)
);
CheckedConsumer<RandomIndexWriter, IOException> buildIndex = iw -> {
for (int i = 0; i < 99; i++) {
iw.addDocument(singleton(new SortedSetDocValuesField("t", i % 2 == 0 ? new BytesRef("even") : new BytesRef("odd"))));
}
};
Consumer<StringTerms> verify = terms -> {
StringTerms.Bucket even = terms.getBucketByKey("even");
assertThat(even.getDocCount(), equalTo(50L));
ScriptedMetric evenMetric = even.getAggregations().get("scripted");
assertThat(evenMetric.aggregation(), equalTo(50));
StringTerms.Bucket odd = terms.getBucketByKey("odd");
assertThat(odd.getDocCount(), equalTo(49L));
ScriptedMetric oddMetric = odd.getAggregations().get("scripted");
assertThat(oddMetric.aggregation(), equalTo(49));
};
testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, verify, keywordField("t"), longField("number"));
}
/** /**
* We cannot use Mockito for mocking QueryShardContext in this case because * We cannot use Mockito for mocking QueryShardContext in this case because
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript) * script-related methods (e.g. QueryShardContext#getLazyExecutableScript)
@ -430,12 +460,24 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, SCRIPTS, Collections.emptyMap()); MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, SCRIPTS, Collections.emptyMap());
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine); Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
ScriptService scriptService = new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS); ScriptService scriptService = new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
ValuesSourceRegistry valuesSourceRegistry = mock(ValuesSourceRegistry.class); return new QueryShardContext(
AggregationUsageService.Builder builder = new AggregationUsageService.Builder(); 0,
builder.registerAggregationUsage(ScriptedMetricAggregationBuilder.NAME); indexSettings,
when(valuesSourceRegistry.getUsageService()).thenReturn(builder.build()); BigArrays.NON_RECYCLING_INSTANCE,
return new QueryShardContext(0, indexSettings, BigArrays.NON_RECYCLING_INSTANCE, null, null,
null, mapperService, null, scriptService, xContentRegistry(), writableRegistry(), getIndexFieldDataLookup(mapperService, circuitBreakerService),
null, null, System::currentTimeMillis, null, null, () -> true, valuesSourceRegistry); mapperService,
null,
scriptService,
xContentRegistry(),
writableRegistry(),
null,
null,
System::currentTimeMillis,
null,
null,
() -> true,
valuesSourceRegistry
);
} }
} }