mirror of https://github.com/apache/lucene.git
SOLR-13134: Allow the knnRegress Stream Evaluator to more easily perform bivariate regression
This commit is contained in:
parent
dcc9ffe186
commit
292e26bc2d
|
@ -64,11 +64,20 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
|||
List<Number> outcomes = null;
|
||||
int k = 5;
|
||||
DistanceMeasure distanceMeasure = new EuclideanDistance();
|
||||
boolean bivariate = false;
|
||||
|
||||
if(values[0] instanceof Matrix) {
|
||||
observations = (Matrix)values[0];
|
||||
} else if(values[0] instanceof List) {
|
||||
bivariate = true;
|
||||
List<Number> vec = (List<Number>)values[0];
|
||||
double[][] data = new double[vec.size()][1];
|
||||
for(int i=0; i<vec.size(); i++) {
|
||||
data[i][0] = vec.get(i).doubleValue();
|
||||
}
|
||||
observations = new Matrix(data);
|
||||
} else {
|
||||
throw new IOException("The first parameter for knnRegress should be the observation matrix.");
|
||||
throw new IOException("The first parameter for knnRegress should be the observation vector or matrix.");
|
||||
}
|
||||
|
||||
if(values[1] instanceof List) {
|
||||
|
@ -104,7 +113,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
|||
map.put("robust", robust);
|
||||
map.put("scale", scale);
|
||||
|
||||
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust);
|
||||
return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust, bivariate);
|
||||
}
|
||||
|
||||
|
||||
|
@ -117,6 +126,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
|||
private DistanceMeasure distanceMeasure;
|
||||
private boolean scale;
|
||||
private boolean robust;
|
||||
private boolean bivariate;
|
||||
|
||||
public KnnRegressionTuple(Matrix observations,
|
||||
double[] outcomes,
|
||||
|
@ -124,7 +134,8 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
|||
DistanceMeasure distanceMeasure,
|
||||
Map<?,?> map,
|
||||
boolean scale,
|
||||
boolean robust) {
|
||||
boolean robust,
|
||||
boolean bivariate) {
|
||||
super(map);
|
||||
this.observations = observations;
|
||||
this.outcomes = outcomes;
|
||||
|
@ -132,11 +143,15 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
|
|||
this.distanceMeasure = distanceMeasure;
|
||||
this.scale = scale;
|
||||
this.robust = robust;
|
||||
this.bivariate = bivariate;
|
||||
}
|
||||
|
||||
public boolean getScale() {
|
||||
return this.scale;
|
||||
}
|
||||
public boolean getBivariate() {
|
||||
return this.bivariate;
|
||||
}
|
||||
|
||||
//MinMax Scale both the observations and the predictors
|
||||
|
||||
|
|
|
@ -89,31 +89,51 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
|
|||
|
||||
} else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple) {
|
||||
KnnRegressionEvaluator.KnnRegressionTuple regressedTuple = (KnnRegressionEvaluator.KnnRegressionTuple) first;
|
||||
if (second instanceof List) {
|
||||
List<Number> list = (List<Number>) second;
|
||||
double[] predictors = new double[list.size()];
|
||||
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
predictors[i] = list.get(i).doubleValue();
|
||||
if(regressedTuple.getBivariate()) {
|
||||
//Handle bi-variate regression
|
||||
if(second instanceof Number) {
|
||||
double[] predictors = new double[1];
|
||||
predictors[0] = ((Number)second).doubleValue();
|
||||
return regressedTuple.predict(predictors);
|
||||
} else if(second instanceof List) {
|
||||
List<Number> vec = (List<Number>)second;
|
||||
List<Number> predictions = new ArrayList();
|
||||
for(Number num : vec) {
|
||||
double[] predictors = new double[1];
|
||||
predictors[0] = num.doubleValue();
|
||||
predictions.add(regressedTuple.predict(predictors));
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
} else {
|
||||
//Handle multi-variate regression
|
||||
if (second instanceof List) {
|
||||
List<Number> list = (List<Number>) second;
|
||||
double[] predictors = new double[list.size()];
|
||||
|
||||
if(regressedTuple.getScale()) {
|
||||
predictors = regressedTuple.scale(predictors);
|
||||
}
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
predictors[i] = list.get(i).doubleValue();
|
||||
}
|
||||
|
||||
return regressedTuple.predict(predictors);
|
||||
} else if (second instanceof Matrix) {
|
||||
if (regressedTuple.getScale()) {
|
||||
predictors = regressedTuple.scale(predictors);
|
||||
}
|
||||
|
||||
Matrix m = (Matrix) second;
|
||||
if(regressedTuple.getScale()) {
|
||||
m = regressedTuple.scale(m);
|
||||
return regressedTuple.predict(predictors);
|
||||
} else if (second instanceof Matrix) {
|
||||
|
||||
Matrix m = (Matrix) second;
|
||||
if (regressedTuple.getScale()) {
|
||||
m = regressedTuple.scale(m);
|
||||
}
|
||||
double[][] data = m.getData();
|
||||
List<Number> predictions = new ArrayList();
|
||||
for (double[] predictors : data) {
|
||||
predictions.add(regressedTuple.predict(predictors));
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
double[][] data = m.getData();
|
||||
List<Number> predictions = new ArrayList();
|
||||
for (double[] predictors : data) {
|
||||
predictions.add(regressedTuple.predict(predictors));
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
} else if (first instanceof VectorFunction) {
|
||||
VectorFunction vectorFunction = (VectorFunction) first;
|
||||
|
|
|
@ -4136,9 +4136,8 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
//Test univariate regression with scaling off
|
||||
|
||||
cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
|
||||
"b=transpose(matrix(a))," +
|
||||
"c=knnRegress(b, a, 3)," +
|
||||
"d=predict(c, array(3)))";
|
||||
"c=knnRegress(a, a, 3)," +
|
||||
"d=predict(c, 3))";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
|
@ -4151,6 +4150,22 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
prediction = (Number)tuples.get(0).get("d");
|
||||
assertEquals(prediction.doubleValue(), 3, 0);
|
||||
|
||||
cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
|
||||
"c=knnRegress(a, a, 3)," +
|
||||
"d=predict(c, array(3,4)))";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
solrStream = new SolrStream(url, paramsLoc);
|
||||
context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
predictions = (List<Number>)tuples.get(0).get("d");
|
||||
assertEquals(predictions.size(), 2);
|
||||
assertEquals(predictions.get(0).doubleValue(), 3, 0);
|
||||
assertEquals(predictions.get(1).doubleValue(), 4, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue