mirror of https://github.com/apache/lucene.git
SOLR-12671: Add robust flag to knnRegress Stream Evaluator
This commit is contained in:
parent
124be4e202
commit
52f9cee97b
|
@ -144,6 +144,10 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
|
|||
}
|
||||
|
||||
public int compareTo(Neighbor neighbor) {
|
||||
if(this.distance.compareTo(neighbor.getDistance()) == 0) {
|
||||
return row-neighbor.getRow();
|
||||
}
|
||||
|
||||
return this.distance.compareTo(neighbor.getDistance());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,15 +25,32 @@ import java.util.HashMap;
|
|||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
|
||||
|
||||
public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
private boolean robust=false;
|
||||
private boolean scale=false;
|
||||
|
||||
public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
||||
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
|
||||
|
||||
for(StreamExpressionNamedParameter namedParam : namedParams){
|
||||
if(namedParam.getName().equals("scale")){
|
||||
this.scale = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
|
||||
} else if(namedParam.getName().equals("robust")) {
|
||||
this.robust = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
|
||||
} else {
|
||||
throw new IOException("Unexpected named parameter:"+namedParam.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -84,7 +101,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
|||
map.put("features", observations.getColumnCount());
|
||||
map.put("distance", distanceMeasure.getClass().getSimpleName());
|
||||
|
||||
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map);
|
||||
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust);
|
||||
}
|
||||
|
||||
|
||||
|
@ -95,17 +112,27 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
|||
private double[] outcomes;
|
||||
private int k;
|
||||
private DistanceMeasure distanceMeasure;
|
||||
private boolean scale;
|
||||
private boolean robust;
|
||||
|
||||
public KnnRegressionTuple(Matrix observations,
|
||||
double[] outcomes,
|
||||
int k,
|
||||
DistanceMeasure distanceMeasure,
|
||||
Map<?,?> map) {
|
||||
Map<?,?> map,
|
||||
boolean scale,
|
||||
boolean robust) {
|
||||
super(map);
|
||||
this.observations = observations;
|
||||
this.outcomes = outcomes;
|
||||
this.k = k;
|
||||
this.distanceMeasure = distanceMeasure;
|
||||
this.scale = scale;
|
||||
this.robust = robust;
|
||||
}
|
||||
|
||||
public boolean getScale() {
|
||||
return this.scale;
|
||||
}
|
||||
|
||||
//MinMax Scale both the observations and the predictors
|
||||
|
@ -175,19 +202,33 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
|||
|
||||
public double predict(double[] values) {
|
||||
|
||||
Matrix knn = KnnEvaluator.search(scaledObservations, values, k, distanceMeasure);
|
||||
Matrix obs = scaledObservations != null ? scaledObservations : observations;
|
||||
Matrix knn = KnnEvaluator.search(obs, values, k, distanceMeasure);
|
||||
List<Number> indexes = (List<Number>)knn.getAttribute("indexes");
|
||||
|
||||
double sum = 0;
|
||||
if(robust) {
|
||||
//Get the median of the results.
|
||||
double[] vals = new double[indexes.size()];
|
||||
Percentile percentile = new Percentile();
|
||||
int i=0;
|
||||
for (Number n : indexes) {
|
||||
vals[i++]=outcomes[n.intValue()];
|
||||
}
|
||||
|
||||
//Collect the outcomes for the nearest neighbors
|
||||
for(Number n : indexes) {
|
||||
sum += outcomes[n.intValue()];
|
||||
//Return 50 percentile.
|
||||
return percentile.evaluate(vals, 50);
|
||||
} else {
|
||||
//Get the average of the results
|
||||
double sum = 0;
|
||||
|
||||
//Collect the outcomes for the nearest neighbors
|
||||
for (Number n : indexes) {
|
||||
sum += outcomes[n.intValue()];
|
||||
}
|
||||
|
||||
//Return the average of the outcomes as the prediction.
|
||||
return sum / ((double) indexes.size());
|
||||
}
|
||||
|
||||
//Return the average of the outcomes as the prediction.
|
||||
|
||||
return sum/((double)indexes.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -97,13 +97,17 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
|
|||
predictors[i] = list.get(i).doubleValue();
|
||||
}
|
||||
|
||||
predictors = regressedTuple.scale(predictors);
|
||||
if(regressedTuple.getScale()) {
|
||||
predictors = regressedTuple.scale(predictors);
|
||||
}
|
||||
|
||||
return regressedTuple.predict(predictors);
|
||||
} else if (second instanceof Matrix) {
|
||||
|
||||
Matrix m = (Matrix) second;
|
||||
m = regressedTuple.scale(m);
|
||||
if(regressedTuple.getScale()) {
|
||||
m = regressedTuple.scale(m);
|
||||
}
|
||||
double[][] data = m.getData();
|
||||
List<Number> predictions = new ArrayList();
|
||||
for (double[] predictors : data) {
|
||||
|
|
|
@ -3450,7 +3450,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
|
||||
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
||||
"e=transpose(matrix(a, b, c))," +
|
||||
"f=knnRegress(e, d, 1)," +
|
||||
"f=knnRegress(e, d, 1, scale=true)," +
|
||||
"g=predict(f, e))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
|
@ -3480,7 +3480,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
|
||||
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
||||
"e=transpose(matrix(a, b, c))," +
|
||||
"f=knnRegress(e, d, 1)," +
|
||||
"f=knnRegress(e, d, 1, scale=true)," +
|
||||
"g=predict(f, array(8, 5, 4)))";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
|
@ -3494,12 +3494,14 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
Number prediction = (Number)tuples.get(0).get("g");
|
||||
assertEquals(prediction.doubleValue(), 85.09999847, 0);
|
||||
|
||||
cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 8.19999981), " +
|
||||
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 5.699999809, 4.5)," +
|
||||
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 4.900000095)," +
|
||||
//Test robust. Take the median rather then average
|
||||
|
||||
cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 8.10000038, 8.19999981), " +
|
||||
"b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 5.599999905, 5.699999809, 4.5)," +
|
||||
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 4.79999924, 4.900000095)," +
|
||||
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
|
||||
"e=transpose(matrix(a, b, c))," +
|
||||
"f=knnRegress(e, d, 2)," +
|
||||
"f=knnRegress(e, d, 3, scale=true, robust=true)," +
|
||||
"g=predict(f, array(8, 5, 4)))";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
|
@ -3511,7 +3513,27 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
prediction = (Number)tuples.get(0).get("g");
|
||||
assertEquals(prediction.doubleValue(), 87.20000076, 0);
|
||||
assertEquals(prediction.doubleValue(), 89.30000305, 0);
|
||||
|
||||
|
||||
//Test univariate regression with scaling off
|
||||
|
||||
cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
|
||||
"b=transpose(matrix(a))," +
|
||||
"c=knnRegress(b, a, 3)," +
|
||||
"d=predict(c, array(3)))";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
solrStream = new SolrStream(url, paramsLoc);
|
||||
context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
prediction = (Number)tuples.get(0).get("d");
|
||||
assertEquals(prediction.doubleValue(), 3, 0);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue