This removes the deprecated `asMultiBucketAggregator` wrapper from `scripted_metric`. Unlike most other such removals, this isn't likely to save much memory. But it does make the internals of the aggregator slightly less twisted. Relates to #56487
This commit is contained in:
parent
9666a895f7
commit
3b1dfa3b5d
|
@ -23,38 +23,66 @@ import org.apache.lucene.index.LeafReaderContext;
|
|||
import org.apache.lucene.search.Scorable;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.lease.Releasables;
|
||||
import org.elasticsearch.common.util.CollectionUtils;
|
||||
import org.elasticsearch.common.util.ObjectArray;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.script.ScriptedMetricAggContexts;
|
||||
import org.elasticsearch.script.ScriptedMetricAggContexts.MapScript;
|
||||
import org.elasticsearch.search.aggregations.Aggregator;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregation;
|
||||
import org.elasticsearch.search.aggregations.LeafBucketCollector;
|
||||
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
|
||||
import org.elasticsearch.search.internal.SearchContext;
|
||||
import org.elasticsearch.search.lookup.SearchLookup;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
class ScriptedMetricAggregator extends MetricsAggregator {
|
||||
/**
|
||||
* Estimated cost to maintain a bucket. Since this aggregator uses
|
||||
* untracked java collections for its state it is going to both be
|
||||
* much "heavier" than a normal metric aggregator and not going to be
|
||||
* tracked by the circuit breakers properly. This is sad. So we pick a big
|
||||
* number and estimate that each bucket costs that. It could be wildly
|
||||
* inaccurate. We're sort of hoping that the real memory breaker saves
|
||||
* us here. Or that folks just don't use the aggregation.
|
||||
*/
|
||||
private static final long BUCKET_COST_ESTIMATE = 1024 * 5;
|
||||
|
||||
private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript;
|
||||
private final ScriptedMetricAggContexts.CombineScript combineScript;
|
||||
private final SearchLookup lookup;
|
||||
private final Map<String, Object> initialState;
|
||||
private final ScriptedMetricAggContexts.MapScript.Factory mapScriptFactory;
|
||||
private final Map<String, Object> mapScriptParams;
|
||||
private final ScriptedMetricAggContexts.CombineScript.Factory combineScriptFactory;
|
||||
private final Map<String, Object> combineScriptParams;
|
||||
private final Script reduceScript;
|
||||
private Map<String, Object> aggState;
|
||||
private ObjectArray<State> states;
|
||||
|
||||
ScriptedMetricAggregator(String name,
|
||||
ScriptedMetricAggContexts.MapScript.LeafFactory mapScript,
|
||||
ScriptedMetricAggContexts.CombineScript combineScript,
|
||||
Script reduceScript,
|
||||
Map<String, Object> aggState,
|
||||
SearchContext context,
|
||||
Aggregator parent,
|
||||
Map<String, Object> metadata) throws IOException {
|
||||
ScriptedMetricAggregator(
|
||||
String name,
|
||||
SearchLookup lookup,
|
||||
Map<String, Object> initialState,
|
||||
ScriptedMetricAggContexts.MapScript.Factory mapScriptFactory,
|
||||
Map<String, Object> mapScriptParams,
|
||||
ScriptedMetricAggContexts.CombineScript.Factory combineScriptFactory,
|
||||
Map<String, Object> combineScriptParams,
|
||||
Script reduceScript,
|
||||
SearchContext context,
|
||||
Aggregator parent,
|
||||
Map<String, Object> metadata
|
||||
) throws IOException {
|
||||
super(name, context, parent, metadata);
|
||||
this.aggState = aggState;
|
||||
this.mapScript = mapScript;
|
||||
this.combineScript = combineScript;
|
||||
this.lookup = lookup;
|
||||
this.initialState = initialState;
|
||||
this.mapScriptFactory = mapScriptFactory;
|
||||
this.mapScriptParams = mapScriptParams;
|
||||
this.combineScriptFactory = combineScriptFactory;
|
||||
this.combineScriptParams = combineScriptParams;
|
||||
this.reduceScript = reduceScript;
|
||||
states = context.bigArrays().newObjectArray(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -63,36 +91,77 @@ class ScriptedMetricAggregator extends MetricsAggregator {
|
|||
}
|
||||
|
||||
@Override
|
||||
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
|
||||
final LeafBucketCollector sub) throws IOException {
|
||||
final ScriptedMetricAggContexts.MapScript leafMapScript = mapScript.newInstance(ctx);
|
||||
return new LeafBucketCollectorBase(sub, leafMapScript) {
|
||||
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
|
||||
// Clear any old leaf scripts so we rebuild them on the new leaf when we first see them.
|
||||
for (long i = 0; i < states.size(); i++) {
|
||||
State state = states.get(i);
|
||||
if (state == null) {
|
||||
continue;
|
||||
}
|
||||
state.leafMapScript = null;
|
||||
}
|
||||
return new LeafBucketCollectorBase(sub, null) {
|
||||
private Scorable scorer;
|
||||
|
||||
@Override
|
||||
public void setScorer(Scorable scorer) throws IOException {
|
||||
leafMapScript.setScorer(scorer);
|
||||
this.scorer = scorer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void collect(int doc, long bucket) throws IOException {
|
||||
assert bucket == 0 : bucket;
|
||||
|
||||
leafMapScript.setDocument(doc);
|
||||
leafMapScript.execute();
|
||||
public void collect(int doc, long owningBucketOrd) throws IOException {
|
||||
states = context.bigArrays().grow(states, owningBucketOrd + 1);
|
||||
State state = states.get(owningBucketOrd);
|
||||
if (state == null) {
|
||||
addRequestCircuitBreakerBytes(BUCKET_COST_ESTIMATE);
|
||||
state = new State();
|
||||
states.set(owningBucketOrd, state);
|
||||
}
|
||||
if (state.leafMapScript == null) {
|
||||
state.leafMapScript = state.mapScript.newInstance(ctx);
|
||||
state.leafMapScript.setScorer(scorer);
|
||||
}
|
||||
state.leafMapScript.setDocument(doc);
|
||||
state.leafMapScript.execute();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public InternalAggregation buildAggregation(long owningBucketOrdinal) {
|
||||
Object aggregation;
|
||||
if (combineScript != null) {
|
||||
aggregation = combineScript.execute();
|
||||
CollectionUtils.ensureNoSelfReferences(aggregation, "Scripted metric aggs combine script");
|
||||
} else {
|
||||
aggregation = aggState;
|
||||
Object result = resultFor(aggStateFor(owningBucketOrdinal));
|
||||
StreamOutput.checkWriteable(result);
|
||||
return new InternalScriptedMetric(name, result, reduceScript, metadata());
|
||||
}
|
||||
|
||||
private Map<String, Object> aggStateFor(long owningBucketOrdinal) {
|
||||
if (owningBucketOrdinal >= states.size()) {
|
||||
return newInitialState();
|
||||
}
|
||||
StreamOutput.checkWriteable(aggregation);
|
||||
return new InternalScriptedMetric(name, aggregation, reduceScript, metadata());
|
||||
State state = states.get(owningBucketOrdinal);
|
||||
if (state == null) {
|
||||
return newInitialState();
|
||||
}
|
||||
// The last script that touched the state at this point is the "map" script
|
||||
CollectionUtils.ensureNoSelfReferences(state.aggState, "Scripted metric aggs map script");
|
||||
return state.aggState;
|
||||
}
|
||||
|
||||
private Object resultFor(Map<String, Object> aggState) {
|
||||
if (combineScriptFactory == null) {
|
||||
return aggState;
|
||||
}
|
||||
Object result = combineScriptFactory.newInstance(
|
||||
// Send a deep copy of the params because the script is allowed to mutate it
|
||||
ScriptedMetricAggregatorFactory.deepCopyParams(combineScriptParams, context),
|
||||
aggState
|
||||
).execute();
|
||||
CollectionUtils.ensureNoSelfReferences(result, "Scripted metric aggs combine script");
|
||||
return result;
|
||||
}
|
||||
|
||||
private Map<String, Object> newInitialState() {
|
||||
return initialState == null ? new HashMap<>() : ScriptedMetricAggregatorFactory.deepCopyParams(initialState, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -101,9 +170,23 @@ class ScriptedMetricAggregator extends MetricsAggregator {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected void doPostCollection() throws IOException {
|
||||
CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs map script");
|
||||
public void close() {
|
||||
Releasables.close(states);
|
||||
}
|
||||
|
||||
super.doPostCollection();
|
||||
private class State {
|
||||
private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript;
|
||||
private final Map<String, Object> aggState;
|
||||
private MapScript leafMapScript;
|
||||
|
||||
State() {
|
||||
aggState = newInitialState();
|
||||
mapScript = mapScriptFactory.newFactory(
|
||||
// Send a deep copy of the params because the script is allowed to mutate it
|
||||
ScriptedMetricAggregatorFactory.deepCopyParams(mapScriptParams, context),
|
||||
aggState,
|
||||
lookup
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -72,32 +72,35 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory {
|
|||
Aggregator parent,
|
||||
boolean collectsFromSingleBucket,
|
||||
Map<String, Object> metadata) throws IOException {
|
||||
if (collectsFromSingleBucket == false) {
|
||||
return asMultiBucketAggregator(this, searchContext, parent);
|
||||
}
|
||||
Map<String, Object> aggParams = this.aggParams;
|
||||
if (aggParams != null) {
|
||||
aggParams = deepCopyParams(aggParams, searchContext);
|
||||
} else {
|
||||
aggParams = new HashMap<>();
|
||||
}
|
||||
Map<String, Object> aggParams = this.aggParams == null ? org.elasticsearch.common.collect.Map.of() : this.aggParams;
|
||||
Map<String, Object> initialState = new HashMap<String, Object>();
|
||||
|
||||
Map<String, Object> aggState = new HashMap<String, Object>();
|
||||
|
||||
final ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance(
|
||||
mergeParams(aggParams, initScriptParams), aggState);
|
||||
final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript = this.mapScript.newFactory(
|
||||
mergeParams(aggParams, mapScriptParams), aggState, lookup);
|
||||
final ScriptedMetricAggContexts.CombineScript combineScript = this.combineScript.newInstance(
|
||||
mergeParams(aggParams, combineScriptParams), aggState);
|
||||
|
||||
final Script reduceScript = deepCopyScript(this.reduceScript, searchContext, aggParams);
|
||||
ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance(
|
||||
mergeParams(aggParams, initScriptParams),
|
||||
initialState
|
||||
);
|
||||
if (initScript != null) {
|
||||
initScript.execute();
|
||||
CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs init script");
|
||||
CollectionUtils.ensureNoSelfReferences(initialState, "Scripted metric aggs init script");
|
||||
}
|
||||
return new ScriptedMetricAggregator(name, mapScript,
|
||||
combineScript, reduceScript, aggState, searchContext, parent, metadata);
|
||||
|
||||
Map<String, Object> mapParams = mergeParams(aggParams, mapScriptParams);
|
||||
Map<String, Object> combineParams = mergeParams(aggParams, combineScriptParams);
|
||||
Script reduceScript = deepCopyScript(this.reduceScript, searchContext, aggParams);
|
||||
|
||||
return new ScriptedMetricAggregator(
|
||||
name,
|
||||
lookup,
|
||||
initialState,
|
||||
mapScript,
|
||||
mapParams,
|
||||
combineScript,
|
||||
combineParams,
|
||||
reduceScript,
|
||||
searchContext,
|
||||
parent,
|
||||
metadata
|
||||
);
|
||||
}
|
||||
|
||||
private static Script deepCopyScript(Script script, SearchContext context, Map<String, Object> aggParams) {
|
||||
|
@ -110,7 +113,7 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory {
|
|||
}
|
||||
|
||||
@SuppressWarnings({ "unchecked" })
|
||||
private static <T> T deepCopyParams(T original, SearchContext context) {
|
||||
static <T> T deepCopyParams(T original, SearchContext context) {
|
||||
T clone;
|
||||
if (original instanceof Map) {
|
||||
Map<?, ?> originalMap = (Map<?, ?>) original;
|
||||
|
@ -152,3 +155,4 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory {
|
|||
return combined;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,12 +20,15 @@
|
|||
package org.elasticsearch.search.aggregations.metrics;
|
||||
|
||||
import org.apache.lucene.document.SortedNumericDocValuesField;
|
||||
import org.apache.lucene.document.SortedSetDocValuesField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.common.CheckedConsumer;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.index.IndexSettings;
|
||||
|
@ -38,9 +41,10 @@ import org.elasticsearch.script.ScriptEngine;
|
|||
import org.elasticsearch.script.ScriptModule;
|
||||
import org.elasticsearch.script.ScriptService;
|
||||
import org.elasticsearch.script.ScriptType;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregatorTestCase;
|
||||
import org.elasticsearch.search.aggregations.support.AggregationUsageService;
|
||||
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -49,11 +53,11 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static java.util.Collections.singleton;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
|
||||
|
||||
|
@ -115,8 +119,8 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
|
|||
return state;
|
||||
});
|
||||
SCRIPTS.put("reduceScript", params -> {
|
||||
Map<String, Object> state = (Map<String, Object>) params.get("state");
|
||||
return state;
|
||||
List<Integer> states = (List<Integer>) params.get("states");
|
||||
return states.stream().mapToInt(Integer::intValue).sum();
|
||||
});
|
||||
|
||||
SCRIPTS.put("initScriptScore", params -> {
|
||||
|
@ -416,6 +420,32 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testAsSubAgg() throws IOException {
|
||||
AggregationBuilder aggregationBuilder = new TermsAggregationBuilder("t").field("t")
|
||||
.subAggregation(
|
||||
new ScriptedMetricAggregationBuilder("scripted").initScript(INIT_SCRIPT)
|
||||
.mapScript(MAP_SCRIPT)
|
||||
.combineScript(COMBINE_SCRIPT)
|
||||
.reduceScript(REDUCE_SCRIPT)
|
||||
);
|
||||
CheckedConsumer<RandomIndexWriter, IOException> buildIndex = iw -> {
|
||||
for (int i = 0; i < 99; i++) {
|
||||
iw.addDocument(singleton(new SortedSetDocValuesField("t", i % 2 == 0 ? new BytesRef("even") : new BytesRef("odd"))));
|
||||
}
|
||||
};
|
||||
Consumer<StringTerms> verify = terms -> {
|
||||
StringTerms.Bucket even = terms.getBucketByKey("even");
|
||||
assertThat(even.getDocCount(), equalTo(50L));
|
||||
ScriptedMetric evenMetric = even.getAggregations().get("scripted");
|
||||
assertThat(evenMetric.aggregation(), equalTo(50));
|
||||
StringTerms.Bucket odd = terms.getBucketByKey("odd");
|
||||
assertThat(odd.getDocCount(), equalTo(49L));
|
||||
ScriptedMetric oddMetric = odd.getAggregations().get("scripted");
|
||||
assertThat(oddMetric.aggregation(), equalTo(49));
|
||||
};
|
||||
testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, verify, keywordField("t"), longField("number"));
|
||||
}
|
||||
|
||||
/**
|
||||
* We cannot use Mockito for mocking QueryShardContext in this case because
|
||||
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)
|
||||
|
@ -430,12 +460,24 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
|
|||
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, SCRIPTS, Collections.emptyMap());
|
||||
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
|
||||
ScriptService scriptService = new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
|
||||
ValuesSourceRegistry valuesSourceRegistry = mock(ValuesSourceRegistry.class);
|
||||
AggregationUsageService.Builder builder = new AggregationUsageService.Builder();
|
||||
builder.registerAggregationUsage(ScriptedMetricAggregationBuilder.NAME);
|
||||
when(valuesSourceRegistry.getUsageService()).thenReturn(builder.build());
|
||||
return new QueryShardContext(0, indexSettings, BigArrays.NON_RECYCLING_INSTANCE, null,
|
||||
null, mapperService, null, scriptService, xContentRegistry(), writableRegistry(),
|
||||
null, null, System::currentTimeMillis, null, null, () -> true, valuesSourceRegistry);
|
||||
return new QueryShardContext(
|
||||
0,
|
||||
indexSettings,
|
||||
BigArrays.NON_RECYCLING_INSTANCE,
|
||||
null,
|
||||
getIndexFieldDataLookup(mapperService, circuitBreakerService),
|
||||
mapperService,
|
||||
null,
|
||||
scriptService,
|
||||
xContentRegistry(),
|
||||
writableRegistry(),
|
||||
null,
|
||||
null,
|
||||
System::currentTimeMillis,
|
||||
null,
|
||||
null,
|
||||
() -> true,
|
||||
valuesSourceRegistry
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue