Aggregations Refactor: Refactor Scripted Metric Aggregation

This commit is contained in:
Colin Goodheart-Smithe 2015-11-24 10:00:41 +00:00
parent 94e867906c
commit 8499e27dc5
4 changed files with 190 additions and 9 deletions

View File

@ -34,12 +34,24 @@ import org.elasticsearch.script.ScriptService.ScriptType;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.function.Supplier;
/** /**
* Script holds all the parameters necessary to compile or find in cache and then execute a script. * Script holds all the parameters necessary to compile or find in cache and then execute a script.
*/ */
public class Script implements ToXContent, Streamable { public class Script implements ToXContent, Streamable {
/**
* A {@link Supplier} implementation for use when reading a {@link Script}
* using {@link StreamInput#readOptionalStreamable(Supplier)}
*/
public static final Supplier<Script> SUPPLIER = new Supplier<Script>() {
@Override
public Script get() {
return new Script();
}
};
public static final ScriptType DEFAULT_TYPE = ScriptType.INLINE; public static final ScriptType DEFAULT_TYPE = ScriptType.INLINE;
private static final ScriptParser PARSER = new ScriptParser(); private static final ScriptParser PARSER = new ScriptParser();

View File

@ -20,6 +20,9 @@
package org.elasticsearch.search.aggregations.metrics.scripted; package org.elasticsearch.search.aggregations.metrics.scripted;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.script.ExecutableScript; import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.LeafSearchScript; import org.elasticsearch.script.LeafSearchScript;
import org.elasticsearch.script.Script; import org.elasticsearch.script.Script;
@ -44,6 +47,7 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Objects;
public class ScriptedMetricAggregator extends MetricsAggregator { public class ScriptedMetricAggregator extends MetricsAggregator {
@ -114,13 +118,43 @@ public class ScriptedMetricAggregator extends MetricsAggregator {
private Script reduceScript; private Script reduceScript;
private Map<String, Object> params; private Map<String, Object> params;
public Factory(String name, Script initScript, Script mapScript, Script combineScript, Script reduceScript, public Factory(String name) {
Map<String, Object> params) {
super(name, InternalScriptedMetric.TYPE); super(name, InternalScriptedMetric.TYPE);
}
/**
* Set the <tt>init</tt> script.
*/
public void initScript(Script initScript) {
this.initScript = initScript; this.initScript = initScript;
}
/**
* Set the <tt>map</tt> script.
*/
public void mapScript(Script mapScript) {
this.mapScript = mapScript; this.mapScript = mapScript;
}
/**
* Set the <tt>combine</tt> script.
*/
public void combineScript(Script combineScript) {
this.combineScript = combineScript; this.combineScript = combineScript;
}
/**
* Set the <tt>reduce</tt> script.
*/
public void reduceScript(Script reduceScript) {
this.reduceScript = reduceScript; this.reduceScript = reduceScript;
}
/**
* Set parameters that will be available in the <tt>init</tt>,
* <tt>map</tt> and <tt>combine</tt> phases.
*/
public void params(Map<String, Object> params) {
this.params = params; this.params = params;
} }
@ -189,6 +223,73 @@ public class ScriptedMetricAggregator extends MetricsAggregator {
return clone; return clone;
} }
@Override
protected XContentBuilder internalXContent(XContentBuilder builder, Params builderParams) throws IOException {
builder.startObject();
if (initScript != null) {
builder.field(ScriptedMetricParser.INIT_SCRIPT_FIELD.getPreferredName(), initScript);
}
if (mapScript != null) {
builder.field(ScriptedMetricParser.MAP_SCRIPT_FIELD.getPreferredName(), mapScript);
}
if (combineScript != null) {
builder.field(ScriptedMetricParser.COMBINE_SCRIPT_FIELD.getPreferredName(), combineScript);
}
if (reduceScript != null) {
builder.field(ScriptedMetricParser.REDUCE_SCRIPT_FIELD.getPreferredName(), reduceScript);
}
if (params != null) {
builder.field(ScriptedMetricParser.PARAMS_FIELD.getPreferredName());
builder.map(params);
}
builder.endObject();
return builder;
}
@Override
protected AggregatorFactory doReadFrom(String name, StreamInput in) throws IOException {
Factory factory = new Factory(name);
factory.initScript = in.readOptionalStreamable(Script.SUPPLIER);
factory.mapScript = in.readOptionalStreamable(Script.SUPPLIER);
factory.combineScript = in.readOptionalStreamable(Script.SUPPLIER);
factory.reduceScript = in.readOptionalStreamable(Script.SUPPLIER);
if (in.readBoolean()) {
factory.params = in.readMap();
}
return factory;
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalStreamable(initScript);
out.writeOptionalStreamable(mapScript);
out.writeOptionalStreamable(combineScript);
out.writeOptionalStreamable(reduceScript);
boolean hasParams = params != null;
out.writeBoolean(hasParams);
if (hasParams) {
out.writeMap(params);
}
}
@Override
protected int doHashCode() {
return Objects.hash(initScript, mapScript, combineScript, reduceScript, params);
}
@Override
protected boolean doEquals(Object obj) {
Factory other = (Factory) obj;
return Objects.equals(initScript, other.initScript)
&& Objects.equals(mapScript, other.mapScript)
&& Objects.equals(combineScript, other.combineScript)
&& Objects.equals(reduceScript, other.reduceScript)
&& Objects.equals(params, other.params);
}
} }
} }

