From 52f9cee97b4f293af26de0e7b4ec534cb6b11b10 Mon Sep 17 00:00:00 2001 From: Joel Bernstein Date: Fri, 17 Aug 2018 14:26:05 -0400 Subject: [PATCH] SOLR-12671: Add robust flag to knnRegress Stream Evaluator --- .../client/solrj/io/eval/KnnEvaluator.java | 4 ++ .../solrj/io/eval/KnnRegressionEvaluator.java | 63 +++++++++++++++---- .../solrj/io/eval/PredictEvaluator.java | 8 ++- .../solrj/io/stream/MathExpressionTest.java | 36 ++++++++--- 4 files changed, 91 insertions(+), 20 deletions(-) diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java index 81607cfdea0..17fb0110af2 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java @@ -144,6 +144,10 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW } public int compareTo(Neighbor neighbor) { + if(this.distance.compareTo(neighbor.getDistance()) == 0) { + return row-neighbor.getRow(); + } + return this.distance.compareTo(neighbor.getDistance()); } } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java index 957936eefa8..e6f6d8022c5 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java @@ -25,15 +25,32 @@ import java.util.HashMap; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.ml.distance.DistanceMeasure; import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.stat.descriptive.rank.Percentile; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker { protected static final long serialVersionUID = 1L; + private boolean robust=false; + private boolean scale=false; + public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ super(expression, factory); + + List namedParams = factory.getNamedOperands(expression); + + for(StreamExpressionNamedParameter namedParam : namedParams){ + if(namedParam.getName().equals("scale")){ + this.scale = Boolean.parseBoolean(namedParam.getParameter().toString().trim()); + } else if(namedParam.getName().equals("robust")) { + this.robust = Boolean.parseBoolean(namedParam.getParameter().toString().trim()); + } else { + throw new IOException("Unexpected named parameter:"+namedParam.getName()); + } + } } @Override @@ -84,7 +101,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements map.put("features", observations.getColumnCount()); map.put("distance", distanceMeasure.getClass().getSimpleName()); - return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map); + return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust); } @@ -95,17 +112,27 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements private double[] outcomes; private int k; private DistanceMeasure distanceMeasure; + private boolean scale; + private boolean robust; public KnnRegressionTuple(Matrix observations, double[] outcomes, int k, DistanceMeasure distanceMeasure, - Map map) { + Map map, + boolean scale, + boolean robust) { super(map); this.observations = observations; this.outcomes = outcomes; this.k = k; this.distanceMeasure = distanceMeasure; + this.scale = scale; + this.robust = robust; + } + + public boolean getScale() { + return this.scale; } //MinMax Scale both the observations and the predictors @@ -175,19 +202,33 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements public double predict(double[] values) { - Matrix knn = KnnEvaluator.search(scaledObservations, values, k, distanceMeasure); + Matrix obs = scaledObservations != null ? scaledObservations : observations; + Matrix knn = KnnEvaluator.search(obs, values, k, distanceMeasure); List indexes = (List)knn.getAttribute("indexes"); - double sum = 0; + if(robust) { + //Get the median of the results. + double[] vals = new double[indexes.size()]; + Percentile percentile = new Percentile(); + int i=0; + for (Number n : indexes) { + vals[i++]=outcomes[n.intValue()]; + } - //Collect the outcomes for the nearest neighbors - for(Number n : indexes) { - sum += outcomes[n.intValue()]; + //Return 50 percentile. + return percentile.evaluate(vals, 50); + } else { + //Get the average of the results + double sum = 0; + + //Collect the outcomes for the nearest neighbors + for (Number n : indexes) { + sum += outcomes[n.intValue()]; + } + + //Return the average of the outcomes as the prediction. + return sum / ((double) indexes.size()); } - - //Return the average of the outcomes as the prediction. - - return sum/((double)indexes.size()); } } } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java index 9385928f82b..c8e83ba63b9 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java @@ -97,13 +97,17 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa predictors[i] = list.get(i).doubleValue(); } - predictors = regressedTuple.scale(predictors); + if(regressedTuple.getScale()) { + predictors = regressedTuple.scale(predictors); + } return regressedTuple.predict(predictors); } else if (second instanceof Matrix) { Matrix m = (Matrix) second; - m = regressedTuple.scale(m); + if(regressedTuple.getScale()) { + m = regressedTuple.scale(m); + } double[][] data = m.getData(); List predictions = new ArrayList(); for (double[] predictors : data) { diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java index bfd4160d2f5..6565b7623bb 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java @@ -3450,7 +3450,7 @@ public class MathExpressionTest extends SolrCloudTestCase { "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," + "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," + "e=transpose(matrix(a, b, c))," + - "f=knnRegress(e, d, 1)," + + "f=knnRegress(e, d, 1, scale=true)," + "g=predict(f, e))"; ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); paramsLoc.set("expr", cexpr); @@ -3480,7 +3480,7 @@ public class MathExpressionTest extends SolrCloudTestCase { "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," + "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," + "e=transpose(matrix(a, b, c))," + - "f=knnRegress(e, d, 1)," + + "f=knnRegress(e, d, 1, scale=true)," + "g=predict(f, array(8, 5, 4)))"; paramsLoc = new ModifiableSolrParams(); paramsLoc.set("expr", cexpr); @@ -3494,12 +3494,14 @@ public class MathExpressionTest extends SolrCloudTestCase { Number prediction = (Number)tuples.get(0).get("g"); assertEquals(prediction.doubleValue(), 85.09999847, 0); - cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 8.19999981), " + - "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 5.699999809, 4.5)," + - "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 4.900000095)," + + //Test robust. Take the median rather then average + + cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 8.10000038, 8.19999981), " + + "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 5.599999905, 5.699999809, 4.5)," + + "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 4.79999924, 4.900000095)," + "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," + "e=transpose(matrix(a, b, c))," + - "f=knnRegress(e, d, 2)," + + "f=knnRegress(e, d, 3, scale=true, robust=true)," + "g=predict(f, array(8, 5, 4)))"; paramsLoc = new ModifiableSolrParams(); paramsLoc.set("expr", cexpr); @@ -3511,7 +3513,27 @@ public class MathExpressionTest extends SolrCloudTestCase { tuples = getTuples(solrStream); assertTrue(tuples.size() == 1); prediction = (Number)tuples.get(0).get("g"); - assertEquals(prediction.doubleValue(), 87.20000076, 0); + assertEquals(prediction.doubleValue(), 89.30000305, 0); + + + //Test univariate regression with scaling off + + cexpr = "let(echo=true, a=sequence(10, 0, 1), " + + "b=transpose(matrix(a))," + + "c=knnRegress(b, a, 3)," + + "d=predict(c, array(3)))"; + paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + solrStream = new SolrStream(url, paramsLoc); + context = new StreamContext(); + solrStream.setStreamContext(context); + tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + prediction = (Number)tuples.get(0).get("d"); + assertEquals(prediction.doubleValue(), 3, 0); + } @Test