mirror of https://github.com/apache/lucene.git
SOLR-12158: Allow the monteCarlo Stream Evaluator to support variables
This commit is contained in:
parent
61c37551ac
commit
9ebe11f1d9
|
@ -19,19 +19,47 @@ package org.apache.solr.client.solrj.io.eval;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.UncheckedIOException;
|
import java.io.UncheckedIOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
import org.apache.solr.client.solrj.io.Tuple;
|
import org.apache.solr.client.solrj.io.Tuple;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.TupleStream;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
public class MonteCarloEvaluator extends RecursiveEvaluator {
|
public class MonteCarloEvaluator extends RecursiveEvaluator {
|
||||||
protected static final long serialVersionUID = 1L;
|
protected static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
private Map variables = new LinkedHashMap();
|
||||||
|
|
||||||
public MonteCarloEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
public MonteCarloEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||||
super(expression, factory);
|
super(expression, factory);
|
||||||
|
|
||||||
|
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
|
||||||
|
//Get all the named params
|
||||||
|
Set<String> echo = null;
|
||||||
|
boolean echoAll = false;
|
||||||
|
String currentName = null;
|
||||||
|
for(StreamExpressionParameter np : namedParams) {
|
||||||
|
String name = ((StreamExpressionNamedParameter)np).getName();
|
||||||
|
currentName = name;
|
||||||
|
|
||||||
|
|
||||||
|
StreamExpressionParameter param = ((StreamExpressionNamedParameter)np).getParameter();
|
||||||
|
if(factory.isEvaluator((StreamExpression)param)) {
|
||||||
|
StreamEvaluator evaluator = factory.constructEvaluator((StreamExpression) param);
|
||||||
|
variables.put(name, evaluator);
|
||||||
|
} else {
|
||||||
|
TupleStream tupleStream = factory.constructStream((StreamExpression) param);
|
||||||
|
variables.put(name, tupleStream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
init();
|
init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,7 +77,7 @@ public class MonteCarloEvaluator extends RecursiveEvaluator {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object evaluate(Tuple tuple) throws IOException {
|
public Object evaluate(Tuple tuple) throws IOException {
|
||||||
try{
|
try {
|
||||||
|
|
||||||
StreamEvaluator function = containedEvaluators.get(0);
|
StreamEvaluator function = containedEvaluators.get(0);
|
||||||
StreamEvaluator iterationsEvaluator = containedEvaluators.get(1);
|
StreamEvaluator iterationsEvaluator = containedEvaluators.get(1);
|
||||||
|
@ -57,6 +85,7 @@ public class MonteCarloEvaluator extends RecursiveEvaluator {
|
||||||
int it = itNum.intValue();
|
int it = itNum.intValue();
|
||||||
List<Number> results = new ArrayList();
|
List<Number> results = new ArrayList();
|
||||||
for(int i=0; i<it; i++) {
|
for(int i=0; i<it; i++) {
|
||||||
|
populateVariables(tuple);
|
||||||
Number result = (Number)function.evaluate(tuple);
|
Number result = (Number)function.evaluate(tuple);
|
||||||
results.add(result);
|
results.add(result);
|
||||||
}
|
}
|
||||||
|
@ -73,4 +102,39 @@ public class MonteCarloEvaluator extends RecursiveEvaluator {
|
||||||
// Nothing to do here
|
// Nothing to do here
|
||||||
throw new IOException("This call should never occur");
|
throw new IOException("This call should never occur");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private void populateVariables(Tuple contextTuple) throws IOException {
|
||||||
|
|
||||||
|
Set<Map.Entry<String, Object>> entries = variables.entrySet();
|
||||||
|
|
||||||
|
for(Map.Entry<String, Object> entry : entries) {
|
||||||
|
String name = entry.getKey();
|
||||||
|
Object o = entry.getValue();
|
||||||
|
if(o instanceof TupleStream) {
|
||||||
|
List<Tuple> tuples = new ArrayList();
|
||||||
|
TupleStream tStream = (TupleStream)o;
|
||||||
|
tStream.setStreamContext(streamContext);
|
||||||
|
try {
|
||||||
|
tStream.open();
|
||||||
|
TUPLES:
|
||||||
|
while(true) {
|
||||||
|
Tuple tuple = tStream.read();
|
||||||
|
if (tuple.EOF) {
|
||||||
|
break TUPLES;
|
||||||
|
} else {
|
||||||
|
tuples.add(tuple);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contextTuple.put(name, tuples);
|
||||||
|
} finally {
|
||||||
|
tStream.close();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
StreamEvaluator evaluator = (StreamEvaluator)o;
|
||||||
|
Object eo = evaluator.evaluate(contextTuple);
|
||||||
|
contextTuple.put(name, eo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3240,6 +3240,37 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
||||||
assertEquals(out.get(9).doubleValue(), 30.0, .0);
|
assertEquals(out.get(9).doubleValue(), 30.0, .0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMonteCarloWithVariables() throws Exception {
|
||||||
|
String cexpr = "let(a=constantDistribution(10), " +
|
||||||
|
" b=constantDistribution(20), " +
|
||||||
|
" c=monteCarlo(d=sample(a),"+
|
||||||
|
" e=sample(b),"+
|
||||||
|
" add(d, add(e, 10)), " +
|
||||||
|
" 10))";
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
List<Number> out = (List<Number>)tuples.get(0).get("c");
|
||||||
|
assertTrue(out.size()==10);
|
||||||
|
assertEquals(out.get(0).doubleValue(), 40.0, .0);
|
||||||
|
assertEquals(out.get(1).doubleValue(), 40.0, .0);
|
||||||
|
assertEquals(out.get(2).doubleValue(), 40.0, .0);
|
||||||
|
assertEquals(out.get(3).doubleValue(), 40.0, .0);
|
||||||
|
assertEquals(out.get(4).doubleValue(), 40.0, .0);
|
||||||
|
assertEquals(out.get(5).doubleValue(), 40.0, .0);
|
||||||
|
assertEquals(out.get(6).doubleValue(), 40.0, .0);
|
||||||
|
assertEquals(out.get(7).doubleValue(), 40.0, .0);
|
||||||
|
assertEquals(out.get(8).doubleValue(), 40.0, .0);
|
||||||
|
assertEquals(out.get(9).doubleValue(), 40.0, .0);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWeibullDistribution() throws Exception {
|
public void testWeibullDistribution() throws Exception {
|
||||||
String cexpr = "let(echo=true, " +
|
String cexpr = "let(echo=true, " +
|
||||||
|
|
Loading…
Reference in New Issue