SOLR-11599: Change normalize function to standardize and make it work with matrices

This commit is contained in:
Joel Bernstein 2017-11-02 21:35:23 -04:00
parent 19db1df81a
commit 0ebf5e0896
3 changed files with 47 additions and 60 deletions

View File

@ -217,7 +217,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("hist", HistogramEvaluator.class) .withFunctionName("hist", HistogramEvaluator.class)
.withFunctionName("length", LengthEvaluator.class) .withFunctionName("length", LengthEvaluator.class)
.withFunctionName("movingAvg", MovingAverageEvaluator.class) .withFunctionName("movingAvg", MovingAverageEvaluator.class)
.withFunctionName("normalize", NormalizeEvaluator.class) .withFunctionName("standardize", NormalizeEvaluator.class)
.withFunctionName("percentile", PercentileEvaluator.class) .withFunctionName("percentile", PercentileEvaluator.class)
.withFunctionName("predict", PredictEvaluator.class) .withFunctionName("predict", PredictEvaluator.class)
.withFunctionName("rank", RankEvaluator.class) .withFunctionName("rank", RankEvaluator.class)
@ -270,7 +270,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("loess", LoessEvaluator.class) .withFunctionName("loess", LoessEvaluator.class)
.withFunctionName("matrix", MatrixEvaluator.class) .withFunctionName("matrix", MatrixEvaluator.class)
.withFunctionName("transpose", TransposeEvaluator.class) .withFunctionName("transpose", TransposeEvaluator.class)
.withFunctionName("unit", UnitEvaluator.class) .withFunctionName("unitize", UnitEvaluator.class)
.withFunctionName("triangularDistribution", TriangularDistributionEvaluator.class) .withFunctionName("triangularDistribution", TriangularDistributionEvaluator.class)
.withFunctionName("precision", PrecisionEvaluator.class) .withFunctionName("precision", PrecisionEvaluator.class)
.withFunctionName("minMaxScale", MinMaxScaleEvaluator.class) .withFunctionName("minMaxScale", MinMaxScaleEvaluator.class)

View File

@ -27,7 +27,7 @@ import org.apache.commons.math3.stat.StatUtils;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class NormalizeEvaluator extends RecursiveNumericEvaluator implements OneValueWorker { public class NormalizeEvaluator extends RecursiveObjectEvaluator implements OneValueWorker {
protected static final long serialVersionUID = 1L; protected static final long serialVersionUID = 1L;
public NormalizeEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ public NormalizeEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
@ -45,8 +45,16 @@ public class NormalizeEvaluator extends RecursiveNumericEvaluator implements One
} }
else if(value instanceof List){ else if(value instanceof List){
return Arrays.stream(StatUtils.normalize(((List<?>)value).stream().mapToDouble(innerValue -> ((Number)innerValue).doubleValue()).toArray())).mapToObj(Double::new).collect(Collectors.toList()); return Arrays.stream(StatUtils.normalize(((List<?>)value).stream().mapToDouble(innerValue -> ((Number)innerValue).doubleValue()).toArray())).mapToObj(Double::new).collect(Collectors.toList());
} } else if (value instanceof Matrix) {
else{ Matrix matrix = (Matrix) value;
double[][] data = matrix.getData();
double[][] standardized = new double[data.length][];
for(int i=0; i<data.length; i++) {
double[] row = data[i];
standardized[i] = StatUtils.normalize(row);
}
return new Matrix(standardized);
} else {
return doWork(Arrays.asList((BigDecimal)value)); return doWork(Arrays.asList((BigDecimal)value));
} }
} }

View File

@ -6097,8 +6097,8 @@ public class StreamExpressionTest extends SolrCloudTestCase {
} }
@Test @Test
public void testUnit() throws Exception { public void testUnitize() throws Exception {
String cexpr = "let(echo=true, a=unit(matrix(array(1,2,3), array(4,5,6))), b=unit(array(4,5,6)))"; String cexpr = "let(echo=true, a=unitize(matrix(array(1,2,3), array(4,5,6))), b=unitize(array(4,5,6)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr); paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream"); paramsLoc.set("qt", "/stream");
@ -6129,6 +6129,38 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertEquals(array2.get(2).doubleValue(), 0.6837634587578276, 0.0); assertEquals(array2.get(2).doubleValue(), 0.6837634587578276, 0.0);
} }
@Test
public void testStandardize() throws Exception {
String cexpr = "let(echo=true, a=standardize(matrix(array(1,2,3), array(4,5,6))), b=standardize(array(4,5,6)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<List<Number>> out = (List<List<Number>>)tuples.get(0).get("a");
assertEquals(out.size(), 2);
List<Number> array1 = out.get(0);
assertEquals(array1.size(), 3);
assertEquals(array1.get(0).doubleValue(), -1, 0.0);
assertEquals(array1.get(1).doubleValue(), 0, 0.0);
assertEquals(array1.get(2).doubleValue(), 1, 0.0);
List<Number> array2 = out.get(1);
assertEquals(array2.size(), 3);
assertEquals(array2.get(0).doubleValue(), -1, 0.0);
assertEquals(array2.get(1).doubleValue(), 0, 0.0);
assertEquals(array2.get(2).doubleValue(), 1, 0.0);
List<Number> array3 = (List<Number>)tuples.get(0).get("b");
assertEquals(array3.size(), 3);
assertEquals(array2.get(0).doubleValue(), -1, 0.0);
assertEquals(array2.get(1).doubleValue(), 0, 0.0);
assertEquals(array2.get(2).doubleValue(), 1, 0.0);
}
@Test @Test
@ -7556,59 +7588,6 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertTrue(length == 7); assertTrue(length == 7);
} }
@Test
public void testNormalize() throws Exception {
UpdateRequest updateRequest = new UpdateRequest();
updateRequest.add(id, "1", "price_f", "100.0", "col_s", "a", "order_i", "1");
updateRequest.add(id, "2", "price_f", "200.0", "col_s", "a", "order_i", "2");
updateRequest.add(id, "3", "price_f", "300.0", "col_s", "a", "order_i", "3");
updateRequest.add(id, "4", "price_f", "100.0", "col_s", "a", "order_i", "4");
updateRequest.add(id, "5", "price_f", "200.0", "col_s", "a", "order_i", "5");
updateRequest.add(id, "6", "price_f", "400.0", "col_s", "a", "order_i", "6");
updateRequest.add(id, "7", "price_f", "600.0", "col_s", "a", "order_i", "7");
updateRequest.commit(cluster.getSolrClient(), COLLECTIONORALIAS);
String expr1 = "search("+COLLECTIONORALIAS+", q=\"col_s:a\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
String cexpr = "let(a="+expr1+", c=col(a, price_f), tuple(n=normalize(c), c=c))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Tuple tuple = tuples.get(0);
List<Double> col = (List<Double>)tuple.get("c");
List<Double> normalized = (List<Double>)tuple.get("n");
assertTrue(col.size() == normalized.size());
double total = 0.0D;
for(double d : normalized) {
total += d;
}
double mean = total/normalized.size();
assert(Math.round(mean) == 0);
double sd = 0;
for (int i = 0; i < normalized.size(); i++)
{
sd += Math.pow(normalized.get(i) - mean, 2) / normalized.size();
}
double standardDeviation = Math.sqrt(sd);
assertTrue(Math.round(standardDeviation) == 1);
}
@Test @Test
public void testListStream() throws Exception { public void testListStream() throws Exception {
UpdateRequest updateRequest = new UpdateRequest(); UpdateRequest updateRequest = new UpdateRequest();