diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index 8a83160d30d..949a040bade 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -283,11 +283,11 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("spline", SplineEvaluator.class) .withFunctionName("ttest", TTestEvaluator.class) .withFunctionName("pairedTtest", PairedTTestEvaluator.class) - + .withFunctionName("multiVariateNormalDistribution", MultiVariateNormalDistributionEvaluator.class) // Boolean Stream Evaluators - .withFunctionName("and", AndEvaluator.class) + .withFunctionName("and", AndEvaluator.class) .withFunctionName("eor", ExclusiveOrEvaluator.class) .withFunctionName("eq", EqualToEvaluator.class) .withFunctionName("gt", GreaterThanEvaluator.class) diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiVariateNormalDistributionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiVariateNormalDistributionEvaluator.java new file mode 100644 index 00000000000..bc2fbcb80a5 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiVariateNormalDistributionEvaluator.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; +import java.util.List; + +import org.apache.commons.math3.distribution.MultivariateNormalDistribution; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class MultiVariateNormalDistributionEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + + private static final long serialVersionUID = 1; + + public MultiVariateNormalDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object first, Object second) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + if(null == second){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); + } + + List means = (List)first; + Matrix covar = (Matrix)second; + + double[] m = new double[means.size()]; + for(int i=0; i< m.length; i++) { + m[i] = means.get(i).doubleValue(); + } + + return new MultivariateNormalDistribution(m, covar.getData()); + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java index 9b7aca50604..5ea29e6d41e 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java @@ -18,12 +18,16 @@ package org.apache.solr.client.solrj.io.eval; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Locale; import java.util.stream.Collectors; +import java.util.List; import org.apache.commons.math3.distribution.IntegerDistribution; +import org.apache.commons.math3.distribution.MultivariateRealDistribution; import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.distribution.MultivariateNormalDistribution; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; @@ -43,7 +47,7 @@ public class SampleEvaluator extends RecursiveObjectEvaluator implements ManyVal Object first = objects[0]; - if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution) && !(first instanceof MarkovChainEvaluator.MarkovChain)){ + if(!(first instanceof MultivariateRealDistribution) && !(first instanceof RealDistribution) && !(first instanceof IntegerDistribution) && !(first instanceof MarkovChainEvaluator.MarkovChain)){ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a Markov Chain, Real or Integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName())); } @@ -61,11 +65,30 @@ public class SampleEvaluator extends RecursiveObjectEvaluator implements ManyVal } } else if (first instanceof RealDistribution) { RealDistribution realDistribution = (RealDistribution) first; - if(second != null) { + if (second != null) { return Arrays.stream(realDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList()); } else { return realDistribution.sample(); } + }else if(first instanceof MultivariateNormalDistribution) { + if(second != null) { + MultivariateNormalDistribution multivariateNormalDistribution = (MultivariateNormalDistribution)first; + int size = ((Number)second).intValue(); + double[][] samples = new double[size][]; + for(int i=0; i sampleList = new ArrayList(sample.length); + for(int i=0; i tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List> cov = (List>)tuples.get(0).get("h"); + assertEquals(cov.size(), 2); + List row1 = cov.get(0); + assertEquals(row1.size(), 2); + + double a = row1.get(0).doubleValue(); + double b = row1.get(1).doubleValue(); + assertEquals(a, 4.666666666666667, 2.5); + assertEquals(b, 56.66666666666667, 7); + + List row2 = cov.get(1); + + double c = row2.get(0).doubleValue(); + double d = row2.get(1).doubleValue(); + assertEquals(c, 56.66666666666667, 7); + assertEquals(d, 723.8095238095239, 50); + + List sample = (List)tuples.get(0).get("i"); + assertEquals(sample.size(), 2); + Number sample1 = sample.get(0); + Number sample2 = sample.get(1); + assertTrue(sample1.doubleValue() > -30 && sample1.doubleValue() < 30); + assertTrue(sample2.doubleValue() > 50 && sample2.doubleValue() < 250); + } + @Test public void testLoess() throws Exception {