From cb1db482523cf33b7927b5155d506202d8ddbd89 Mon Sep 17 00:00:00 2001 From: Joel Bernstein Date: Wed, 8 Aug 2018 21:05:02 -0400 Subject: [PATCH] SOLR-11863: Add knnRegress Stream Evaluator to support nearest neighbor regression --- .../org/apache/solr/client/solrj/io/Lang.java | 1 + .../client/solrj/io/eval/KnnEvaluator.java | 18 +- .../solrj/io/eval/KnnRegressionEvaluator.java | 194 ++++++++++++++++++ .../solrj/io/eval/MinMaxScaleEvaluator.java | 2 +- .../solrj/io/eval/PredictEvaluator.java | 30 ++- .../apache/solr/client/solrj/io/TestLang.java | 2 +- .../solrj/io/stream/MathExpressionTest.java | 71 +++++++ 7 files changed, 311 insertions(+), 7 deletions(-) create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java index a01a841e120..6f170c46340 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java @@ -246,6 +246,7 @@ public class Lang { .withFunctionName("zeros", ZerosEvaluator.class) .withFunctionName("getValue", GetValueEvaluator.class) .withFunctionName("setValue", SetValueEvaluator.class) + .withFunctionName("knnRegress", KnnRegressionEvaluator.class) // Boolean Stream Evaluators diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java index 251e092f7b3..81607cfdea0 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java @@ -67,8 +67,6 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW throw new IOException("The third parameter for knn should be k."); } - double[][] data = matrix.getData(); - DistanceMeasure distanceMeasure = null; if(values.length == 4) { @@ -77,6 +75,15 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW distanceMeasure = new EuclideanDistance(); } + return search(matrix, vec, k, distanceMeasure); + } + + public static Matrix search(Matrix observations, + double[] vec, + int k, + DistanceMeasure distanceMeasure) { + + double[][] data = observations.getData(); TreeSet neighbors = new TreeSet(); for(int i=0; i rowLabels = matrix.getRowLabels(); + List rowLabels = observations.getRowLabels(); List newRowLabels = new ArrayList(); + List indexes = new ArrayList(); List distances = new ArrayList(); int i=-1; @@ -102,6 +110,7 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW out[++i] = data[rowIndex]; distances.add(neighbor.getDistance()); + indexes.add(rowIndex); } Matrix knn = new Matrix(out); @@ -110,8 +119,9 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW knn.setRowLabels(newRowLabels); } - knn.setColumnLabels(matrix.getColumnLabels()); + knn.setColumnLabels(observations.getColumnLabels()); knn.setAttribute("distances", distances); + knn.setAttribute("indexes", indexes); return knn; } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java new file mode 100644 index 00000000000..957936eefa8 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.HashMap; + +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker { + protected static final long serialVersionUID = 1L; + + public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + } + + @Override + public Object doWork(Object ... values) throws IOException { + + if(values.length < 3) { + throw new IOException("knnRegress expects atleast three parameters: an observation matrix, an outcomes vector and k."); + } + + Matrix observations = null; + List outcomes = null; + int k = 5; + DistanceMeasure distanceMeasure = new EuclideanDistance(); + + if(values[0] instanceof Matrix) { + observations = (Matrix)values[0]; + } else { + throw new IOException("The first parameter for knnRegress should be the observation matrix."); + } + + if(values[1] instanceof List) { + outcomes = (List) values[1]; + } else { + throw new IOException("The second parameter for knnRegress should be outcome array. "); + } + + if(values[2] instanceof Number) { + k = ((Number) values[2]).intValue(); + } else { + throw new IOException("The third parameter for knnRegress should be k. "); + } + + if(values.length == 4) { + if(values[3] instanceof DistanceMeasure) { + distanceMeasure = (DistanceMeasure) values[3]; + throw new IOException("The fourth parameter for knnRegress should be a distance measure. "); + } + } + + double[] outcomeData = new double[outcomes.size()]; + for(int i=0; i map) { + super(map); + this.observations = observations; + this.outcomes = outcomes; + this.k = k; + this.distanceMeasure = distanceMeasure; + } + + //MinMax Scale both the observations and the predictors + + public double[] scale(double[] predictors) { + double[][] data = observations.getData(); + //We need to scale the columns of the data matrix with along with the predictors + Array2DRowRealMatrix matrix = new Array2DRowRealMatrix(data); + Array2DRowRealMatrix transposed = (Array2DRowRealMatrix) matrix.transpose(); + double[][] featureRows = transposed.getDataRef(); + + double[] scaledPredictors = new double[predictors.length]; + + for(int i=0; i indexes = (List)knn.getAttribute("indexes"); + + double sum = 0; + + //Collect the outcomes for the nearest neighbors + for(Number n : indexes) { + sum += outcomes[n.intValue()]; + } + + //Return the average of the outcomes as the prediction. + + return sum/((double)indexes.size()); + } + } +} + diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java index 60c63774fd5..399691022b8 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java @@ -76,7 +76,7 @@ public class MinMaxScaleEvaluator extends RecursiveObjectEvaluator implements Ma } } - private double[] scale(double[] values, double min, double max) { + public static double[] scale(double[] values, double min, double max) { double localMin = Double.MAX_VALUE; double localMax = Double.MIN_VALUE; diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java index 2444370455a..9385928f82b 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java @@ -43,7 +43,11 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa Object first = objects[0]; Object second = objects[1]; - if (!(first instanceof BivariateFunction) && !(first instanceof VectorFunction) && !(first instanceof RegressionEvaluator.RegressionTuple) && !(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple)) { + if (!(first instanceof BivariateFunction) && + !(first instanceof VectorFunction) && + !(first instanceof RegressionEvaluator.RegressionTuple) && + !(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple) && + !(first instanceof KnnRegressionEvaluator.KnnRegressionTuple)) { throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a RegressionTuple", toExpression(constructingFactory), first.getClass().getSimpleName())); } @@ -83,6 +87,30 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa return predictions; } + } else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple) { + KnnRegressionEvaluator.KnnRegressionTuple regressedTuple = (KnnRegressionEvaluator.KnnRegressionTuple) first; + if (second instanceof List) { + List list = (List) second; + double[] predictors = new double[list.size()]; + + for (int i = 0; i < list.size(); i++) { + predictors[i] = list.get(i).doubleValue(); + } + + predictors = regressedTuple.scale(predictors); + + return regressedTuple.predict(predictors); + } else if (second instanceof Matrix) { + + Matrix m = (Matrix) second; + m = regressedTuple.scale(m); + double[][] data = m.getData(); + List predictions = new ArrayList(); + for (double[] predictors : data) { + predictions.add(regressedTuple.predict(predictors)); + } + return predictions; + } } else if (first instanceof VectorFunction) { VectorFunction vectorFunction = (VectorFunction) first; UnivariateFunction univariateFunction = (UnivariateFunction)vectorFunction.getFunction(); diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java index 22b432f3d55..df56844e4dd 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java @@ -69,7 +69,7 @@ public class TestLang extends LuceneTestCase { TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow", "mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt", "cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft", "euclidean","manhattan", - "earthMovers", "canberra", "chebyshev", "ones", "zeros", "setValue", "getValue"}; + "earthMovers", "canberra", "chebyshev", "ones", "zeros", "setValue", "getValue", "knnRegress"}; @Test public void testLang() { diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java index 98a52a6ee35..a9be57e6f20 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java @@ -3394,6 +3394,77 @@ public class MathExpressionTest extends SolrCloudTestCase { assertEquals(predictions.get(9).doubleValue(), 85.86401719768607, .0001); } + @Test + public void testKnnRegress() throws Exception { + String cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 10.19999981), " + + "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 7.699999809, 4.5)," + + "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," + + "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," + + "e=transpose(matrix(a, b, c))," + + "f=knnRegress(e, d, 1)," + + "g=predict(f, e))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List predictions = (List)tuples.get(0).get("g"); + assertEquals(predictions.size(), 10); + //k=1 should bring back only one prediction for the exact match in the training set + assertEquals(predictions.get(0).doubleValue(), 85.09999847, 0); + assertEquals(predictions.get(1).doubleValue(), 106.3000031, 0); + assertEquals(predictions.get(2).doubleValue(), 50.20000076, 0); + assertEquals(predictions.get(3).doubleValue(), 130.6000061, 0); + assertEquals(predictions.get(4).doubleValue(), 54.79999924, 0); + assertEquals(predictions.get(5).doubleValue(), 30.29999924, 0); + assertEquals(predictions.get(6).doubleValue(), 79.40000153, 0); + assertEquals(predictions.get(7).doubleValue(), 91, 0); + assertEquals(predictions.get(8).doubleValue(), 135.3999939, 0); + assertEquals(predictions.get(9).doubleValue(), 89.30000305, 0); + + cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 10.19999981), " + + "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 7.699999809, 4.5)," + + "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," + + "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," + + "e=transpose(matrix(a, b, c))," + + "f=knnRegress(e, d, 1)," + + "g=predict(f, array(8, 5, 4)))"; + paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + solrStream = new SolrStream(url, paramsLoc); + context = new StreamContext(); + solrStream.setStreamContext(context); + tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Number prediction = (Number)tuples.get(0).get("g"); + assertEquals(prediction.doubleValue(), 85.09999847, 0); + + cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 8.19999981), " + + "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 5.699999809, 4.5)," + + "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 4.900000095)," + + "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," + + "e=transpose(matrix(a, b, c))," + + "f=knnRegress(e, d, 2)," + + "g=predict(f, array(8, 5, 4)))"; + paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + solrStream = new SolrStream(url, paramsLoc); + context = new StreamContext(); + solrStream.setStreamContext(context); + tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + prediction = (Number)tuples.get(0).get("g"); + assertEquals(prediction.doubleValue(), 87.20000076, 0); + } + @Test public void testPlot() throws Exception { String cexpr = "let(a=array(3,2,3), plot(type=scatter, x=a, y=array(5,6,3)))";