mirror of https://github.com/apache/lucene.git
SOLR-12671: Add robust flag to knnRegress Stream Evaluator
This commit is contained in:
parent
124be4e202
commit
52f9cee97b
|
@ -144,6 +144,10 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
|
||||||
}
|
}
|
||||||
|
|
||||||
public int compareTo(Neighbor neighbor) {
|
public int compareTo(Neighbor neighbor) {
|
||||||
|
if(this.distance.compareTo(neighbor.getDistance()) == 0) {
|
||||||
|
return row-neighbor.getRow();
|
||||||
|
}
|
||||||
|
|
||||||
return this.distance.compareTo(neighbor.getDistance());
|
return this.distance.compareTo(neighbor.getDistance());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,15 +25,32 @@ import java.util.HashMap;
|
||||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||||
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||||
|
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
|
||||||
import org.apache.solr.client.solrj.io.Tuple;
|
import org.apache.solr.client.solrj.io.Tuple;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
|
||||||
|
|
||||||
public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
|
public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
|
||||||
protected static final long serialVersionUID = 1L;
|
protected static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
private boolean robust=false;
|
||||||
|
private boolean scale=false;
|
||||||
|
|
||||||
public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||||
super(expression, factory);
|
super(expression, factory);
|
||||||
|
|
||||||
|
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
|
||||||
|
|
||||||
|
for(StreamExpressionNamedParameter namedParam : namedParams){
|
||||||
|
if(namedParam.getName().equals("scale")){
|
||||||
|
this.scale = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
|
||||||
|
} else if(namedParam.getName().equals("robust")) {
|
||||||
|
this.robust = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
|
||||||
|
} else {
|
||||||
|
throw new IOException("Unexpected named parameter:"+namedParam.getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -84,7 +101,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
||||||
map.put("features", observations.getColumnCount());
|
map.put("features", observations.getColumnCount());
|
||||||
map.put("distance", distanceMeasure.getClass().getSimpleName());
|
map.put("distance", distanceMeasure.getClass().getSimpleName());
|
||||||
|
|
||||||
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map);
|
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,17 +112,27 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
||||||
private double[] outcomes;
|
private double[] outcomes;
|
||||||
private int k;
|
private int k;
|
||||||
private DistanceMeasure distanceMeasure;
|
private DistanceMeasure distanceMeasure;
|
||||||
|
private boolean scale;
|
||||||
|
private boolean robust;
|
||||||
|
|
||||||
public KnnRegressionTuple(Matrix observations,
|
public KnnRegressionTuple(Matrix observations,
|
||||||
double[] outcomes,
|
double[] outcomes,
|
||||||
int k,
|
int k,
|
||||||
DistanceMeasure distanceMeasure,
|
DistanceMeasure distanceMeasure,
|
||||||
Map<?,?> map) {
|
Map<?,?> map,
|
||||||
|
boolean scale,
|
||||||
|
boolean robust) {
|
||||||
super(map);
|
super(map);
|
||||||
this.observations = observations;
|
this.observations = observations;
|
||||||
this.outcomes = outcomes;
|
this.outcomes = outcomes;
|
||||||
this.k = k;
|
this.k = k;
|
||||||
this.distanceMeasure = distanceMeasure;
|
this.distanceMeasure = distanceMeasure;
|
||||||
|
this.scale = scale;
|
||||||
|
this.robust = robust;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean getScale() {
|
||||||
|
return this.scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
//MinMax Scale both the observations and the predictors
|
//MinMax Scale both the observations and the predictors
|
||||||
|
@ -175,19 +202,33 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
||||||
|
|
||||||
public double predict(double[] values) {
|
public double predict(double[] values) {
|
||||||
|
|
||||||
Matrix knn = KnnEvaluator.search(scaledObservations, values, k, distanceMeasure);
|
Matrix obs = scaledObservations != null ? scaledObservations : observations;
|
||||||
|
Matrix knn = KnnEvaluator.search(obs, values, k, distanceMeasure);
|
||||||
List<Number> indexes = (List<Number>)knn.getAttribute("indexes");
|
List<Number> indexes = (List<Number>)knn.getAttribute("indexes");
|
||||||
|
|
||||||
|
if(robust) {
|
||||||
|
//Get the median of the results.
|
||||||
|
double[] vals = new double[indexes.size()];
|
||||||
|
Percentile percentile = new Percentile();
|
||||||
|
int i=0;
|
||||||
|
for (Number n : indexes) {
|
||||||
|
vals[i++]=outcomes[n.intValue()];
|
||||||
|
}
|
||||||
|
|
||||||
|
//Return 50 percentile.
|
||||||
|
return percentile.evaluate(vals, 50);
|
||||||
|
} else {
|
||||||
|
//Get the average of the results
|
||||||
double sum = 0;
|
double sum = 0;
|
||||||
|
|
||||||
//Collect the outcomes for the nearest neighbors
|
//Collect the outcomes for the nearest neighbors
|
||||||
for(Number n : indexes) {
|
for (Number n : indexes) {
|
||||||
sum += outcomes[n.intValue()];
|
sum += outcomes[n.intValue()];
|
||||||
}
|
}
|
||||||
|
|
||||||
//Return the average of the outcomes as the prediction.
|
//Return the average of the outcomes as the prediction.
|
||||||
|
return sum / ((double) indexes.size());
|
||||||
return sum/((double)indexes.size());
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,13 +97,17 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
|
||||||
predictors[i] = list.get(i).doubleValue();
|
predictors[i] = list.get(i).doubleValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(regressedTuple.getScale()) {
|
||||||
predictors = regressedTuple.scale(predictors);
|
predictors = regressedTuple.scale(predictors);
|
||||||
|
}
|
||||||
|
|
||||||
return regressedTuple.predict(predictors);
|
return regressedTuple.predict(predictors);
|
||||||
} else if (second instanceof Matrix) {
|
} else if (second instanceof Matrix) {
|
||||||
|
|
||||||
Matrix m = (Matrix) second;
|
Matrix m = (Matrix) second;
|
||||||
|
if(regressedTuple.getScale()) {
|
||||||
m = regressedTuple.scale(m);
|
m = regressedTuple.scale(m);
|
||||||
|
}
|
||||||
double[][] data = m.getData();
|
double[][] data = m.getData();
|
||||||
List<Number> predictions = new ArrayList();
|
List<Number> predictions = new ArrayList();
|
||||||
for (double[] predictors : data) {
|
for (double[] predictors : data) {
|
||||||
|
|
|
@ -3450,7 +3450,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
||||||
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
|
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
|
||||||
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
||||||
"e=transpose(matrix(a, b, c))," +
|
"e=transpose(matrix(a, b, c))," +
|
||||||
"f=knnRegress(e, d, 1)," +
|
"f=knnRegress(e, d, 1, scale=true)," +
|
||||||
"g=predict(f, e))";
|
"g=predict(f, e))";
|
||||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
paramsLoc.set("expr", cexpr);
|
paramsLoc.set("expr", cexpr);
|
||||||
|
@ -3480,7 +3480,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
||||||
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
|
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
|
||||||
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
||||||
"e=transpose(matrix(a, b, c))," +
|
"e=transpose(matrix(a, b, c))," +
|
||||||
"f=knnRegress(e, d, 1)," +
|
"f=knnRegress(e, d, 1, scale=true)," +
|
||||||
"g=predict(f, array(8, 5, 4)))";
|
"g=predict(f, array(8, 5, 4)))";
|
||||||
paramsLoc = new ModifiableSolrParams();
|
paramsLoc = new ModifiableSolrParams();
|
||||||
paramsLoc.set("expr", cexpr);
|
paramsLoc.set("expr", cexpr);
|
||||||
|
@ -3494,12 +3494,14 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
||||||
Number prediction = (Number)tuples.get(0).get("g");
|
Number prediction = (Number)tuples.get(0).get("g");
|
||||||
assertEquals(prediction.doubleValue(), 85.09999847, 0);
|
assertEquals(prediction.doubleValue(), 85.09999847, 0);
|
||||||
|
|
||||||
cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 8.19999981), " +
|
//Test robust. Take the median rather then average
|
||||||
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 5.699999809, 4.5)," +
|
|
||||||
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 4.900000095)," +
|
cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 8.10000038, 8.19999981), " +
|
||||||
|
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 5.599999905, 5.699999809, 4.5)," +
|
||||||
|
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 4.79999924, 4.900000095)," +
|
||||||
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
||||||
"e=transpose(matrix(a, b, c))," +
|
"e=transpose(matrix(a, b, c))," +
|
||||||
"f=knnRegress(e, d, 2)," +
|
"f=knnRegress(e, d, 3, scale=true, robust=true)," +
|
||||||
"g=predict(f, array(8, 5, 4)))";
|
"g=predict(f, array(8, 5, 4)))";
|
||||||
paramsLoc = new ModifiableSolrParams();
|
paramsLoc = new ModifiableSolrParams();
|
||||||
paramsLoc.set("expr", cexpr);
|
paramsLoc.set("expr", cexpr);
|
||||||
|
@ -3511,7 +3513,27 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
||||||
tuples = getTuples(solrStream);
|
tuples = getTuples(solrStream);
|
||||||
assertTrue(tuples.size() == 1);
|
assertTrue(tuples.size() == 1);
|
||||||
prediction = (Number)tuples.get(0).get("g");
|
prediction = (Number)tuples.get(0).get("g");
|
||||||
assertEquals(prediction.doubleValue(), 87.20000076, 0);
|
assertEquals(prediction.doubleValue(), 89.30000305, 0);
|
||||||
|
|
||||||
|
|
||||||
|
//Test univariate regression with scaling off
|
||||||
|
|
||||||
|
cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
|
||||||
|
"b=transpose(matrix(a))," +
|
||||||
|
"c=knnRegress(b, a, 3)," +
|
||||||
|
"d=predict(c, array(3)))";
|
||||||
|
paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
prediction = (Number)tuples.get(0).get("d");
|
||||||
|
assertEquals(prediction.doubleValue(), 3, 0);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
Loading…
Reference in New Issue