View File

@ -147,13 +147,19 @@ public class ScriptedMetricParser implements Aggregator.Parser {
if (mapScript == null) { if (mapScript == null) {
throw new SearchParseException(context, "map_script field is required in [" + aggregationName + "].", parser.getTokenLocation()); throw new SearchParseException(context, "map_script field is required in [" + aggregationName + "].", parser.getTokenLocation());
} }
return new ScriptedMetricAggregator.Factory(aggregationName, initScript, mapScript, combineScript, reduceScript, params);
ScriptedMetricAggregator.Factory factory = new ScriptedMetricAggregator.Factory(aggregationName);
factory.initScript(initScript);
factory.mapScript(mapScript);
factory.combineScript(combineScript);
factory.reduceScript(reduceScript);
factory.params(params);
return factory;
} }
// NORELEASE implement this method when refactoring this aggregation
@Override @Override
public AggregatorFactory[] getFactoryPrototypes() { public AggregatorFactory[] getFactoryPrototypes() {
return null; return new AggregatorFactory[] { new ScriptedMetricAggregator.Factory(null) };
} }
} }

View File

@ -0,0 +1,62 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptService.ScriptType;
import org.elasticsearch.search.aggregations.BaseAggregationTestCase;
import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetricAggregator;
import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetricAggregator.Factory;
import java.util.HashMap;
import java.util.Map;
public class ScriptedMetricTests extends BaseAggregationTestCase<ScriptedMetricAggregator.Factory> {
@Override
protected Factory createTestAggregatorFactory() {
Factory factory = new Factory(randomAsciiOfLengthBetween(1, 20));
if (randomBoolean()) {
factory.initScript(randomScript("initScript"));
}
factory.mapScript(randomScript("mapScript"));
if (randomBoolean()) {
factory.combineScript(randomScript("combineScript"));
}
if (randomBoolean()) {
factory.reduceScript(randomScript("reduceScript"));
}
if (randomBoolean()) {
Map<String, Object> params = new HashMap<String, Object>();
params.put("foo", "bar");
factory.params(params);
}
return factory;
}
private Script randomScript(String script) {
if (randomBoolean()) {
return new Script(script);
} else {
return new Script(script, randomFrom(ScriptType.values()), randomFrom("my_lang", null), null);
}
}
}