SOLR-11863: Add knnRegress Stream Evaluator to support nearest neighbor regression

This commit is contained in:
Joel Bernstein 2018-08-08 21:05:02 -04:00
parent e9f3a3ce1d
commit cb1db48252
7 changed files with 311 additions and 7 deletions

View File

@ -246,6 +246,7 @@ public class Lang {
.withFunctionName("zeros", ZerosEvaluator.class) .withFunctionName("zeros", ZerosEvaluator.class)
.withFunctionName("getValue", GetValueEvaluator.class) .withFunctionName("getValue", GetValueEvaluator.class)
.withFunctionName("setValue", SetValueEvaluator.class) .withFunctionName("setValue", SetValueEvaluator.class)
.withFunctionName("knnRegress", KnnRegressionEvaluator.class)
// Boolean Stream Evaluators // Boolean Stream Evaluators

View File

@ -67,8 +67,6 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
throw new IOException("The third parameter for knn should be k."); throw new IOException("The third parameter for knn should be k.");
} }
double[][] data = matrix.getData();
DistanceMeasure distanceMeasure = null; DistanceMeasure distanceMeasure = null;
if(values.length == 4) { if(values.length == 4) {
@ -77,6 +75,15 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
distanceMeasure = new EuclideanDistance(); distanceMeasure = new EuclideanDistance();
} }
return search(matrix, vec, k, distanceMeasure);
}
public static Matrix search(Matrix observations,
double[] vec,
int k,
DistanceMeasure distanceMeasure) {
double[][] data = observations.getData();
TreeSet<Neighbor> neighbors = new TreeSet(); TreeSet<Neighbor> neighbors = new TreeSet();
for(int i=0; i<data.length; i++) { for(int i=0; i<data.length; i++) {
double distance = distanceMeasure.compute(vec, data[i]); double distance = distanceMeasure.compute(vec, data[i]);
@ -87,8 +94,9 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
} }
double[][] out = new double[neighbors.size()][]; double[][] out = new double[neighbors.size()][];
List<String> rowLabels = matrix.getRowLabels(); List<String> rowLabels = observations.getRowLabels();
List<String> newRowLabels = new ArrayList(); List<String> newRowLabels = new ArrayList();
List<Number> indexes = new ArrayList();
List<Number> distances = new ArrayList(); List<Number> distances = new ArrayList();
int i=-1; int i=-1;
@ -102,6 +110,7 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
out[++i] = data[rowIndex]; out[++i] = data[rowIndex];
distances.add(neighbor.getDistance()); distances.add(neighbor.getDistance());
indexes.add(rowIndex);
} }
Matrix knn = new Matrix(out); Matrix knn = new Matrix(out);
@ -110,8 +119,9 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
knn.setRowLabels(newRowLabels); knn.setRowLabels(newRowLabels);
} }
knn.setColumnLabels(matrix.getColumnLabels()); knn.setColumnLabels(observations.getColumnLabels());
knn.setAttribute("distances", distances); knn.setAttribute("distances", distances);
knn.setAttribute("indexes", indexes);
return knn; return knn;
} }

View File

@ -0,0 +1,194 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L;
public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
@Override
public Object doWork(Object ... values) throws IOException {
if(values.length < 3) {
throw new IOException("knnRegress expects atleast three parameters: an observation matrix, an outcomes vector and k.");
}
Matrix observations = null;
List<Number> outcomes = null;
int k = 5;
DistanceMeasure distanceMeasure = new EuclideanDistance();
if(values[0] instanceof Matrix) {
observations = (Matrix)values[0];
} else {
throw new IOException("The first parameter for knnRegress should be the observation matrix.");
}
if(values[1] instanceof List) {
outcomes = (List) values[1];
} else {
throw new IOException("The second parameter for knnRegress should be outcome array. ");
}
if(values[2] instanceof Number) {
k = ((Number) values[2]).intValue();
} else {
throw new IOException("The third parameter for knnRegress should be k. ");
}
if(values.length == 4) {
if(values[3] instanceof DistanceMeasure) {
distanceMeasure = (DistanceMeasure) values[3];
throw new IOException("The fourth parameter for knnRegress should be a distance measure. ");
}
}
double[] outcomeData = new double[outcomes.size()];
for(int i=0; i<outcomeData.length; i++) {
outcomeData[i] = outcomes.get(i).doubleValue();
}
Map map = new HashMap();
map.put("k", k);
map.put("observations", observations.getRowCount());
map.put("features", observations.getColumnCount());
map.put("distance", distanceMeasure.getClass().getSimpleName());
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map);
}
public static class KnnRegressionTuple extends Tuple {
private Matrix observations;
private Matrix scaledObservations;
private double[] outcomes;
private int k;
private DistanceMeasure distanceMeasure;
public KnnRegressionTuple(Matrix observations,
double[] outcomes,
int k,
DistanceMeasure distanceMeasure,
Map<?,?> map) {
super(map);
this.observations = observations;
this.outcomes = outcomes;
this.k = k;
this.distanceMeasure = distanceMeasure;
}
//MinMax Scale both the observations and the predictors
public double[] scale(double[] predictors) {
double[][] data = observations.getData();
//We need to scale the columns of the data matrix with along with the predictors
Array2DRowRealMatrix matrix = new Array2DRowRealMatrix(data);
Array2DRowRealMatrix transposed = (Array2DRowRealMatrix) matrix.transpose();
double[][] featureRows = transposed.getDataRef();
double[] scaledPredictors = new double[predictors.length];
for(int i=0; i<featureRows.length; i++) {
double[] featureRow = featureRows[i];
double[] combinedFeatureRow = new double[featureRow.length+1];
System.arraycopy(featureRow, 0, combinedFeatureRow, 0, featureRow.length);
combinedFeatureRow[featureRow.length] = predictors[i]; // Add the last feature from the predictor
double[] scaledFeatures = MinMaxScaleEvaluator.scale(combinedFeatureRow, 0, 1);
scaledPredictors[i] = scaledFeatures[featureRow.length];
System.arraycopy(scaledFeatures, 0, featureRow, 0, featureRow.length);
}
Array2DRowRealMatrix scaledFeatureMatrix = new Array2DRowRealMatrix(featureRows);
Array2DRowRealMatrix scaledObservationsMatrix= (Array2DRowRealMatrix)scaledFeatureMatrix.transpose();
this.scaledObservations = new Matrix(scaledObservationsMatrix.getDataRef());
return scaledPredictors;
}
public Matrix scale(Matrix predictors) {
double[][] observationData = observations.getData();
//We need to scale the columns of the data matrix with along with the predictors
Array2DRowRealMatrix observationMatrix = new Array2DRowRealMatrix(observationData);
Array2DRowRealMatrix observationTransposed = (Array2DRowRealMatrix) observationMatrix.transpose();
double[][] observationFeatureRows = observationTransposed.getDataRef();
double[][] predictorsData = predictors.getData();
//We need to scale the columns of the data matrix with along with the predictors
Array2DRowRealMatrix predictorMatrix = new Array2DRowRealMatrix(predictorsData);
Array2DRowRealMatrix predictorTransposed = (Array2DRowRealMatrix) predictorMatrix.transpose();
double[][] predictorFeatureRows = predictorTransposed.getDataRef();
for(int i=0; i<observationFeatureRows.length; i++) {
double[] observationFeatureRow = observationFeatureRows[i];
double[] predictorFeatureRow = predictorFeatureRows[i];
double[] combinedFeatureRow = new double[observationFeatureRow.length+predictorFeatureRow.length];
System.arraycopy(observationFeatureRow, 0, combinedFeatureRow, 0, observationFeatureRow.length);
System.arraycopy(predictorFeatureRow, 0, combinedFeatureRow, observationFeatureRow.length, predictorFeatureRow.length);
double[] scaledFeatures = MinMaxScaleEvaluator.scale(combinedFeatureRow, 0, 1);
System.arraycopy(scaledFeatures, 0, observationFeatureRow, 0, observationFeatureRow.length);
System.arraycopy(scaledFeatures, observationFeatureRow.length, predictorFeatureRow, 0, predictorFeatureRow.length);
}
Array2DRowRealMatrix scaledFeatureMatrix = new Array2DRowRealMatrix(observationFeatureRows);
Array2DRowRealMatrix scaledObservationsMatrix= (Array2DRowRealMatrix)scaledFeatureMatrix.transpose();
this.scaledObservations = new Matrix(scaledObservationsMatrix.getDataRef());
Array2DRowRealMatrix scaledPredictorMatrix = new Array2DRowRealMatrix(predictorFeatureRows);
Array2DRowRealMatrix scaledTransposedPredictorMatrix= (Array2DRowRealMatrix)scaledPredictorMatrix.transpose();
return new Matrix(scaledTransposedPredictorMatrix.getDataRef());
}
public double predict(double[] values) {
Matrix knn = KnnEvaluator.search(scaledObservations, values, k, distanceMeasure);
List<Number> indexes = (List<Number>)knn.getAttribute("indexes");
double sum = 0;
//Collect the outcomes for the nearest neighbors
for(Number n : indexes) {
sum += outcomes[n.intValue()];
}
//Return the average of the outcomes as the prediction.
return sum/((double)indexes.size());
}
}
}

View File

@ -76,7 +76,7 @@ public class MinMaxScaleEvaluator extends RecursiveObjectEvaluator implements Ma
} }
} }
private double[] scale(double[] values, double min, double max) { public static double[] scale(double[] values, double min, double max) {
double localMin = Double.MAX_VALUE; double localMin = Double.MAX_VALUE;
double localMax = Double.MIN_VALUE; double localMax = Double.MIN_VALUE;

View File

@ -43,7 +43,11 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
Object first = objects[0]; Object first = objects[0];
Object second = objects[1]; Object second = objects[1];
if (!(first instanceof BivariateFunction) && !(first instanceof VectorFunction) && !(first instanceof RegressionEvaluator.RegressionTuple) && !(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple)) { if (!(first instanceof BivariateFunction) &&
!(first instanceof VectorFunction) &&
!(first instanceof RegressionEvaluator.RegressionTuple) &&
!(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple) &&
!(first instanceof KnnRegressionEvaluator.KnnRegressionTuple)) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a RegressionTuple", toExpression(constructingFactory), first.getClass().getSimpleName())); throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a RegressionTuple", toExpression(constructingFactory), first.getClass().getSimpleName()));
} }
@ -83,6 +87,30 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
return predictions; return predictions;
} }
} else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple) {
KnnRegressionEvaluator.KnnRegressionTuple regressedTuple = (KnnRegressionEvaluator.KnnRegressionTuple) first;
if (second instanceof List) {
List<Number> list = (List<Number>) second;
double[] predictors = new double[list.size()];
for (int i = 0; i < list.size(); i++) {
predictors[i] = list.get(i).doubleValue();
}
predictors = regressedTuple.scale(predictors);
return regressedTuple.predict(predictors);
} else if (second instanceof Matrix) {
Matrix m = (Matrix) second;
m = regressedTuple.scale(m);
double[][] data = m.getData();
List<Number> predictions = new ArrayList();
for (double[] predictors : data) {
predictions.add(regressedTuple.predict(predictors));
}
return predictions;
}
} else if (first instanceof VectorFunction) { } else if (first instanceof VectorFunction) {
VectorFunction vectorFunction = (VectorFunction) first; VectorFunction vectorFunction = (VectorFunction) first;
UnivariateFunction univariateFunction = (UnivariateFunction)vectorFunction.getFunction(); UnivariateFunction univariateFunction = (UnivariateFunction)vectorFunction.getFunction();

View File

@ -69,7 +69,7 @@ public class TestLang extends LuceneTestCase {
TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow", TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow",
"mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt", "mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt",
"cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft", "euclidean","manhattan", "cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft", "euclidean","manhattan",
"earthMovers", "canberra", "chebyshev", "ones", "zeros", "setValue", "getValue"}; "earthMovers", "canberra", "chebyshev", "ones", "zeros", "setValue", "getValue", "knnRegress"};
@Test @Test
public void testLang() { public void testLang() {

View File

@ -3394,6 +3394,77 @@ public class MathExpressionTest extends SolrCloudTestCase {
assertEquals(predictions.get(9).doubleValue(), 85.86401719768607, .0001); assertEquals(predictions.get(9).doubleValue(), 85.86401719768607, .0001);
} }
@Test
public void testKnnRegress() throws Exception {
String cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 10.19999981), " +
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 7.699999809, 4.5)," +
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
"e=transpose(matrix(a, b, c))," +
"f=knnRegress(e, d, 1)," +
"g=predict(f, e))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<Number> predictions = (List<Number>)tuples.get(0).get("g");
assertEquals(predictions.size(), 10);
//k=1 should bring back only one prediction for the exact match in the training set
assertEquals(predictions.get(0).doubleValue(), 85.09999847, 0);
assertEquals(predictions.get(1).doubleValue(), 106.3000031, 0);
assertEquals(predictions.get(2).doubleValue(), 50.20000076, 0);
assertEquals(predictions.get(3).doubleValue(), 130.6000061, 0);
assertEquals(predictions.get(4).doubleValue(), 54.79999924, 0);
assertEquals(predictions.get(5).doubleValue(), 30.29999924, 0);
assertEquals(predictions.get(6).doubleValue(), 79.40000153, 0);
assertEquals(predictions.get(7).doubleValue(), 91, 0);
assertEquals(predictions.get(8).doubleValue(), 135.3999939, 0);
assertEquals(predictions.get(9).doubleValue(), 89.30000305, 0);
cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 10.19999981), " +
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 7.699999809, 4.5)," +
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
"e=transpose(matrix(a, b, c))," +
"f=knnRegress(e, d, 1)," +
"g=predict(f, array(8, 5, 4)))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Number prediction = (Number)tuples.get(0).get("g");
assertEquals(prediction.doubleValue(), 85.09999847, 0);
cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 8.19999981), " +
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 5.699999809, 4.5)," +
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 4.900000095)," +
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
"e=transpose(matrix(a, b, c))," +
"f=knnRegress(e, d, 2)," +
"g=predict(f, array(8, 5, 4)))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
prediction = (Number)tuples.get(0).get("g");
assertEquals(prediction.doubleValue(), 87.20000076, 0);
}
@Test @Test
public void testPlot() throws Exception { public void testPlot() throws Exception {
String cexpr = "let(a=array(3,2,3), plot(type=scatter, x=a, y=array(5,6,3)))"; String cexpr = "let(a=array(3,2,3), plot(type=scatter, x=a, y=array(5,6,3)))";