mirror of https://github.com/apache/lucene.git
SOLR-11212: Allow the predict StreamEvaluator to work on arrays as well as a single numeric parameter
This commit is contained in:
parent
2a8930cf83
commit
ea85543ace
|
@ -19,6 +19,8 @@ package org.apache.solr.client.solrj.io.eval;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
|
||||
|
@ -38,17 +40,27 @@ public class PredictEvaluator extends ComplexEvaluator implements Expressible {
|
|||
if(2 != subEvaluators.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting two values (regression result and a number) but found %d",expression,subEvaluators.size()));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public Number evaluate(Tuple tuple) throws IOException {
|
||||
public Object evaluate(Tuple tuple) throws IOException {
|
||||
|
||||
StreamEvaluator r = subEvaluators.get(0);
|
||||
StreamEvaluator d = subEvaluators.get(1);
|
||||
|
||||
RegressionEvaluator.RegressionTuple rt= (RegressionEvaluator.RegressionTuple)r.evaluate(tuple);
|
||||
Number n = (Number)d.evaluate(tuple);
|
||||
return rt.predict(n.doubleValue());
|
||||
|
||||
Object o = d.evaluate(tuple);
|
||||
if(o instanceof Number) {
|
||||
Number n = (Number)o;
|
||||
return rt.predict(n.doubleValue());
|
||||
} else {
|
||||
List<Number> list = (List<Number>)o;
|
||||
List<Number> predications = new ArrayList();
|
||||
for(Number n : list) {
|
||||
predications.add(rt.predict(n.doubleValue()));
|
||||
}
|
||||
return predications;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -6270,7 +6270,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
String expr1 = "search("+COLLECTIONORALIAS+", q=\"col_s:a\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
|
||||
String expr2 = "search("+COLLECTIONORALIAS+", q=\"col_s:b\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
|
||||
|
||||
String cexpr = "let(a="+expr1+", b="+expr2+", c=col(a, price_f), d=col(b, price_f), e=regress(c, d), tuple(regress=e, p=predict(e, 300)))";
|
||||
String cexpr = "let(a="+expr1+", b="+expr2+", c=col(a, price_f), d=col(b, price_f), e=regress(c, d), tuple(regress=e, p=predict(e, 300), pl=predict(e, c)))";
|
||||
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
|
@ -6293,6 +6293,8 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
assertTrue(rSquare == 1.0D);
|
||||
double prediction = tuple.getDouble("p");
|
||||
assertTrue(prediction == 600.0D);
|
||||
List<Number> predictions = (List<Number>)tuple.get("pl");
|
||||
assertList(predictions, 200.0, 400.0, 600.0, 200.0, 400.0, 800.0, 1200.0);
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue