diff --git a/common/src/main/java/com/metamx/druid/aggregation/JavaScriptAggregatorFactory.java b/common/src/main/java/com/metamx/druid/aggregation/JavaScriptAggregatorFactory.java index 1d0bb9da863..9a594febb96 100644 --- a/common/src/main/java/com/metamx/druid/aggregation/JavaScriptAggregatorFactory.java +++ b/common/src/main/java/com/metamx/druid/aggregation/JavaScriptAggregatorFactory.java @@ -48,7 +48,11 @@ public class JavaScriptAggregatorFactory implements AggregatorFactory private final String name; private final List fieldNames; - private final String script; + private final String fnAggregate; + private final String fnReset; + private final String fnCombine; + + private final JavaScriptAggregator.ScriptAggregator compiledScript; @@ -56,13 +60,19 @@ public class JavaScriptAggregatorFactory implements AggregatorFactory public JavaScriptAggregatorFactory( @JsonProperty("name") String name, @JsonProperty("fieldNames") final List fieldNames, - @JsonProperty("script") final String expression + @JsonProperty("fnAggregate") final String fnAggregate, + @JsonProperty("fnReset") final String fnReset, + @JsonProperty("fnCombine") final String fnCombine ) { this.name = name; - this.script = expression; this.fieldNames = fieldNames; - this.compiledScript = compileScript(script); + + this.fnAggregate = fnAggregate; + this.fnReset = fnReset; + this.fnCombine = fnCombine; + + this.compiledScript = compileScript(fnAggregate, fnReset, fnCombine); } @Override @@ -116,7 +126,7 @@ public class JavaScriptAggregatorFactory implements AggregatorFactory @Override public AggregatorFactory getCombiningFactory() { - throw new UnsupportedOperationException(); + return new JavaScriptAggregatorFactory(name, Lists.newArrayList(name), fnCombine, fnReset, fnCombine); } @Override @@ -144,8 +154,21 @@ public class JavaScriptAggregatorFactory implements AggregatorFactory } @JsonProperty - public String getScript() { - return script; + public String getFnAggregate() + { + return fnAggregate; + } + + @JsonProperty + public String getFnReset() + { + return fnReset; + } + + @JsonProperty + public String getFnCombine() + { + return fnCombine; } @Override @@ -160,7 +183,7 @@ public class JavaScriptAggregatorFactory implements AggregatorFactory try { MessageDigest md = MessageDigest.getInstance("SHA-1"); byte[] fieldNameBytes = Joiner.on(",").join(fieldNames).getBytes(); - byte[] sha1 = md.digest(script.getBytes()); + byte[] sha1 = md.digest((fnAggregate+fnReset+fnCombine).getBytes()); return ByteBuffer.allocate(1 + fieldNameBytes.length + sha1.length) .put(CACHE_TYPE_ID) @@ -197,21 +220,13 @@ public class JavaScriptAggregatorFactory implements AggregatorFactory return "JavaScriptAggregatorFactory{" + "name='" + name + '\'' + ", fieldNames=" + fieldNames + - ", script='" + script + '\'' + + ", fnAggregate='" + fnAggregate + '\'' + + ", fnReset='" + fnReset + '\'' + + ", fnCombine='" + fnCombine + '\'' + '}'; } - protected static Function getScriptFunction(String name, ScriptableObject scope) - { - Object fun = scope.get(name, scope); - if (fun instanceof Function) { - return (Function) fun; - } else { - throw new IAE("Function [%s] not defined in script", name); - } - } - - public static JavaScriptAggregator.ScriptAggregator compileScript(final String script) + public static JavaScriptAggregator.ScriptAggregator compileScript(final String aggregate, final String reset, final String combine) { final ContextFactory contextFactory = ContextFactory.getGlobal(); Context context = contextFactory.enterContext(); @@ -219,12 +234,9 @@ public class JavaScriptAggregatorFactory implements AggregatorFactory final ScriptableObject scope = context.initStandardObjects(); - Script compiledScript = context.compileString(script, "script", 1, null); - compiledScript.exec(context, scope); - - final Function fnAggregate = getScriptFunction("aggregate", scope); - final Function fnReset = getScriptFunction("reset", scope); - final Function fnCombine = getScriptFunction("combine", scope); + final Function fnAggregate = context.compileFunction(scope, aggregate, "aggregate", 1, null); + final Function fnReset = context.compileFunction(scope, reset, "reset", 1, null); + final Function fnCombine = context.compileFunction(scope, combine, "combine", 1, null); Context.exit(); return new JavaScriptAggregator.ScriptAggregator() diff --git a/common/src/test/java/com/metamx/druid/aggregation/JavaScriptAggregatorTest.java b/common/src/test/java/com/metamx/druid/aggregation/JavaScriptAggregatorTest.java index bbe97070752..bb1a3a39075 100644 --- a/common/src/test/java/com/metamx/druid/aggregation/JavaScriptAggregatorTest.java +++ b/common/src/test/java/com/metamx/druid/aggregation/JavaScriptAggregatorTest.java @@ -20,25 +20,29 @@ package com.metamx.druid.aggregation; import com.google.common.collect.Lists; -import com.google.common.primitives.Doubles; +import com.google.common.collect.Maps; import com.metamx.druid.processing.FloatMetricSelector; import org.junit.Assert; import org.junit.Test; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.Map; public class JavaScriptAggregatorTest { - protected static final String sumLogATimesBPlusTen = - "function aggregate(current, a, b) { return current + (Math.log(a) * b) }" - + "function combine(a,b) { return a + b }" - + "function reset() { return 10 }"; + protected static final Map sumLogATimesBPlusTen = Maps.newHashMap(); + protected static final Map scriptDoubleSum = Maps.newHashMap(); - protected static final String scriptDoubleSum = - "function aggregate(current, a) { return current + a }" - + "function combine(a,b) { return a + b }" - + "function reset() { return 0 }"; + static { + sumLogATimesBPlusTen.put("fnAggregate", "function aggregate(current, a, b) { return current + (Math.log(a) * b) }"); + sumLogATimesBPlusTen.put("fnReset", "function reset() { return 10 }"); + sumLogATimesBPlusTen.put("fnCombine", "function combine(a,b) { return a + b }"); + + scriptDoubleSum.put("fnAggregate", "function aggregate(current, a) { return current + a }"); + scriptDoubleSum.put("fnReset", "function reset() { return 0 }"); + scriptDoubleSum.put("fnCombine", "function combine(a,b) { return a + b }"); + } private static void aggregate(TestFloatMetricSelector selector1, TestFloatMetricSelector selector2, Aggregator agg) { @@ -69,10 +73,14 @@ public class JavaScriptAggregatorTest final TestFloatMetricSelector selector1 = new TestFloatMetricSelector(new float[]{42.12f, 9f}); final TestFloatMetricSelector selector2 = new TestFloatMetricSelector(new float[]{2f, 3f}); + Map script = sumLogATimesBPlusTen; + JavaScriptAggregator agg = new JavaScriptAggregator( "billy", Arrays.asList(selector1, selector2), - JavaScriptAggregatorFactory.compileScript(sumLogATimesBPlusTen) + JavaScriptAggregatorFactory.compileScript(script.get("fnAggregate"), + script.get("fnReset"), + script.get("fnCombine")) ); agg.reset(); @@ -103,9 +111,12 @@ public class JavaScriptAggregatorTest final TestFloatMetricSelector selector1 = new TestFloatMetricSelector(new float[]{42.12f, 9f}); final TestFloatMetricSelector selector2 = new TestFloatMetricSelector(new float[]{2f, 3f}); + Map script = sumLogATimesBPlusTen; JavaScriptBufferAggregator agg = new JavaScriptBufferAggregator( Arrays.asList(selector1, selector2), - JavaScriptAggregatorFactory.compileScript(sumLogATimesBPlusTen) + JavaScriptAggregatorFactory.compileScript(script.get("fnAggregate"), + script.get("fnReset"), + script.get("fnCombine")) ); ByteBuffer buf = ByteBuffer.allocateDirect(32); @@ -150,10 +161,13 @@ public class JavaScriptAggregatorTest } */ + Map script = scriptDoubleSum; JavaScriptAggregator aggRhino = new JavaScriptAggregator( "billy", Lists.asList(selector, new FloatMetricSelector[]{}), - JavaScriptAggregatorFactory.compileScript(scriptDoubleSum) + JavaScriptAggregatorFactory.compileScript(script.get("fnAggregate"), + script.get("fnReset"), + script.get("fnCombine")) ); DoubleSumAggregator doubleAgg = new DoubleSumAggregator("billy", selector);