SOLR-11212: Allow the predict StreamEvaluator to work on arrays as well as a single numeric parameter

This commit is contained in:
Joel Bernstein 2017-08-08 14:53:45 -04:00
parent 2a8930cf83
commit ea85543ace
2 changed files with 19 additions and 5 deletions

View File

@ -19,6 +19,8 @@ package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import java.util.List;
import java.util.ArrayList;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
@ -38,17 +40,27 @@ public class PredictEvaluator extends ComplexEvaluator implements Expressible {
if(2 != subEvaluators.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting two values (regression result and a number) but found %d",expression,subEvaluators.size()));
}
}
public Number evaluate(Tuple tuple) throws IOException {
public Object evaluate(Tuple tuple) throws IOException {
StreamEvaluator r = subEvaluators.get(0);
StreamEvaluator d = subEvaluators.get(1);
RegressionEvaluator.RegressionTuple rt= (RegressionEvaluator.RegressionTuple)r.evaluate(tuple);
Number n = (Number)d.evaluate(tuple);
return rt.predict(n.doubleValue());
Object o = d.evaluate(tuple);
if(o instanceof Number) {
Number n = (Number)o;
return rt.predict(n.doubleValue());
} else {
List<Number> list = (List<Number>)o;
List<Number> predications = new ArrayList();
for(Number n : list) {
predications.add(rt.predict(n.doubleValue()));
}
return predications;
}
}
@Override

View File

@ -6270,7 +6270,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
String expr1 = "search("+COLLECTIONORALIAS+", q=\"col_s:a\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
String expr2 = "search("+COLLECTIONORALIAS+", q=\"col_s:b\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
String cexpr = "let(a="+expr1+", b="+expr2+", c=col(a, price_f), d=col(b, price_f), e=regress(c, d), tuple(regress=e, p=predict(e, 300)))";
String cexpr = "let(a="+expr1+", b="+expr2+", c=col(a, price_f), d=col(b, price_f), e=regress(c, d), tuple(regress=e, p=predict(e, 300), pl=predict(e, c)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
@ -6293,6 +6293,8 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertTrue(rSquare == 1.0D);
double prediction = tuple.getDouble("p");
assertTrue(prediction == 600.0D);
List<Number> predictions = (List<Number>)tuple.get("pl");
assertList(predictions, 200.0, 400.0, 600.0, 200.0, 400.0, 800.0, 1200.0);
}