mirror of https://github.com/apache/lucene.git
SOLR-12158: Allow the monteCarlo Stream Evaluator to support variables
This commit is contained in:
parent
61c37551ac
commit
9ebe11f1d9
|
@ -19,19 +19,47 @@ package org.apache.solr.client.solrj.io.eval;
|
|||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.TupleStream;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class MonteCarloEvaluator extends RecursiveEvaluator {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
private Map variables = new LinkedHashMap();
|
||||
|
||||
public MonteCarloEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
||||
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
|
||||
//Get all the named params
|
||||
Set<String> echo = null;
|
||||
boolean echoAll = false;
|
||||
String currentName = null;
|
||||
for(StreamExpressionParameter np : namedParams) {
|
||||
String name = ((StreamExpressionNamedParameter)np).getName();
|
||||
currentName = name;
|
||||
|
||||
|
||||
StreamExpressionParameter param = ((StreamExpressionNamedParameter)np).getParameter();
|
||||
if(factory.isEvaluator((StreamExpression)param)) {
|
||||
StreamEvaluator evaluator = factory.constructEvaluator((StreamExpression) param);
|
||||
variables.put(name, evaluator);
|
||||
} else {
|
||||
TupleStream tupleStream = factory.constructStream((StreamExpression) param);
|
||||
variables.put(name, tupleStream);
|
||||
}
|
||||
}
|
||||
|
||||
init();
|
||||
}
|
||||
|
||||
|
@ -49,7 +77,7 @@ public class MonteCarloEvaluator extends RecursiveEvaluator {
|
|||
|
||||
@Override
|
||||
public Object evaluate(Tuple tuple) throws IOException {
|
||||
try{
|
||||
try {
|
||||
|
||||
StreamEvaluator function = containedEvaluators.get(0);
|
||||
StreamEvaluator iterationsEvaluator = containedEvaluators.get(1);
|
||||
|
@ -57,6 +85,7 @@ public class MonteCarloEvaluator extends RecursiveEvaluator {
|
|||
int it = itNum.intValue();
|
||||
List<Number> results = new ArrayList();
|
||||
for(int i=0; i<it; i++) {
|
||||
populateVariables(tuple);
|
||||
Number result = (Number)function.evaluate(tuple);
|
||||
results.add(result);
|
||||
}
|
||||
|
@ -73,4 +102,39 @@ public class MonteCarloEvaluator extends RecursiveEvaluator {
|
|||
// Nothing to do here
|
||||
throw new IOException("This call should never occur");
|
||||
}
|
||||
|
||||
|
||||
private void populateVariables(Tuple contextTuple) throws IOException {
|
||||
|
||||
Set<Map.Entry<String, Object>> entries = variables.entrySet();
|
||||
|
||||
for(Map.Entry<String, Object> entry : entries) {
|
||||
String name = entry.getKey();
|
||||
Object o = entry.getValue();
|
||||
if(o instanceof TupleStream) {
|
||||
List<Tuple> tuples = new ArrayList();
|
||||
TupleStream tStream = (TupleStream)o;
|
||||
tStream.setStreamContext(streamContext);
|
||||
try {
|
||||
tStream.open();
|
||||
TUPLES:
|
||||
while(true) {
|
||||
Tuple tuple = tStream.read();
|
||||
if (tuple.EOF) {
|
||||
break TUPLES;
|
||||
} else {
|
||||
tuples.add(tuple);
|
||||
}
|
||||
}
|
||||
contextTuple.put(name, tuples);
|
||||
} finally {
|
||||
tStream.close();
|
||||
}
|
||||
} else {
|
||||
StreamEvaluator evaluator = (StreamEvaluator)o;
|
||||
Object eo = evaluator.evaluate(contextTuple);
|
||||
contextTuple.put(name, eo);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3240,6 +3240,37 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
assertEquals(out.get(9).doubleValue(), 30.0, .0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMonteCarloWithVariables() throws Exception {
|
||||
String cexpr = "let(a=constantDistribution(10), " +
|
||||
" b=constantDistribution(20), " +
|
||||
" c=monteCarlo(d=sample(a),"+
|
||||
" e=sample(b),"+
|
||||
" add(d, add(e, 10)), " +
|
||||
" 10))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Number> out = (List<Number>)tuples.get(0).get("c");
|
||||
assertTrue(out.size()==10);
|
||||
assertEquals(out.get(0).doubleValue(), 40.0, .0);
|
||||
assertEquals(out.get(1).doubleValue(), 40.0, .0);
|
||||
assertEquals(out.get(2).doubleValue(), 40.0, .0);
|
||||
assertEquals(out.get(3).doubleValue(), 40.0, .0);
|
||||
assertEquals(out.get(4).doubleValue(), 40.0, .0);
|
||||
assertEquals(out.get(5).doubleValue(), 40.0, .0);
|
||||
assertEquals(out.get(6).doubleValue(), 40.0, .0);
|
||||
assertEquals(out.get(7).doubleValue(), 40.0, .0);
|
||||
assertEquals(out.get(8).doubleValue(), 40.0, .0);
|
||||
assertEquals(out.get(9).doubleValue(), 40.0, .0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWeibullDistribution() throws Exception {
|
||||
String cexpr = "let(echo=true, " +
|
||||
|
|
Loading…
Reference in New Issue