SOLR-13298: Allow zplot to plot matrices

This commit is contained in:
Joel Bernstein 2019-10-06 22:17:45 -04:00
parent 888fe76a09
commit 7d4751e8b8
9 changed files with 312 additions and 56 deletions

View File

@ -17,6 +17,7 @@
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
@ -26,6 +27,7 @@ import org.apache.commons.math3.stat.correlation.PearsonsCorrelation;
import org.apache.commons.math3.stat.correlation.KendallsCorrelation;
import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;
import org.apache.solr.client.solrj.io.stream.ZplotStream;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -109,6 +111,9 @@ public class CorrelationEvaluator extends RecursiveObjectEvaluator implements Ma
double[][] corrMatrixData = corrMatrix.getData();
Matrix realMatrix = new Matrix(corrMatrixData);
realMatrix.setAttribute("corr", pearsonsCorrelation);
List<String> labels = getColumnLabels(matrix.getColumnLabels(), corrMatrixData.length);
realMatrix.setColumnLabels(labels);
realMatrix.setRowLabels(labels);
return realMatrix;
} else if (type.equals(CorrelationType.kendalls)) {
KendallsCorrelation kendallsCorrelation = new KendallsCorrelation(data);
@ -116,6 +121,9 @@ public class CorrelationEvaluator extends RecursiveObjectEvaluator implements Ma
double[][] corrMatrixData = corrMatrix.getData();
Matrix realMatrix = new Matrix(corrMatrixData);
realMatrix.setAttribute("corr", kendallsCorrelation);
List<String> labels = getColumnLabels(matrix.getColumnLabels(), corrMatrixData.length);
realMatrix.setColumnLabels(labels);
realMatrix.setRowLabels(labels);
return realMatrix;
} else if (type.equals(CorrelationType.spearmans)) {
SpearmansCorrelation spearmansCorrelation = new SpearmansCorrelation(new Array2DRowRealMatrix(data, false));
@ -123,6 +131,9 @@ public class CorrelationEvaluator extends RecursiveObjectEvaluator implements Ma
double[][] corrMatrixData = corrMatrix.getData();
Matrix realMatrix = new Matrix(corrMatrixData);
realMatrix.setAttribute("corr", spearmansCorrelation.getRankCorrelation());
List<String> labels = getColumnLabels(matrix.getColumnLabels(), corrMatrixData.length);
realMatrix.setColumnLabels(labels);
realMatrix.setRowLabels(labels);
return realMatrix;
} else {
return null;
@ -134,4 +145,18 @@ public class CorrelationEvaluator extends RecursiveObjectEvaluator implements Ma
throw new IOException("corr function operates on either two numeric arrays or a single matrix as parameters.");
}
}
public static List<String> getColumnLabels(List<String> labels, int length) {
if(labels != null) {
return labels;
} else {
List<String> l = new ArrayList();
for(int i=0; i<length; i++) {
String label = "col"+ ZplotStream.pad(Integer.toString(i), length);
l.add(label);
}
return l;
}
}
}

View File

@ -49,7 +49,11 @@ public class CovarianceEvaluator extends RecursiveObjectEvaluator implements Man
Covariance covariance = new Covariance(data, true);
RealMatrix coMatrix = covariance.getCovarianceMatrix();
double[][] coData = coMatrix.getData();
return new Matrix(coData);
Matrix realMatrix = new Matrix(coData);
List<String> labels = CorrelationEvaluator.getColumnLabels(matrix.getColumnLabels(), coData.length);
realMatrix.setColumnLabels(labels);
realMatrix.setRowLabels(labels);
return realMatrix;
} else {
throw new IOException("The cov function expects either two numeric arrays or a matrix as parameters.");
}

View File

@ -123,6 +123,10 @@ public class DistanceEvaluator extends RecursiveObjectEvaluator implements ManyV
distanceMatrix[i][j] = dist;
}
}
return new Matrix(distanceMatrix);
Matrix m = new Matrix(distanceMatrix);
List<String> labels = CorrelationEvaluator.getColumnLabels(matrix.getColumnLabels(), data.length);
m.setColumnLabels(labels);
m.setRowLabels(labels);
return m;
}
}

