SOLR-12158: Allow the monteCarlo Stream Evaluator to support variables

This commit is contained in:
Joel Bernstein 2018-04-10 12:36:03 -04:00
parent 61c37551ac
commit 9ebe11f1d9
2 changed files with 96 additions and 1 deletions

View File

@ -19,19 +19,47 @@ package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class MonteCarloEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
private Map variables = new LinkedHashMap();
public MonteCarloEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
//Get all the named params
Set<String> echo = null;
boolean echoAll = false;
String currentName = null;
for(StreamExpressionParameter np : namedParams) {
String name = ((StreamExpressionNamedParameter)np).getName();
currentName = name;
StreamExpressionParameter param = ((StreamExpressionNamedParameter)np).getParameter();
if(factory.isEvaluator((StreamExpression)param)) {
StreamEvaluator evaluator = factory.constructEvaluator((StreamExpression) param);
variables.put(name, evaluator);
} else {
TupleStream tupleStream = factory.constructStream((StreamExpression) param);
variables.put(name, tupleStream);
}
}
init();
}
@ -57,6 +85,7 @@ public class MonteCarloEvaluator extends RecursiveEvaluator {
int it = itNum.intValue();
List<Number> results = new ArrayList();
for(int i=0; i<it; i++) {
populateVariables(tuple);
Number result = (Number)function.evaluate(tuple);
results.add(result);
}
@ -73,4 +102,39 @@ public class MonteCarloEvaluator extends RecursiveEvaluator {
// Nothing to do here
throw new IOException("This call should never occur");
}
private void populateVariables(Tuple contextTuple) throws IOException {
Set<Map.Entry<String, Object>> entries = variables.entrySet();
for(Map.Entry<String, Object> entry : entries) {
String name = entry.getKey();
Object o = entry.getValue();
if(o instanceof TupleStream) {
List<Tuple> tuples = new ArrayList();
TupleStream tStream = (TupleStream)o;
tStream.setStreamContext(streamContext);
try {
tStream.open();
TUPLES:
while(true) {
Tuple tuple = tStream.read();
if (tuple.EOF) {
break TUPLES;
} else {
tuples.add(tuple);
}
}
contextTuple.put(name, tuples);
} finally {
tStream.close();
}
} else {
StreamEvaluator evaluator = (StreamEvaluator)o;
Object eo = evaluator.evaluate(contextTuple);
contextTuple.put(name, eo);
}
}
}
}

View File

@ -3240,6 +3240,37 @@ public class MathExpressionTest extends SolrCloudTestCase {
assertEquals(out.get(9).doubleValue(), 30.0, .0);
}
@Test
public void testMonteCarloWithVariables() throws Exception {
String cexpr = "let(a=constantDistribution(10), " +
" b=constantDistribution(20), " +
" c=monteCarlo(d=sample(a),"+
" e=sample(b),"+
" add(d, add(e, 10)), " +
" 10))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<Number> out = (List<Number>)tuples.get(0).get("c");
assertTrue(out.size()==10);
assertEquals(out.get(0).doubleValue(), 40.0, .0);
assertEquals(out.get(1).doubleValue(), 40.0, .0);
assertEquals(out.get(2).doubleValue(), 40.0, .0);
assertEquals(out.get(3).doubleValue(), 40.0, .0);
assertEquals(out.get(4).doubleValue(), 40.0, .0);
assertEquals(out.get(5).doubleValue(), 40.0, .0);
assertEquals(out.get(6).doubleValue(), 40.0, .0);
assertEquals(out.get(7).doubleValue(), 40.0, .0);
assertEquals(out.get(8).doubleValue(), 40.0, .0);
assertEquals(out.get(9).doubleValue(), 40.0, .0);
}
@Test
public void testWeibullDistribution() throws Exception {
String cexpr = "let(echo=true, " +