diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java index af8a7f02006..ed9034a9b22 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java @@ -19,6 +19,8 @@ package org.apache.solr.client.solrj.io.eval; import java.io.IOException; import java.util.Locale; +import java.util.List; +import java.util.ArrayList; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.stream.expr.Explanation; @@ -38,17 +40,27 @@ public class PredictEvaluator extends ComplexEvaluator implements Expressible { if(2 != subEvaluators.size()){ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting two values (regression result and a number) but found %d",expression,subEvaluators.size())); } - } - public Number evaluate(Tuple tuple) throws IOException { + public Object evaluate(Tuple tuple) throws IOException { StreamEvaluator r = subEvaluators.get(0); StreamEvaluator d = subEvaluators.get(1); RegressionEvaluator.RegressionTuple rt= (RegressionEvaluator.RegressionTuple)r.evaluate(tuple); - Number n = (Number)d.evaluate(tuple); - return rt.predict(n.doubleValue()); + + Object o = d.evaluate(tuple); + if(o instanceof Number) { + Number n = (Number)o; + return rt.predict(n.doubleValue()); + } else { + List list = (List)o; + List predications = new ArrayList(); + for(Number n : list) { + predications.add(rt.predict(n.doubleValue())); + } + return predications; + } } @Override diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java index 699a5629d1d..93e52871d6a 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java @@ -6270,7 +6270,7 @@ public class StreamExpressionTest extends SolrCloudTestCase { String expr1 = "search("+COLLECTIONORALIAS+", q=\"col_s:a\", fl=\"price_f, order_i\", sort=\"order_i asc\")"; String expr2 = "search("+COLLECTIONORALIAS+", q=\"col_s:b\", fl=\"price_f, order_i\", sort=\"order_i asc\")"; - String cexpr = "let(a="+expr1+", b="+expr2+", c=col(a, price_f), d=col(b, price_f), e=regress(c, d), tuple(regress=e, p=predict(e, 300)))"; + String cexpr = "let(a="+expr1+", b="+expr2+", c=col(a, price_f), d=col(b, price_f), e=regress(c, d), tuple(regress=e, p=predict(e, 300), pl=predict(e, c)))"; ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); paramsLoc.set("expr", cexpr); @@ -6293,6 +6293,8 @@ public class StreamExpressionTest extends SolrCloudTestCase { assertTrue(rSquare == 1.0D); double prediction = tuple.getDouble("p"); assertTrue(prediction == 600.0D); + List predictions = (List)tuple.get("pl"); + assertList(predictions, 200.0, 400.0, 600.0, 200.0, 400.0, 800.0, 1200.0); }