View File

@ -27,6 +27,7 @@ import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.ml.clustering.FuzzyKMeansClusterer;
import org.apache.solr.client.solrj.io.stream.ZplotStream;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -100,6 +101,11 @@ public class FuzzyKmeansEvaluator extends RecursiveObjectEvaluator implements Tw
double[][] mmData = realMatrix.getData();
Matrix mmMatrix = new Matrix(mmData);
mmMatrix.setRowLabels(matrix.getRowLabels());
List<String> clusterCols = new ArrayList();
for(int i=0; i<clusters.size(); i++) {
clusterCols.add("cluster"+ ZplotStream.pad(Integer.toString(i), clusters.size()));
}
mmMatrix.setRowLabels(clusterCols);
return new KmeansEvaluator.ClusterTuple(fields, clusters, matrix.getColumnLabels(),mmMatrix);
}
}

View File

@ -53,7 +53,10 @@ public class NormalizeEvaluator extends RecursiveObjectEvaluator implements OneV
double[] row = data[i];
standardized[i] = StatUtils.normalize(row);
}
return new Matrix(standardized);
Matrix m = new Matrix(standardized);
m.setRowLabels(matrix.getRowLabels());
m.setColumnLabels(matrix.getColumnLabels());
return m;
} else {
return doWork(Arrays.asList((BigDecimal)value));
}

View File

