mirror of https://github.com/apache/lucene.git
SOLR-11212: Allow the predict StreamEvaluator to work on arrays as well as a single numeric parameter
This commit is contained in:
parent
2a8930cf83
commit
ea85543ace
|
@ -19,6 +19,8 @@ package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
|
||||||
import org.apache.solr.client.solrj.io.Tuple;
|
import org.apache.solr.client.solrj.io.Tuple;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
|
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
|
||||||
|
@ -38,17 +40,27 @@ public class PredictEvaluator extends ComplexEvaluator implements Expressible {
|
||||||
if(2 != subEvaluators.size()){
|
if(2 != subEvaluators.size()){
|
||||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting two values (regression result and a number) but found %d",expression,subEvaluators.size()));
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting two values (regression result and a number) but found %d",expression,subEvaluators.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Number evaluate(Tuple tuple) throws IOException {
|
public Object evaluate(Tuple tuple) throws IOException {
|
||||||
|
|
||||||
StreamEvaluator r = subEvaluators.get(0);
|
StreamEvaluator r = subEvaluators.get(0);
|
||||||
StreamEvaluator d = subEvaluators.get(1);
|
StreamEvaluator d = subEvaluators.get(1);
|
||||||
|
|
||||||
RegressionEvaluator.RegressionTuple rt= (RegressionEvaluator.RegressionTuple)r.evaluate(tuple);
|
RegressionEvaluator.RegressionTuple rt= (RegressionEvaluator.RegressionTuple)r.evaluate(tuple);
|
||||||
Number n = (Number)d.evaluate(tuple);
|
|
||||||
return rt.predict(n.doubleValue());
|
Object o = d.evaluate(tuple);
|
||||||
|
if(o instanceof Number) {
|
||||||
|
Number n = (Number)o;
|
||||||
|
return rt.predict(n.doubleValue());
|
||||||
|
} else {
|
||||||
|
List<Number> list = (List<Number>)o;
|
||||||
|
List<Number> predications = new ArrayList();
|
||||||
|
for(Number n : list) {
|
||||||
|
predications.add(rt.predict(n.doubleValue()));
|
||||||
|
}
|
||||||
|
return predications;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -6270,7 +6270,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
||||||
String expr1 = "search("+COLLECTIONORALIAS+", q=\"col_s:a\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
|
String expr1 = "search("+COLLECTIONORALIAS+", q=\"col_s:a\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
|
||||||
String expr2 = "search("+COLLECTIONORALIAS+", q=\"col_s:b\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
|
String expr2 = "search("+COLLECTIONORALIAS+", q=\"col_s:b\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
|
||||||
|
|
||||||
String cexpr = "let(a="+expr1+", b="+expr2+", c=col(a, price_f), d=col(b, price_f), e=regress(c, d), tuple(regress=e, p=predict(e, 300)))";
|
String cexpr = "let(a="+expr1+", b="+expr2+", c=col(a, price_f), d=col(b, price_f), e=regress(c, d), tuple(regress=e, p=predict(e, 300), pl=predict(e, c)))";
|
||||||
|
|
||||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
paramsLoc.set("expr", cexpr);
|
paramsLoc.set("expr", cexpr);
|
||||||
|
@ -6293,6 +6293,8 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
||||||
assertTrue(rSquare == 1.0D);
|
assertTrue(rSquare == 1.0D);
|
||||||
double prediction = tuple.getDouble("p");
|
double prediction = tuple.getDouble("p");
|
||||||
assertTrue(prediction == 600.0D);
|
assertTrue(prediction == 600.0D);
|
||||||
|
List<Number> predictions = (List<Number>)tuple.get("pl");
|
||||||
|
assertList(predictions, 200.0, 400.0, 600.0, 200.0, 400.0, 800.0, 1200.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue