mirror of https://github.com/apache/lucene.git
SOLR-11863: Add knnRegress Stream Evaluator to support nearest neighbor regression
This commit is contained in:
parent
e9f3a3ce1d
commit
cb1db48252
|
@ -246,6 +246,7 @@ public class Lang {
|
|||
.withFunctionName("zeros", ZerosEvaluator.class)
|
||||
.withFunctionName("getValue", GetValueEvaluator.class)
|
||||
.withFunctionName("setValue", SetValueEvaluator.class)
|
||||
.withFunctionName("knnRegress", KnnRegressionEvaluator.class)
|
||||
|
||||
// Boolean Stream Evaluators
|
||||
|
||||
|
|
|
@ -67,8 +67,6 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
|
|||
throw new IOException("The third parameter for knn should be k.");
|
||||
}
|
||||
|
||||
double[][] data = matrix.getData();
|
||||
|
||||
DistanceMeasure distanceMeasure = null;
|
||||
|
||||
if(values.length == 4) {
|
||||
|
@ -77,6 +75,15 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
|
|||
distanceMeasure = new EuclideanDistance();
|
||||
}
|
||||
|
||||
return search(matrix, vec, k, distanceMeasure);
|
||||
}
|
||||
|
||||
public static Matrix search(Matrix observations,
|
||||
double[] vec,
|
||||
int k,
|
||||
DistanceMeasure distanceMeasure) {
|
||||
|
||||
double[][] data = observations.getData();
|
||||
TreeSet<Neighbor> neighbors = new TreeSet();
|
||||
for(int i=0; i<data.length; i++) {
|
||||
double distance = distanceMeasure.compute(vec, data[i]);
|
||||
|
@ -87,8 +94,9 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
|
|||
}
|
||||
|
||||
double[][] out = new double[neighbors.size()][];
|
||||
List<String> rowLabels = matrix.getRowLabels();
|
||||
List<String> rowLabels = observations.getRowLabels();
|
||||
List<String> newRowLabels = new ArrayList();
|
||||
List<Number> indexes = new ArrayList();
|
||||
List<Number> distances = new ArrayList();
|
||||
int i=-1;
|
||||
|
||||
|
@ -102,6 +110,7 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
|
|||
|
||||
out[++i] = data[rowIndex];
|
||||
distances.add(neighbor.getDistance());
|
||||
indexes.add(rowIndex);
|
||||
}
|
||||
|
||||
Matrix knn = new Matrix(out);
|
||||
|
@ -110,8 +119,9 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
|
|||
knn.setRowLabels(newRowLabels);
|
||||
}
|
||||
|
||||
knn.setColumnLabels(matrix.getColumnLabels());
|
||||
knn.setColumnLabels(observations.getColumnLabels());
|
||||
knn.setAttribute("distances", distances);
|
||||
knn.setAttribute("indexes", indexes);
|
||||
return knn;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object ... values) throws IOException {
|
||||
|
||||
if(values.length < 3) {
|
||||
throw new IOException("knnRegress expects atleast three parameters: an observation matrix, an outcomes vector and k.");
|
||||
}
|
||||
|
||||
Matrix observations = null;
|
||||
List<Number> outcomes = null;
|
||||
int k = 5;
|
||||
DistanceMeasure distanceMeasure = new EuclideanDistance();
|
||||
|
||||
if(values[0] instanceof Matrix) {
|
||||
observations = (Matrix)values[0];
|
||||
} else {
|
||||
throw new IOException("The first parameter for knnRegress should be the observation matrix.");
|
||||
}
|
||||
|
||||
if(values[1] instanceof List) {
|
||||
outcomes = (List) values[1];
|
||||
} else {
|
||||
throw new IOException("The second parameter for knnRegress should be outcome array. ");
|
||||
}
|
||||
|
||||
if(values[2] instanceof Number) {
|
||||
k = ((Number) values[2]).intValue();
|
||||
} else {
|
||||
throw new IOException("The third parameter for knnRegress should be k. ");
|
||||
}
|
||||
|
||||
if(values.length == 4) {
|
||||
if(values[3] instanceof DistanceMeasure) {
|
||||
distanceMeasure = (DistanceMeasure) values[3];
|
||||
throw new IOException("The fourth parameter for knnRegress should be a distance measure. ");
|
||||
}
|
||||
}
|
||||
|
||||
double[] outcomeData = new double[outcomes.size()];
|
||||
for(int i=0; i<outcomeData.length; i++) {
|
||||
outcomeData[i] = outcomes.get(i).doubleValue();
|
||||
}
|
||||
|
||||
Map map = new HashMap();
|
||||
map.put("k", k);
|
||||
map.put("observations", observations.getRowCount());
|
||||
map.put("features", observations.getColumnCount());
|
||||
map.put("distance", distanceMeasure.getClass().getSimpleName());
|
||||
|
||||
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map);
|
||||
}
|
||||
|
||||
|
||||
public static class KnnRegressionTuple extends Tuple {
|
||||
|
||||
private Matrix observations;
|
||||
private Matrix scaledObservations;
|
||||
private double[] outcomes;
|
||||
private int k;
|
||||
private DistanceMeasure distanceMeasure;
|
||||
|
||||
public KnnRegressionTuple(Matrix observations,
|
||||
double[] outcomes,
|
||||
int k,
|
||||
DistanceMeasure distanceMeasure,
|
||||
Map<?,?> map) {
|
||||
super(map);
|
||||
this.observations = observations;
|
||||
this.outcomes = outcomes;
|
||||
this.k = k;
|
||||
this.distanceMeasure = distanceMeasure;
|
||||
}
|
||||
|
||||
//MinMax Scale both the observations and the predictors
|
||||
|
||||
public double[] scale(double[] predictors) {
|
||||
double[][] data = observations.getData();
|
||||
//We need to scale the columns of the data matrix with along with the predictors
|
||||
Array2DRowRealMatrix matrix = new Array2DRowRealMatrix(data);
|
||||
Array2DRowRealMatrix transposed = (Array2DRowRealMatrix) matrix.transpose();
|
||||
double[][] featureRows = transposed.getDataRef();
|
||||
|
||||
double[] scaledPredictors = new double[predictors.length];
|
||||
|
||||
for(int i=0; i<featureRows.length; i++) {
|
||||
double[] featureRow = featureRows[i];
|
||||
double[] combinedFeatureRow = new double[featureRow.length+1];
|
||||
System.arraycopy(featureRow, 0, combinedFeatureRow, 0, featureRow.length);
|
||||
combinedFeatureRow[featureRow.length] = predictors[i]; // Add the last feature from the predictor
|
||||
double[] scaledFeatures = MinMaxScaleEvaluator.scale(combinedFeatureRow, 0, 1);
|
||||
scaledPredictors[i] = scaledFeatures[featureRow.length];
|
||||
System.arraycopy(scaledFeatures, 0, featureRow, 0, featureRow.length);
|
||||
}
|
||||
|
||||
Array2DRowRealMatrix scaledFeatureMatrix = new Array2DRowRealMatrix(featureRows);
|
||||
|
||||
|
||||
Array2DRowRealMatrix scaledObservationsMatrix= (Array2DRowRealMatrix)scaledFeatureMatrix.transpose();
|
||||
this.scaledObservations = new Matrix(scaledObservationsMatrix.getDataRef());
|
||||
return scaledPredictors;
|
||||
}
|
||||
|
||||
|
||||
public Matrix scale(Matrix predictors) {
|
||||
double[][] observationData = observations.getData();
|
||||
//We need to scale the columns of the data matrix with along with the predictors
|
||||
Array2DRowRealMatrix observationMatrix = new Array2DRowRealMatrix(observationData);
|
||||
Array2DRowRealMatrix observationTransposed = (Array2DRowRealMatrix) observationMatrix.transpose();
|
||||
double[][] observationFeatureRows = observationTransposed.getDataRef();
|
||||
|
||||
double[][] predictorsData = predictors.getData();
|
||||
//We need to scale the columns of the data matrix with along with the predictors
|
||||
Array2DRowRealMatrix predictorMatrix = new Array2DRowRealMatrix(predictorsData);
|
||||
Array2DRowRealMatrix predictorTransposed = (Array2DRowRealMatrix) predictorMatrix.transpose();
|
||||
double[][] predictorFeatureRows = predictorTransposed.getDataRef();
|
||||
|
||||
for(int i=0; i<observationFeatureRows.length; i++) {
|
||||
double[] observationFeatureRow = observationFeatureRows[i];
|
||||
double[] predictorFeatureRow = predictorFeatureRows[i];
|
||||
double[] combinedFeatureRow = new double[observationFeatureRow.length+predictorFeatureRow.length];
|
||||
System.arraycopy(observationFeatureRow, 0, combinedFeatureRow, 0, observationFeatureRow.length);
|
||||
System.arraycopy(predictorFeatureRow, 0, combinedFeatureRow, observationFeatureRow.length, predictorFeatureRow.length);
|
||||
|
||||
double[] scaledFeatures = MinMaxScaleEvaluator.scale(combinedFeatureRow, 0, 1);
|
||||
System.arraycopy(scaledFeatures, 0, observationFeatureRow, 0, observationFeatureRow.length);
|
||||
System.arraycopy(scaledFeatures, observationFeatureRow.length, predictorFeatureRow, 0, predictorFeatureRow.length);
|
||||
}
|
||||
|
||||
Array2DRowRealMatrix scaledFeatureMatrix = new Array2DRowRealMatrix(observationFeatureRows);
|
||||
Array2DRowRealMatrix scaledObservationsMatrix= (Array2DRowRealMatrix)scaledFeatureMatrix.transpose();
|
||||
this.scaledObservations = new Matrix(scaledObservationsMatrix.getDataRef());
|
||||
|
||||
Array2DRowRealMatrix scaledPredictorMatrix = new Array2DRowRealMatrix(predictorFeatureRows);
|
||||
Array2DRowRealMatrix scaledTransposedPredictorMatrix= (Array2DRowRealMatrix)scaledPredictorMatrix.transpose();
|
||||
return new Matrix(scaledTransposedPredictorMatrix.getDataRef());
|
||||
}
|
||||
|
||||
|
||||
public double predict(double[] values) {
|
||||
|
||||
Matrix knn = KnnEvaluator.search(scaledObservations, values, k, distanceMeasure);
|
||||
List<Number> indexes = (List<Number>)knn.getAttribute("indexes");
|
||||
|
||||
double sum = 0;
|
||||
|
||||
//Collect the outcomes for the nearest neighbors
|
||||
for(Number n : indexes) {
|
||||
sum += outcomes[n.intValue()];
|
||||
}
|
||||
|
||||
//Return the average of the outcomes as the prediction.
|
||||
|
||||
return sum/((double)indexes.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -76,7 +76,7 @@ public class MinMaxScaleEvaluator extends RecursiveObjectEvaluator implements Ma
|
|||
}
|
||||
}
|
||||
|
||||
private double[] scale(double[] values, double min, double max) {
|
||||
public static double[] scale(double[] values, double min, double max) {
|
||||
|
||||
double localMin = Double.MAX_VALUE;
|
||||
double localMax = Double.MIN_VALUE;
|
||||
|
|
|
@ -43,7 +43,11 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
|
|||
Object first = objects[0];
|
||||
Object second = objects[1];
|
||||
|
||||
if (!(first instanceof BivariateFunction) && !(first instanceof VectorFunction) && !(first instanceof RegressionEvaluator.RegressionTuple) && !(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple)) {
|
||||
if (!(first instanceof BivariateFunction) &&
|
||||
!(first instanceof VectorFunction) &&
|
||||
!(first instanceof RegressionEvaluator.RegressionTuple) &&
|
||||
!(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple) &&
|
||||
!(first instanceof KnnRegressionEvaluator.KnnRegressionTuple)) {
|
||||
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a RegressionTuple", toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
|
@ -83,6 +87,30 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
|
|||
return predictions;
|
||||
}
|
||||
|
||||
} else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple) {
|
||||
KnnRegressionEvaluator.KnnRegressionTuple regressedTuple = (KnnRegressionEvaluator.KnnRegressionTuple) first;
|
||||
if (second instanceof List) {
|
||||
List<Number> list = (List<Number>) second;
|
||||
double[] predictors = new double[list.size()];
|
||||
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
predictors[i] = list.get(i).doubleValue();
|
||||
}
|
||||
|
||||
predictors = regressedTuple.scale(predictors);
|
||||
|
||||
return regressedTuple.predict(predictors);
|
||||
} else if (second instanceof Matrix) {
|
||||
|
||||
Matrix m = (Matrix) second;
|
||||
m = regressedTuple.scale(m);
|
||||
double[][] data = m.getData();
|
||||
List<Number> predictions = new ArrayList();
|
||||
for (double[] predictors : data) {
|
||||
predictions.add(regressedTuple.predict(predictors));
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
} else if (first instanceof VectorFunction) {
|
||||
VectorFunction vectorFunction = (VectorFunction) first;
|
||||
UnivariateFunction univariateFunction = (UnivariateFunction)vectorFunction.getFunction();
|
||||
|
|
|
@ -69,7 +69,7 @@ public class TestLang extends LuceneTestCase {
|
|||
TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow",
|
||||
"mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt",
|
||||
"cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft", "euclidean","manhattan",
|
||||
"earthMovers", "canberra", "chebyshev", "ones", "zeros", "setValue", "getValue"};
|
||||
"earthMovers", "canberra", "chebyshev", "ones", "zeros", "setValue", "getValue", "knnRegress"};
|
||||
|
||||
@Test
|
||||
public void testLang() {
|
||||
|
|
|
@ -3394,6 +3394,77 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
assertEquals(predictions.get(9).doubleValue(), 85.86401719768607, .0001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKnnRegress() throws Exception {
|
||||
String cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 10.19999981), " +
|
||||
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 7.699999809, 4.5)," +
|
||||
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
|
||||
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
||||
"e=transpose(matrix(a, b, c))," +
|
||||
"f=knnRegress(e, d, 1)," +
|
||||
"g=predict(f, e))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Number> predictions = (List<Number>)tuples.get(0).get("g");
|
||||
assertEquals(predictions.size(), 10);
|
||||
//k=1 should bring back only one prediction for the exact match in the training set
|
||||
assertEquals(predictions.get(0).doubleValue(), 85.09999847, 0);
|
||||
assertEquals(predictions.get(1).doubleValue(), 106.3000031, 0);
|
||||
assertEquals(predictions.get(2).doubleValue(), 50.20000076, 0);
|
||||
assertEquals(predictions.get(3).doubleValue(), 130.6000061, 0);
|
||||
assertEquals(predictions.get(4).doubleValue(), 54.79999924, 0);
|
||||
assertEquals(predictions.get(5).doubleValue(), 30.29999924, 0);
|
||||
assertEquals(predictions.get(6).doubleValue(), 79.40000153, 0);
|
||||
assertEquals(predictions.get(7).doubleValue(), 91, 0);
|
||||
assertEquals(predictions.get(8).doubleValue(), 135.3999939, 0);
|
||||
assertEquals(predictions.get(9).doubleValue(), 89.30000305, 0);
|
||||
|
||||
cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 10.19999981), " +
|
||||
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 7.699999809, 4.5)," +
|
||||
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
|
||||
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
||||
"e=transpose(matrix(a, b, c))," +
|
||||
"f=knnRegress(e, d, 1)," +
|
||||
"g=predict(f, array(8, 5, 4)))";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
solrStream = new SolrStream(url, paramsLoc);
|
||||
context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Number prediction = (Number)tuples.get(0).get("g");
|
||||
assertEquals(prediction.doubleValue(), 85.09999847, 0);
|
||||
|
||||
cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 8.19999981), " +
|
||||
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 5.699999809, 4.5)," +
|
||||
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 4.900000095)," +
|
||||
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
||||
"e=transpose(matrix(a, b, c))," +
|
||||
"f=knnRegress(e, d, 2)," +
|
||||
"g=predict(f, array(8, 5, 4)))";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
solrStream = new SolrStream(url, paramsLoc);
|
||||
context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
prediction = (Number)tuples.get(0).get("g");
|
||||
assertEquals(prediction.doubleValue(), 87.20000076, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPlot() throws Exception {
|
||||
String cexpr = "let(a=array(3,2,3), plot(type=scatter, x=a, y=array(5,6,3)))";
|
||||
|
|
Loading…
Reference in New Issue