Aggregations: Fixes scripted metrics aggregation when used as a sub aggregation
The scripted metric aggregation is now a PER_BUCKET aggregation so that parent buckets are evaluated independently. Also the params and reduceParams are copied for each instance of the aggregator (each parent bucket) so modifications to the values are kept only within the scope of its parent bucket Closes #8036
This commit is contained in:
parent
101ea31fdf
commit
e9f05eed80
|
@ -26,6 +26,10 @@ import org.elasticsearch.search.aggregations.support.AggregationContext;
|
||||||
public abstract class MetricsAggregator extends Aggregator {
|
public abstract class MetricsAggregator extends Aggregator {
|
||||||
|
|
||||||
protected MetricsAggregator(String name, long estimatedBucketsCount, AggregationContext context, Aggregator parent) {
|
protected MetricsAggregator(String name, long estimatedBucketsCount, AggregationContext context, Aggregator parent) {
|
||||||
super(name, BucketAggregationMode.MULTI_BUCKETS, AggregatorFactories.EMPTY, estimatedBucketsCount, context, parent);
|
this(name, estimatedBucketsCount, BucketAggregationMode.MULTI_BUCKETS, context, parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected MetricsAggregator(String name, long estimatedBucketsCount, BucketAggregationMode bucketAggregationMode, AggregationContext context, Aggregator parent) {
|
||||||
|
super(name, bucketAggregationMode, AggregatorFactories.EMPTY, estimatedBucketsCount, context, parent);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,15 +24,17 @@ import org.elasticsearch.script.ExecutableScript;
|
||||||
import org.elasticsearch.script.ScriptService;
|
import org.elasticsearch.script.ScriptService;
|
||||||
import org.elasticsearch.script.ScriptService.ScriptType;
|
import org.elasticsearch.script.ScriptService.ScriptType;
|
||||||
import org.elasticsearch.script.SearchScript;
|
import org.elasticsearch.script.SearchScript;
|
||||||
|
import org.elasticsearch.search.SearchParseException;
|
||||||
import org.elasticsearch.search.aggregations.Aggregator;
|
import org.elasticsearch.search.aggregations.Aggregator;
|
||||||
import org.elasticsearch.search.aggregations.AggregatorFactory;
|
import org.elasticsearch.search.aggregations.AggregatorFactory;
|
||||||
import org.elasticsearch.search.aggregations.InternalAggregation;
|
import org.elasticsearch.search.aggregations.InternalAggregation;
|
||||||
import org.elasticsearch.search.aggregations.metrics.MetricsAggregator;
|
import org.elasticsearch.search.aggregations.metrics.MetricsAggregator;
|
||||||
import org.elasticsearch.search.aggregations.support.AggregationContext;
|
import org.elasticsearch.search.aggregations.support.AggregationContext;
|
||||||
|
import org.elasticsearch.search.internal.SearchContext;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
import java.util.*;
|
||||||
import java.util.Map;
|
import java.util.Map.Entry;
|
||||||
|
|
||||||
public class ScriptedMetricAggregator extends MetricsAggregator {
|
public class ScriptedMetricAggregator extends MetricsAggregator {
|
||||||
|
|
||||||
|
@ -51,7 +53,7 @@ public class ScriptedMetricAggregator extends MetricsAggregator {
|
||||||
protected ScriptedMetricAggregator(String name, String scriptLang, ScriptType initScriptType, String initScript,
|
protected ScriptedMetricAggregator(String name, String scriptLang, ScriptType initScriptType, String initScript,
|
||||||
ScriptType mapScriptType, String mapScript, ScriptType combineScriptType, String combineScript, ScriptType reduceScriptType,
|
ScriptType mapScriptType, String mapScript, ScriptType combineScriptType, String combineScript, ScriptType reduceScriptType,
|
||||||
String reduceScript, Map<String, Object> params, Map<String, Object> reduceParams, AggregationContext context, Aggregator parent) {
|
String reduceScript, Map<String, Object> params, Map<String, Object> reduceParams, AggregationContext context, Aggregator parent) {
|
||||||
super(name, 1, context, parent);
|
super(name, 1, BucketAggregationMode.PER_BUCKET, context, parent);
|
||||||
this.scriptService = context.searchContext().scriptService();
|
this.scriptService = context.searchContext().scriptService();
|
||||||
this.scriptLang = scriptLang;
|
this.scriptLang = scriptLang;
|
||||||
this.reduceScriptType = reduceScriptType;
|
this.reduceScriptType = reduceScriptType;
|
||||||
|
@ -59,7 +61,7 @@ public class ScriptedMetricAggregator extends MetricsAggregator {
|
||||||
this.params = new HashMap<>();
|
this.params = new HashMap<>();
|
||||||
this.params.put("_agg", new HashMap<>());
|
this.params.put("_agg", new HashMap<>());
|
||||||
} else {
|
} else {
|
||||||
this.params = params;
|
this.params = new HashMap<>(params);
|
||||||
}
|
}
|
||||||
if (reduceParams == null) {
|
if (reduceParams == null) {
|
||||||
this.reduceParams = new HashMap<>();
|
this.reduceParams = new HashMap<>();
|
||||||
|
@ -142,9 +144,45 @@ public class ScriptedMetricAggregator extends MetricsAggregator {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Aggregator create(AggregationContext context, Aggregator parent, long expectedBucketsCount) {
|
public Aggregator create(AggregationContext context, Aggregator parent, long expectedBucketsCount) {
|
||||||
|
Map<String, Object> params = null;
|
||||||
|
if (this.params != null) {
|
||||||
|
params = deepCopyParams(this.params, context.searchContext());
|
||||||
|
}
|
||||||
|
Map<String, Object> reduceParams = null;
|
||||||
|
if (this.reduceParams != null) {
|
||||||
|
reduceParams = deepCopyParams(this.reduceParams, context.searchContext());
|
||||||
|
}
|
||||||
return new ScriptedMetricAggregator(name, scriptLang, initScriptType, initScript, mapScriptType, mapScript, combineScriptType,
|
return new ScriptedMetricAggregator(name, scriptLang, initScriptType, initScript, mapScriptType, mapScript, combineScriptType,
|
||||||
combineScript, reduceScriptType, reduceScript, params, reduceParams, context, parent);
|
combineScript, reduceScriptType, reduceScript, params, reduceParams, context, parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings({ "unchecked" })
|
||||||
|
private static <T> T deepCopyParams(T original, SearchContext context) {
|
||||||
|
T clone;
|
||||||
|
if (original instanceof Map) {
|
||||||
|
Map<?, ?> originalMap = (Map<?, ?>) original;
|
||||||
|
Map<Object, Object> clonedMap = new HashMap<>();
|
||||||
|
for (Entry<?, ?> e : originalMap.entrySet()) {
|
||||||
|
clonedMap.put(deepCopyParams(e.getKey(), context), deepCopyParams(e.getValue(), context));
|
||||||
|
}
|
||||||
|
clone = (T) clonedMap;
|
||||||
|
} else if (original instanceof List) {
|
||||||
|
List<?> originalList = (List<?>) original;
|
||||||
|
List<Object> clonedList = new ArrayList<Object>();
|
||||||
|
for (Object o : originalList) {
|
||||||
|
clonedList.add(deepCopyParams(o, context));
|
||||||
|
}
|
||||||
|
clone = (T) clonedList;
|
||||||
|
} else if (original instanceof String || original instanceof Integer || original instanceof Long || original instanceof Short
|
||||||
|
|| original instanceof Byte || original instanceof Float || original instanceof Double || original instanceof Character
|
||||||
|
|| original instanceof Boolean) {
|
||||||
|
clone = original;
|
||||||
|
} else {
|
||||||
|
throw new SearchParseException(context, "Can only clone primitives, String, ArrayList, and HashMap. Found: "
|
||||||
|
+ original.getClass().getCanonicalName());
|
||||||
|
}
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,9 @@ import org.elasticsearch.action.search.SearchResponse;
|
||||||
import org.elasticsearch.common.settings.ImmutableSettings;
|
import org.elasticsearch.common.settings.ImmutableSettings;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.search.aggregations.Aggregation;
|
import org.elasticsearch.search.aggregations.Aggregation;
|
||||||
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.bucket.histogram.Histogram;
|
import org.elasticsearch.search.aggregations.bucket.histogram.Histogram;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.histogram.Histogram.Bucket;
|
||||||
import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetric;
|
import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetric;
|
||||||
import org.elasticsearch.test.ElasticsearchIntegrationTest;
|
import org.elasticsearch.test.ElasticsearchIntegrationTest;
|
||||||
import org.elasticsearch.test.ElasticsearchIntegrationTest.ClusterScope;
|
import org.elasticsearch.test.ElasticsearchIntegrationTest.ClusterScope;
|
||||||
|
@ -59,7 +61,8 @@ public class ScriptedMetricTests extends ElasticsearchIntegrationTest {
|
||||||
numDocs = randomIntBetween(10, 100);
|
numDocs = randomIntBetween(10, 100);
|
||||||
for (int i = 0; i < numDocs; i++) {
|
for (int i = 0; i < numDocs; i++) {
|
||||||
builders.add(client().prepareIndex("idx", "type", "" + i).setSource(
|
builders.add(client().prepareIndex("idx", "type", "" + i).setSource(
|
||||||
jsonBuilder().startObject().field("value", randomAsciiOfLengthBetween(5, 15)).endObject()));
|
jsonBuilder().startObject().field("value", randomAsciiOfLengthBetween(5, 15))
|
||||||
|
.field("l_value", i).endObject()));
|
||||||
}
|
}
|
||||||
indexRandom(true, builders);
|
indexRandom(true, builders);
|
||||||
|
|
||||||
|
@ -561,6 +564,62 @@ public class ScriptedMetricTests extends ElasticsearchIntegrationTest {
|
||||||
assertThat(((Number) object).longValue(), equalTo(numDocs * 3));
|
assertThat(((Number) object).longValue(), equalTo(numDocs * 3));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInitMapCombineReduce_withParams_asSubAgg() {
|
||||||
|
Map<String, Object> varsMap = new HashMap<>();
|
||||||
|
varsMap.put("multiplier", 1);
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("_agg", new ArrayList<>());
|
||||||
|
params.put("vars", varsMap);
|
||||||
|
|
||||||
|
SearchResponse response = client()
|
||||||
|
.prepareSearch("idx")
|
||||||
|
.setQuery(matchAllQuery()).setSize(1000)
|
||||||
|
.addAggregation(
|
||||||
|
histogram("histo")
|
||||||
|
.field("l_value")
|
||||||
|
.interval(1)
|
||||||
|
.subAggregation(
|
||||||
|
scriptedMetric("scripted")
|
||||||
|
.params(params)
|
||||||
|
.initScript("vars.multiplier = 3")
|
||||||
|
.mapScript("_agg.add(vars.multiplier)")
|
||||||
|
.combineScript(
|
||||||
|
"newaggregation = []; sum = 0;for (a in _agg) { sum += a}; newaggregation.add(sum); return newaggregation")
|
||||||
|
.reduceScript(
|
||||||
|
"newaggregation = []; sum = 0;for (aggregation in _aggs) { for (a in aggregation) { sum += a} }; newaggregation.add(sum); return newaggregation")))
|
||||||
|
.execute().actionGet();
|
||||||
|
assertSearchResponse(response);
|
||||||
|
assertThat(response.getHits().getTotalHits(), equalTo(numDocs));
|
||||||
|
Aggregation aggregation = response.getAggregations().get("histo");
|
||||||
|
assertThat(aggregation, notNullValue());
|
||||||
|
assertThat(aggregation, instanceOf(Histogram.class));
|
||||||
|
Histogram histoAgg = (Histogram) aggregation;
|
||||||
|
assertThat(histoAgg.getName(), equalTo("histo"));
|
||||||
|
List<? extends Bucket> buckets = histoAgg.getBuckets();
|
||||||
|
assertThat(buckets, notNullValue());
|
||||||
|
for (Bucket b : buckets) {
|
||||||
|
assertThat(b, notNullValue());
|
||||||
|
assertThat(b.getDocCount(), equalTo(1l));
|
||||||
|
Aggregations subAggs = b.getAggregations();
|
||||||
|
assertThat(subAggs, notNullValue());
|
||||||
|
assertThat(subAggs.asList().size(), equalTo(1));
|
||||||
|
Aggregation subAgg = subAggs.get("scripted");
|
||||||
|
assertThat(subAgg, notNullValue());
|
||||||
|
assertThat(subAgg, instanceOf(ScriptedMetric.class));
|
||||||
|
ScriptedMetric scriptedMetricAggregation = (ScriptedMetric) subAgg;
|
||||||
|
assertThat(scriptedMetricAggregation.getName(), equalTo("scripted"));
|
||||||
|
assertThat(scriptedMetricAggregation.aggregation(), notNullValue());
|
||||||
|
assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class));
|
||||||
|
List<?> aggregationList = (List<?>) scriptedMetricAggregation.aggregation();
|
||||||
|
assertThat(aggregationList.size(), equalTo(1));
|
||||||
|
Object object = aggregationList.get(0);
|
||||||
|
assertThat(object, notNullValue());
|
||||||
|
assertThat(object, instanceOf(Number.class));
|
||||||
|
assertThat(((Number) object).longValue(), equalTo(3l));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEmptyAggregation() throws Exception {
|
public void testEmptyAggregation() throws Exception {
|
||||||
Map<String, Object> varsMap = new HashMap<>();
|
Map<String, Object> varsMap = new HashMap<>();
|
||||||
|
|
Loading…
Reference in New Issue