SOLR-12671: Add robust flag to knnRegress Stream Evaluator

This commit is contained in:
Joel Bernstein 2018-08-17 14:26:05 -04:00
parent 124be4e202
commit 52f9cee97b
4 changed files with 91 additions and 20 deletions

View File

@ -144,6 +144,10 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
}
public int compareTo(Neighbor neighbor) {
if(this.distance.compareTo(neighbor.getDistance()) == 0) {
return row-neighbor.getRow();
}
return this.distance.compareTo(neighbor.getDistance());
}
}

View File

@ -25,15 +25,32 @@ import java.util.HashMap;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L;
private boolean robust=false;
private boolean scale=false;
public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
for(StreamExpressionNamedParameter namedParam : namedParams){
if(namedParam.getName().equals("scale")){
this.scale = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
} else if(namedParam.getName().equals("robust")) {
this.robust = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
} else {
throw new IOException("Unexpected named parameter:"+namedParam.getName());
}
}
}
@Override
@ -84,7 +101,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
map.put("features", observations.getColumnCount());
map.put("distance", distanceMeasure.getClass().getSimpleName());
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map);
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust);
}
@ -95,17 +112,27 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
private double[] outcomes;
private int k;
private DistanceMeasure distanceMeasure;
private boolean scale;
private boolean robust;
public KnnRegressionTuple(Matrix observations,
double[] outcomes,
int k,
DistanceMeasure distanceMeasure,
Map<?,?> map) {
Map<?,?> map,
boolean scale,
boolean robust) {
super(map);
this.observations = observations;
this.outcomes = outcomes;
this.k = k;
this.distanceMeasure = distanceMeasure;
this.scale = scale;
this.robust = robust;
}
public boolean getScale() {
return this.scale;
}
//MinMax Scale both the observations and the predictors
@ -175,19 +202,33 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
public double predict(double[] values) {
Matrix knn = KnnEvaluator.search(scaledObservations, values, k, distanceMeasure);
Matrix obs = scaledObservations != null ? scaledObservations : observations;
Matrix knn = KnnEvaluator.search(obs, values, k, distanceMeasure);
List<Number> indexes = (List<Number>)knn.getAttribute("indexes");
double sum = 0;
if(robust) {
//Get the median of the results.
double[] vals = new double[indexes.size()];
Percentile percentile = new Percentile();
int i=0;
for (Number n : indexes) {
vals[i++]=outcomes[n.intValue()];
}
//Collect the outcomes for the nearest neighbors
for(Number n : indexes) {
sum += outcomes[n.intValue()];
//Return 50 percentile.
return percentile.evaluate(vals, 50);
} else {
//Get the average of the results
double sum = 0;
//Collect the outcomes for the nearest neighbors
for (Number n : indexes) {
sum += outcomes[n.intValue()];
}
//Return the average of the outcomes as the prediction.
return sum / ((double) indexes.size());
}
//Return the average of the outcomes as the prediction.
return sum/((double)indexes.size());
}
}
}

View File

@ -97,13 +97,17 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
predictors[i] = list.get(i).doubleValue();
}
predictors = regressedTuple.scale(predictors);
if(regressedTuple.getScale()) {
predictors = regressedTuple.scale(predictors);
}
return regressedTuple.predict(predictors);
} else if (second instanceof Matrix) {
Matrix m = (Matrix) second;
m = regressedTuple.scale(m);
if(regressedTuple.getScale()) {
m = regressedTuple.scale(m);
}
double[][] data = m.getData();
List<Number> predictions = new ArrayList();
for (double[] predictors : data) {

View File

@ -3450,7 +3450,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
"e=transpose(matrix(a, b, c))," +
"f=knnRegress(e, d, 1)," +
"f=knnRegress(e, d, 1, scale=true)," +
"g=predict(f, e))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
@ -3480,7 +3480,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
"e=transpose(matrix(a, b, c))," +
"f=knnRegress(e, d, 1)," +
"f=knnRegress(e, d, 1, scale=true)," +
"g=predict(f, array(8, 5, 4)))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
@ -3494,12 +3494,14 @@ public class MathExpressionTest extends SolrCloudTestCase {
Number prediction = (Number)tuples.get(0).get("g");
assertEquals(prediction.doubleValue(), 85.09999847, 0);
cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 8.19999981), " +
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 5.699999809, 4.5)," +
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 4.900000095)," +
//Test robust. Take the median rather then average
cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 8.10000038, 8.19999981), " +
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 5.599999905, 5.699999809, 4.5)," +
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 4.79999924, 4.900000095)," +
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
"e=transpose(matrix(a, b, c))," +
"f=knnRegress(e, d, 2)," +
"f=knnRegress(e, d, 3, scale=true, robust=true)," +
"g=predict(f, array(8, 5, 4)))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
@ -3511,7 +3513,27 @@ public class MathExpressionTest extends SolrCloudTestCase {
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
prediction = (Number)tuples.get(0).get("g");
assertEquals(prediction.doubleValue(), 87.20000076, 0);
assertEquals(prediction.doubleValue(), 89.30000305, 0);
//Test univariate regression with scaling off
cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
"b=transpose(matrix(a))," +
"c=knnRegress(b, a, 3)," +
"d=predict(c, array(3)))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
prediction = (Number)tuples.get(0).get("d");
assertEquals(prediction.doubleValue(), 3, 0);
}
@Test