SOLR-13134: Allow the knnRegress Stream Evaluator to more easily perform bivariate regression

This commit is contained in:
Joel Bernstein 2019-01-12 15:25:45 -05:00
parent dcc9ffe186
commit 292e26bc2d
3 changed files with 75 additions and 25 deletions

View File

@ -64,11 +64,20 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
List<Number> outcomes = null;
int k = 5;
DistanceMeasure distanceMeasure = new EuclideanDistance();
boolean bivariate = false;
if(values[0] instanceof Matrix) {
observations = (Matrix)values[0];
} else if(values[0] instanceof List) {
bivariate = true;
List<Number> vec = (List<Number>)values[0];
double[][] data = new double[vec.size()][1];
for(int i=0; i<vec.size(); i++) {
data[i][0] = vec.get(i).doubleValue();
}
observations = new Matrix(data);
} else {
throw new IOException("The first parameter for knnRegress should be the observation matrix.");
throw new IOException("The first parameter for knnRegress should be the observation vector or matrix.");
}
if(values[1] instanceof List) {
@ -104,7 +113,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
map.put("robust", robust);
map.put("scale", scale);
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust);
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust, bivariate);
}
@ -117,6 +126,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
private DistanceMeasure distanceMeasure;
private boolean scale;
private boolean robust;
private boolean bivariate;
public KnnRegressionTuple(Matrix observations,
double[] outcomes,
@ -124,7 +134,8 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
DistanceMeasure distanceMeasure,
Map<?,?> map,
boolean scale,
boolean robust) {
boolean robust,
boolean bivariate) {
super(map);
this.observations = observations;
this.outcomes = outcomes;
@ -132,11 +143,15 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
this.distanceMeasure = distanceMeasure;
this.scale = scale;
this.robust = robust;
this.bivariate = bivariate;
}
public boolean getScale() {
return this.scale;
}
public boolean getBivariate() {
return this.bivariate;
}
//MinMax Scale both the observations and the predictors

View File

@ -89,31 +89,51 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
} else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple) {
KnnRegressionEvaluator.KnnRegressionTuple regressedTuple = (KnnRegressionEvaluator.KnnRegressionTuple) first;
if (second instanceof List) {
List<Number> list = (List<Number>) second;
double[] predictors = new double[list.size()];
for (int i = 0; i < list.size(); i++) {
predictors[i] = list.get(i).doubleValue();
if(regressedTuple.getBivariate()) {
//Handle bi-variate regression
if(second instanceof Number) {
double[] predictors = new double[1];
predictors[0] = ((Number)second).doubleValue();
return regressedTuple.predict(predictors);
} else if(second instanceof List) {
List<Number> vec = (List<Number>)second;
List<Number> predictions = new ArrayList();
for(Number num : vec) {
double[] predictors = new double[1];
predictors[0] = num.doubleValue();
predictions.add(regressedTuple.predict(predictors));
}
return predictions;
}
} else {
//Handle multi-variate regression
if (second instanceof List) {
List<Number> list = (List<Number>) second;
double[] predictors = new double[list.size()];
if(regressedTuple.getScale()) {
predictors = regressedTuple.scale(predictors);
}
for (int i = 0; i < list.size(); i++) {
predictors[i] = list.get(i).doubleValue();
}
return regressedTuple.predict(predictors);
} else if (second instanceof Matrix) {
if (regressedTuple.getScale()) {
predictors = regressedTuple.scale(predictors);
}
Matrix m = (Matrix) second;
if(regressedTuple.getScale()) {
m = regressedTuple.scale(m);
return regressedTuple.predict(predictors);
} else if (second instanceof Matrix) {
Matrix m = (Matrix) second;
if (regressedTuple.getScale()) {
m = regressedTuple.scale(m);
}
double[][] data = m.getData();
List<Number> predictions = new ArrayList();
for (double[] predictors : data) {
predictions.add(regressedTuple.predict(predictors));
}
return predictions;
}
double[][] data = m.getData();
List<Number> predictions = new ArrayList();
for (double[] predictors : data) {
predictions.add(regressedTuple.predict(predictors));
}
return predictions;
}
} else if (first instanceof VectorFunction) {
VectorFunction vectorFunction = (VectorFunction) first;

View File

@ -4136,9 +4136,8 @@ public class MathExpressionTest extends SolrCloudTestCase {
//Test univariate regression with scaling off
cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
"b=transpose(matrix(a))," +
"c=knnRegress(b, a, 3)," +
"d=predict(c, array(3)))";
"c=knnRegress(a, a, 3)," +
"d=predict(c, 3))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
@ -4151,6 +4150,22 @@ public class MathExpressionTest extends SolrCloudTestCase {
prediction = (Number)tuples.get(0).get("d");
assertEquals(prediction.doubleValue(), 3, 0);
cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
"c=knnRegress(a, a, 3)," +
"d=predict(c, array(3,4)))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
predictions = (List<Number>)tuples.get(0).get("d");
assertEquals(predictions.size(), 2);
assertEquals(predictions.get(0).doubleValue(), 3, 0);
assertEquals(predictions.get(1).doubleValue(), 4, 0);
}
@Test