diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CovarianceEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CovarianceEvaluator.java index 5a568a583e2..810ab330edd 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CovarianceEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CovarianceEvaluator.java @@ -19,13 +19,13 @@ package org.apache.solr.client.solrj.io.eval; import java.io.IOException; import java.math.BigDecimal; import java.util.List; -import java.util.Locale; +import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.stat.correlation.Covariance; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; -public class CovarianceEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { +public class CovarianceEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker { protected static final long serialVersionUID = 1L; public CovarianceEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ @@ -33,25 +33,26 @@ public class CovarianceEvaluator extends RecursiveNumericEvaluator implements Tw } @Override - public Object doWork(Object first, Object second) throws IOException{ - if(null == first){ - throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); - } - if(null == second){ - throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); - } - if(!(first instanceof List)){ - throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); - } - if(!(second instanceof List)){ - throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); - } + public Object doWork(Object ... values) throws IOException{ - Covariance covariance = new Covariance(); - - return covariance.covariance( - ((List)first).stream().mapToDouble(value -> ((BigDecimal)value).doubleValue()).toArray(), - ((List)second).stream().mapToDouble(value -> ((BigDecimal)value).doubleValue()).toArray() - ); + if(values.length == 2) { + Object first = values[0]; + Object second = values[1]; + Covariance covariance = new Covariance(); + + return covariance.covariance( + ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(), + ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray() + ); + } else if(values.length == 1) { + Matrix matrix = (Matrix) values[0]; + double[][] data = matrix.getData(); + Covariance covariance = new Covariance(data, true); + RealMatrix coMatrix = covariance.getCovarianceMatrix(); + double[][] coData = coMatrix.getData(); + return new Matrix(coData); + } else { + throw new IOException("The cov function expects either two numeric arrays or a matrix as parameters."); + } } } diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java index 955eafd10a1..9c3111626b7 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java @@ -7056,6 +7056,38 @@ public class StreamExpressionTest extends SolrCloudTestCase { assertEquals(maxa.doubleValue(), 30, .5); } + @Test + public void testCovMatrix() throws Exception { + String cexpr = "let(a=array(1,2,3), b=array(2,4,6), c=array(4, 8, 12), d=transpose(matrix(a, b, c)), f=cov(d))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List> cm = (List>)tuples.get(0).get("f"); + assertEquals(cm.size(), 3); + List row1 = cm.get(0); + assertEquals(row1.size(), 3); + assertEquals(row1.get(0).longValue(), 1); + assertEquals(row1.get(1).longValue(), 2); + assertEquals(row1.get(2).longValue(), 4); + + List row2 = cm.get(1); + assertEquals(row2.size(), 3); + assertEquals(row2.get(0).longValue(), 2); + assertEquals(row2.get(1).longValue(), 4); + assertEquals(row2.get(2).longValue(), 8); + + List row3 = cm.get(2); + assertEquals(row3.size(), 3); + assertEquals(row3.get(0).longValue(), 4); + assertEquals(row3.get(1).longValue(), 8); + assertEquals(row3.get(2).longValue(), 16); + }