@ -62,7 +62,10 @@ public class NormalizeSumEvaluator extends RecursiveObjectEvaluator implements M
unitData[i] = unitRow;
}
return new Matrix(unitData);
Matrix m = new Matrix(unitData);
m.setRowLabels(matrix.getRowLabels());
m.setColumnLabels(matrix.getColumnLabels());
return m;
} else if(value instanceof List) {
List<Number> vals = (List<Number>)value;
double[] doubles = new double[vals.size()];

View File

@ -55,7 +55,7 @@ public class UnitEvaluator extends RecursiveObjectEvaluator implements OneValueW
Matrix m = new Matrix(unitData);
m.setRowLabels(matrix.getRowLabels());
m.setColumnLabels(matrix.getRowLabels());
m.setColumnLabels(matrix.getColumnLabels());
return m;
} else if(value instanceof List) {
List<Number> values = (List<Number>)value;

View File

@ -36,6 +36,7 @@ import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.eval.KmeansEvaluator;
import org.apache.solr.client.solrj.io.eval.StreamEvaluator;
import org.apache.solr.client.solrj.io.eval.Matrix;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
@ -129,6 +130,7 @@ public class ZplotStream extends TupleStream implements Expressible {
boolean table = false;
boolean distribution = false;
boolean clusters = false;
boolean heat = false;
for(Map.Entry<String, Object> entry : entries) {
++columns;
@ -139,6 +141,9 @@ public class ZplotStream extends TupleStream implements Expressible {
distribution = true;
} else if(name.equals("clusters")) {
clusters = true;
} else if(name.equals("heat")) {
heat = true;
}
Object o = entry.getValue();
@ -176,6 +181,8 @@ public class ZplotStream extends TupleStream implements Expressible {
evaluated.put(name, l);
} else if(eval instanceof Tuple) {
evaluated.put(name, eval);
} else if(eval instanceof Matrix) {
evaluated.put(name, eval);
}
}
}
@ -186,7 +193,7 @@ public class ZplotStream extends TupleStream implements Expressible {
//Load the values into tuples
List<Tuple> outTuples = new ArrayList();
if(!table && !distribution && !clusters) {
if(!table && !distribution && !clusters && !heat) {
//Handle the vectors
for (int i = 0; i < numTuples; i++) {
Tuple tuple = new Tuple(new HashMap());
@ -304,20 +311,96 @@ public class ZplotStream extends TupleStream implements Expressible {
}
}
}
} else if(table){
} else if(table) {
//Handle the Tuple and List of Tuples
Object o = evaluated.get("table");
if(o instanceof List) {
List<Tuple> tuples = (List<Tuple>)o;
outTuples.addAll(tuples);
} else if(o instanceof Tuple) {
outTuples.add((Tuple)o);
if (o instanceof Matrix) {
Matrix m = (Matrix) o;
List<String> rowLabels = m.getRowLabels();
List<String> colLabels = m.getColumnLabels();
double[][] data = m.getData();
for (int i = 0; i < data.length; i++) {
String rowLabel = null;
if (rowLabels != null) {
rowLabel = rowLabels.get(i);
} else {
rowLabel = Integer.toString(i);
}
Tuple tuple = new Tuple(new HashMap());
tuple.put("rowLabel", rowLabel);
double[] row = data[i];
for (int j = 0; j < row.length; j++) {
String colLabel = null;
if (colLabels != null) {
colLabel = colLabels.get(j);
} else {
colLabel = "col" + Integer.toString(j);
}
tuple.put(colLabel, data[i][j]);
}
outTuples.add(tuple);
}
}
} else if (heat) {
//Handle the Tuple and List of Tuples
Object o = evaluated.get("heat");
if(o instanceof Matrix) {
Matrix m = (Matrix) o;
List<String> rowLabels = m.getRowLabels();
List<String> colLabels = m.getColumnLabels();
double[][] data = m.getData();
for (int i = 0; i < data.length; i++) {
String rowLabel = null;
if (rowLabels != null) {
rowLabel = rowLabels.get(i);
} else {
rowLabel = "row"+pad(Integer.toString(i), data.length);
}
double[] row = data[i];
for (int j = 0; j < row.length; j++) {
Tuple tuple = new Tuple(new HashMap());
tuple.put("y", rowLabel);
String colLabel = null;
if (colLabels != null) {
colLabel = colLabels.get(j);
} else {
colLabel = "col" + pad(Integer.toString(j), row.length);
}
tuple.put("x", colLabel);
tuple.put("z", data[i][j]);
outTuples.add(tuple);
}
}
}
}
this.out = outTuples.iterator();
}
public static String pad(String v, int length) {
if(length < 11) {
return v;
} else if(length < 101) {
return prepend(v, 2);
} else if (length < 1001) {
return prepend(v, 3);
} else if(length < 10001){
return prepend(v, 4);
} else {
return prepend(v, 5);
}
}
private static String prepend(String v, int length) {
while(v.length() < length) {
v="0"+v;
}
return v;
}
/** Return the stream sort - ie, the order in which records are returned */
public StreamComparator getStreamSort(){
return null;

View File

@ -1562,58 +1562,24 @@ public class MathExpressionTest extends SolrCloudTestCase {
@Test
public void testZplot() throws Exception {
String cexpr = "let(c=tuple(a=add(1,2), b=add(2,3))," +
" zplot(table=c))";
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
String cexpr = "let(a=array(1,2,3,4)," +
" b=array(10,11,12,13),"+
" zplot(x=a, y=b))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Tuple out = tuples.get(0);
assertEquals(out.getDouble("a").doubleValue(), 3.0, 0.0);
assertEquals(out.getDouble("b").doubleValue(), 5.0, 0.0);
cexpr = "let(c=list(tuple(a=add(1,2), b=add(2,3)), tuple(a=add(1,3), b=add(2,4)))," +
" zplot(table=c))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 2);
out = tuples.get(0);
assertEquals(out.getDouble("a").doubleValue(), 3.0, 0.0);
assertEquals(out.getDouble("b").doubleValue(), 5.0, 0.0);
out = tuples.get(1);
assertEquals(out.getDouble("a").doubleValue(), 4.0, 0.0);
assertEquals(out.getDouble("b").doubleValue(), 6.0, 0.0);
cexpr = "let(a=array(1,2,3,4)," +
" b=array(10,11,12,13),"+
" zplot(x=a, y=b))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 4);
out = tuples.get(0);
Tuple out = tuples.get(0);
assertEquals(out.getDouble("x").doubleValue(), 1.0, 0.0);
assertEquals(out.getDouble("y").doubleValue(), 10.0, 0.0);
@ -1744,6 +1710,152 @@ public class MathExpressionTest extends SolrCloudTestCase {
assertTrue(clusters.contains("cluster3"));
assertTrue(clusters.contains("cluster4"));
assertTrue(clusters.contains("cluster5"));
cexpr = "let(a=matrix(array(0,1,2,3,4,5,6,7,8,9,10,11), array(10,11,12,13,14,15,16,17,18,19,20,21))," +
" zplot(heat=a))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 24);
Tuple tuple = tuples.get(0);
String xLabel = tuple.getString("x");
String yLabel = tuple.getString("y");
Number z = tuple.getLong("z");
assertEquals(xLabel, "col00");
assertEquals(yLabel, "row0");
assertEquals(z.longValue(), 0L);
tuple = tuples.get(1);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "col01");
assertEquals(yLabel, "row0");
assertEquals(z.longValue(), 1L);
tuple = tuples.get(2);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "col02");
assertEquals(yLabel, "row0");
assertEquals(z.longValue(), 2L);
tuple = tuples.get(12);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "col00");
assertEquals(yLabel, "row1");
assertEquals(z.longValue(), 10L);
cexpr = "let(a=transpose(matrix(array(0, 1, 2, 3, 4, 5, 6, 7,8,9,10,11), " +
" array(10,11,12,13,14,15,16,17,18,19,20,21)))," +
" zplot(heat=a))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 24);
tuple = tuples.get(0);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "col0");
assertEquals(yLabel, "row00");
assertEquals(z.longValue(), 0L);
tuple = tuples.get(1);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "col1");
assertEquals(yLabel, "row00");
assertEquals(z.longValue(), 10L);
tuple = tuples.get(2);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "col0");
assertEquals(yLabel, "row01");
assertEquals(z.longValue(), 1L);
tuple = tuples.get(12);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "col0");
assertEquals(yLabel, "row06");
assertEquals(z.longValue(), 6L);
cexpr = "let(a=matrix(array(0, 1, 2, 3, 4, 5, 6, 7,8,9,10,11), " +
" array(10,11,12,13,14,15,16,17,18,19,20,21))," +
" b=setRowLabels(a, array(\"blah1\", \"blah2\")),"+
" c=setColumnLabels(b, array(\"rah1\", \"rah2\", \"rah3\", \"rah4\", \"rah5\", \"rah6\", \"rah7\", \"rah8\", \"rah9\", \"rah10\", \"rah11\", \"rah12\")),"+
" zplot(heat=c))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 24);
tuple = tuples.get(0);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "rah1");
assertEquals(yLabel, "blah1");
assertEquals(z.longValue(), 0L);
tuple = tuples.get(1);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "rah2");
assertEquals(yLabel, "blah1");
assertEquals(z.longValue(), 1L);
tuple = tuples.get(2);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "rah3");
assertEquals(yLabel, "blah1");
assertEquals(z.longValue(), 2L);
tuple = tuples.get(12);
xLabel = tuple.getString("x");
yLabel = tuple.getString("y");
z = tuple.getLong("z");
assertEquals(xLabel, "rah1");
assertEquals(yLabel, "blah2");
assertEquals(z.longValue(), 10L);
}
@ -5137,7 +5249,9 @@ public class MathExpressionTest extends SolrCloudTestCase {
"f=corr(d), " +
"g=corr(d, type=kendalls), " +
"h=corr(d, type=spearmans)," +
"i=corrPValues(f))";
"i=corrPValues(f)," +
" j=getRowLabels(f)," +
" k=getColumnLabels(f))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
@ -5226,6 +5340,20 @@ public class MathExpressionTest extends SolrCloudTestCase {
assertEquals(row3.get(0).doubleValue(), 0.28548201004998375, 0);
assertEquals(row3.get(1).doubleValue(), 0.28548201004998375, 0);
assertEquals(row3.get(2).doubleValue(), 0, 0);
List<String> rowLabels = (List<String>)tuples.get(0).get("j");
assertEquals(rowLabels.size(), 3);
assertEquals(rowLabels.get(0), "col0");
assertEquals(rowLabels.get(1), "col1");
assertEquals(rowLabels.get(2), "col2");
List<String> colLabels = (List<String>)tuples.get(0).get("k");
assertEquals(colLabels.size(), 3);
assertEquals(colLabels.get(0), "col0");
assertEquals(colLabels.get(1), "col1");
assertEquals(colLabels.get(2), "col2");
}
@Test