From 9ebe11f1d952b083ad5e60bc589efb6aa4148a48 Mon Sep 17 00:00:00 2001 From: Joel Bernstein Date: Tue, 10 Apr 2018 12:36:03 -0400 Subject: [PATCH] SOLR-12158: Allow the monteCarlo Stream Evaluator to support variables --- .../solrj/io/eval/MonteCarloEvaluator.java | 66 ++++++++++++++++++- .../solrj/io/stream/MathExpressionTest.java | 31 +++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MonteCarloEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MonteCarloEvaluator.java index 27e806f17a1..b0ec8c5d1e1 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MonteCarloEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MonteCarloEvaluator.java @@ -19,19 +19,47 @@ package org.apache.solr.client.solrj.io.eval; import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; +import java.util.Map; +import java.util.Set; import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.stream.TupleStream; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; public class MonteCarloEvaluator extends RecursiveEvaluator { protected static final long serialVersionUID = 1L; + private Map variables = new LinkedHashMap(); + public MonteCarloEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ super(expression, factory); + List namedParams = factory.getNamedOperands(expression); + //Get all the named params + Set echo = null; + boolean echoAll = false; + String currentName = null; + for(StreamExpressionParameter np : namedParams) { + String name = ((StreamExpressionNamedParameter)np).getName(); + currentName = name; + + + StreamExpressionParameter param = ((StreamExpressionNamedParameter)np).getParameter(); + if(factory.isEvaluator((StreamExpression)param)) { + StreamEvaluator evaluator = factory.constructEvaluator((StreamExpression) param); + variables.put(name, evaluator); + } else { + TupleStream tupleStream = factory.constructStream((StreamExpression) param); + variables.put(name, tupleStream); + } + } + init(); } @@ -49,7 +77,7 @@ public class MonteCarloEvaluator extends RecursiveEvaluator { @Override public Object evaluate(Tuple tuple) throws IOException { - try{ + try { StreamEvaluator function = containedEvaluators.get(0); StreamEvaluator iterationsEvaluator = containedEvaluators.get(1); @@ -57,6 +85,7 @@ public class MonteCarloEvaluator extends RecursiveEvaluator { int it = itNum.intValue(); List results = new ArrayList(); for(int i=0; i> entries = variables.entrySet(); + + for(Map.Entry entry : entries) { + String name = entry.getKey(); + Object o = entry.getValue(); + if(o instanceof TupleStream) { + List tuples = new ArrayList(); + TupleStream tStream = (TupleStream)o; + tStream.setStreamContext(streamContext); + try { + tStream.open(); + TUPLES: + while(true) { + Tuple tuple = tStream.read(); + if (tuple.EOF) { + break TUPLES; + } else { + tuples.add(tuple); + } + } + contextTuple.put(name, tuples); + } finally { + tStream.close(); + } + } else { + StreamEvaluator evaluator = (StreamEvaluator)o; + Object eo = evaluator.evaluate(contextTuple); + contextTuple.put(name, eo); + } + } + } } diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java index 41116e2472f..abc1c214b79 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java @@ -3240,6 +3240,37 @@ public class MathExpressionTest extends SolrCloudTestCase { assertEquals(out.get(9).doubleValue(), 30.0, .0); } + @Test + public void testMonteCarloWithVariables() throws Exception { + String cexpr = "let(a=constantDistribution(10), " + + " b=constantDistribution(20), " + + " c=monteCarlo(d=sample(a),"+ + " e=sample(b),"+ + " add(d, add(e, 10)), " + + " 10))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List out = (List)tuples.get(0).get("c"); + assertTrue(out.size()==10); + assertEquals(out.get(0).doubleValue(), 40.0, .0); + assertEquals(out.get(1).doubleValue(), 40.0, .0); + assertEquals(out.get(2).doubleValue(), 40.0, .0); + assertEquals(out.get(3).doubleValue(), 40.0, .0); + assertEquals(out.get(4).doubleValue(), 40.0, .0); + assertEquals(out.get(5).doubleValue(), 40.0, .0); + assertEquals(out.get(6).doubleValue(), 40.0, .0); + assertEquals(out.get(7).doubleValue(), 40.0, .0); + assertEquals(out.get(8).doubleValue(), 40.0, .0); + assertEquals(out.get(9).doubleValue(), 40.0, .0); + } + @Test public void testWeibullDistribution() throws Exception { String cexpr = "let(echo=true, " +