SOLR-13287: Allow zplot to visualize probability distributions in Apache Zeppelin

This commit is contained in:
Joel Bernstein 2019-03-05 09:18:47 -05:00
parent 7bfe7b265a
commit c34c56b7b2
4 changed files with 242 additions and 62 deletions

View File

@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
@ -24,27 +25,34 @@ import org.apache.commons.math3.random.EmpiricalDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class EmpiricalDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
public class EmpiricalDistributionEvaluator extends RecursiveNumericEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L;
private int bins = 99;
public EmpiricalDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(1 != containedEvaluators.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly one value but found %d",expression,containedEvaluators.size()));
if(2 != containedEvaluators.size() && 1 != containedEvaluators.size()) {
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting one or two values but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object value) throws IOException {
public Object doWork(Object[] values) throws IOException {
if(!(value instanceof List<?>)){
throw new StreamEvaluatorException("List value expected but found type %s for value %s", value.getClass().getName(), value.toString());
if(!(values[0] instanceof List<?>)){
throw new StreamEvaluatorException("List value expected but found type %s for value %s", values[0].getClass().getName(), values[0].toString());
}
EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution();
if(values.length == 2) {
Number n = (Number)values[1];
bins = n.intValue();
}
EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution(bins);
double[] backingValues = ((List<?>)value).stream().mapToDouble(innerValue -> ((Number)innerValue).doubleValue()).sorted().toArray();
double[] backingValues = ((List<?>)values[0]).stream().mapToDouble(innerValue -> ((Number)innerValue).doubleValue()).sorted().toArray();
empiricalDistribution.load(backingValues);
return empiricalDistribution;

View File

@ -23,6 +23,7 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Iterator;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.SingleValueComparator;
@ -45,16 +46,19 @@ public class TupStream extends TupleStream implements Expressible {
private static final long serialVersionUID = 1;
private StreamContext streamContext;
private Map<String,String> stringParams = new HashMap<>();
private Map<String,StreamEvaluator> evaluatorParams = new HashMap<>();
private Map<String,TupleStream> streamParams = new HashMap<>();
private List<String> fieldNames = new ArrayList();
private Map<String, String> fieldLabels = new HashMap();
private Tuple tup = null;
private Tuple unnestedTuple = null;
private Iterator<Tuple> unnestedTuples = null;
private boolean finished;
public TupStream(StreamExpression expression, StreamFactory factory) throws IOException {
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
@ -146,13 +150,27 @@ public class TupStream extends TupleStream implements Expressible {
public Tuple read() throws IOException {
if(finished) {
Map<String,Object> m = new HashMap<>();
m.put("EOF", true);
return new Tuple(m);
if(unnestedTuples == null) {
if (finished) {
Map<String, Object> m = new HashMap<>();
m.put("EOF", true);
return new Tuple(m);
} else {
finished = true;
if(unnestedTuple != null) {
return unnestedTuple;
} else {
return tup;
}
}
} else {
finished = true;
return tup;
if(unnestedTuples.hasNext()) {
return unnestedTuples.next();
} else {
Map<String, Object> m = new HashMap<>();
m.put("EOF", true);
return new Tuple(m);
}
}
}
@ -202,6 +220,19 @@ public class TupStream extends TupleStream implements Expressible {
}
}
if(values.size() == 1) {
for(Object o :values.values()) {
if(o instanceof Tuple) {
unnestedTuple = (Tuple)o;
} else if(o instanceof List) {
List l = (List)o;
if(l.size() > 0 && l.get(0) instanceof Tuple) {
List<Tuple> tl = (List<Tuple>)l;
unnestedTuples = tl.iterator();
}
}
}
}
this.tup = new Tuple(values);
tup.fieldNames = fieldNames;
tup.fieldLabels = fieldLabels;

View File

@ -25,6 +25,12 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.EmpiricalDistribution;
import org.apache.commons.math3.stat.Frequency;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.commons.math3.util.Precision;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.eval.StreamEvaluator;
@ -119,12 +125,15 @@ public class ZplotStream extends TupleStream implements Expressible {
int numTuples = -1;
int columns = 0;
boolean table = false;
boolean distribution = false;
for(Map.Entry<String, Object> entry : entries) {
++columns;
String name = entry.getKey();
if(name.equals("table")) {
table = true;
} else if(name.equals("dist")) {
distribution = true;
}
Object o = entry.getValue();
@ -145,6 +154,8 @@ public class ZplotStream extends TupleStream implements Expressible {
evaluated.put(name, l);
} else if (eo instanceof Tuple) {
evaluated.put(name, eo);
} else {
evaluated.put(name, eo);
}
} else {
Object eval = lets.get(o);
@ -164,13 +175,13 @@ public class ZplotStream extends TupleStream implements Expressible {
}
}
if(columns > 1 && table) {
throw new IOException("If the table parameter is set there can only be one parameter.");
if(columns > 1 && (table || distribution)) {
throw new IOException("If the table or dist parameter is set there can only be one parameter.");
}
//Load the values into tuples
List<Tuple> outTuples = new ArrayList();
if(!table) {
if(!table && !distribution) {
//Handle the vectors
for (int i = 0; i < numTuples; i++) {
Tuple tuple = new Tuple(new HashMap());
@ -181,7 +192,94 @@ public class ZplotStream extends TupleStream implements Expressible {
outTuples.add(tuple);
}
} else {
} else if(distribution) {
Object o = evaluated.get("dist");
if(o instanceof RealDistribution) {
RealDistribution realDistribution = (RealDistribution) o;
List<SummaryStatistics> binStats = null;
if(realDistribution instanceof EmpiricalDistribution) {
EmpiricalDistribution empiricalDistribution = (EmpiricalDistribution)realDistribution;
binStats = empiricalDistribution.getBinStats();
} else {
double[] samples = realDistribution.sample(500000);
EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution(32);
empiricalDistribution.load(samples);
binStats = empiricalDistribution.getBinStats();
}
double[] x = new double[binStats.size()];
double[] y = new double[binStats.size()];
for (int i = 0; i < binStats.size(); i++) {
x[i] = binStats.get(i).getMean();
y[i] = realDistribution.density(x[i]);
}
for (int i = 0; i < x.length; i++) {
Tuple tuple = new Tuple(new HashMap());
if(!Double.isNaN(x[i])) {
tuple.put("x", Precision.round(x[i], 2));
if(y[i] == Double.NEGATIVE_INFINITY || y[i] == Double.POSITIVE_INFINITY) {
tuple.put("y", 0);
} else {
tuple.put("y", y[i]);
}
outTuples.add(tuple);
}
}
} else if(o instanceof IntegerDistribution) {
IntegerDistribution integerDistribution = (IntegerDistribution)o;
int[] samples = integerDistribution.sample(50000);
Frequency frequency = new Frequency();
for(int i : samples) {
frequency.addValue(i);
}
Iterator it = frequency.valuesIterator();
List<Long> values = new ArrayList();
while(it.hasNext()) {
values.add((Long)it.next());
}
System.out.println(values);
int[] x = new int[values.size()];
double[] y = new double[values.size()];
for(int i=0; i<values.size(); i++) {
x[i] = values.get(i).intValue();
y[i] = integerDistribution.probability(x[i]);
}
for (int i = 0; i < x.length; i++) {
Tuple tuple = new Tuple(new HashMap());
tuple.put("x", x[i]);
tuple.put("y", y[i]);
outTuples.add(tuple);
}
} else if(o instanceof List) {
System.out.print("Is list");
List list = (List)o;
if(list.get(0) instanceof Tuple) {
System.out.print("Are tuples");
List<Tuple> tlist = (List<Tuple>)o;
Tuple tuple = tlist.get(0);
if(tuple.fields.containsKey("N")) {
System.out.println("Is hist");
for(Tuple t : tlist) {
Tuple outtuple = new Tuple(new HashMap());
outtuple.put("x", Precision.round(((double)t.get("mean")), 2));
outtuple.put("y", t.get("prob"));
outTuples.add(outtuple);
}
} else if(tuple.fields.containsKey("count")) {
System.out.println("Is freq");
for(Tuple t : tlist) {
Tuple outtuple = new Tuple(new HashMap());
outtuple.put("x", t.get("value"));
outtuple.put("y", t.get("pct"));
outTuples.add(outtuple);
}
}
}
}
} else if(table){
//Handle the Tuple and List of Tuples
Object o = evaluated.get("table");
if(o instanceof List) {

View File

@ -369,11 +369,9 @@ public class MathExpressionTest extends SolrCloudTestCase {
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<Map> hist = (List<Map>)tuples.get(0).get("return-value");
assertTrue(hist.size() == 10);
for(int i=0; i<hist.size(); i++) {
Map stats = hist.get(i);
assertTrue(tuples.size() == 10);
for(int i=0; i<tuples.size(); i++) {
Tuple stats = tuples.get(i);
assertTrue(((Number)stats.get("N")).intValue() == 10);
assertTrue(((Number)stats.get("min")).intValue() == 10*i);
assertTrue(((Number)stats.get("var")).doubleValue() == 9.166666666666666);
@ -388,11 +386,10 @@ public class MathExpressionTest extends SolrCloudTestCase {
solrStream = new SolrStream(url, paramsLoc);
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
hist = (List<Map>)tuples.get(0).get("return-value");
assertTrue(hist.size() == 5);
for(int i=0; i<hist.size(); i++) {
Map stats = hist.get(i);
assertTrue(tuples.size() == 5);
for(int i=0; i<tuples.size(); i++) {
Tuple stats = tuples.get(i);
assertTrue(((Number)stats.get("N")).intValue() == 20);
assertTrue(((Number)stats.get("min")).intValue() == 20*i);
assertTrue(((Number)stats.get("var")).doubleValue() == 35);
@ -1476,6 +1473,56 @@ public class MathExpressionTest extends SolrCloudTestCase {
assertEquals(out.getDouble("x").doubleValue(), 4.0, 0.0);
assertEquals(out.getDouble("y").doubleValue(), 13.0, 0.0);
cexpr = "zplot(dist=binomialDistribution(10, .50))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertEquals(tuples.size(),11);
long x = tuples.get(5).getLong("x");
double y = tuples.get(5).getDouble("y");
assertEquals(x, 5);
assertEquals(y, 0.24609375000000003, 0);
//Due to random errors (bugs) in Apache Commons Math EmpiricalDistribution
//there are times when tuples are discarded because
//they contain values with NaN values. This will occur
//only on the very end of the tails of the normal distribution or other
//real distributions and doesn't effect the visual quality of the curve very much.
//But it does effect the reliability of tests.
//For this reason the loop below is in place to run the test N times looking
//for the correct number of tuples before asserting the mean.
int n = 0;
int limit = 15;
while(true) {
cexpr = "zplot(dist=normalDistribution(100, 10))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
//Assert the mean
if (tuples.size() == 32) {
double x1 = tuples.get(15).getDouble("x");
double y1 = tuples.get(15).getDouble("y");
assertEquals(x1, 100, 10);
assertEquals(y1, .039, .02);
break;
} else {
++n;
if(n == limit) {
throw new Exception("Reached iterations limit without correct tuple count.");
}
}
}
}
@ -1751,14 +1798,13 @@ public class MathExpressionTest extends SolrCloudTestCase {
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<Map<String, Number>> out = (List<Map<String, Number>>)tuples.get(0).get("f");
assertEquals(out.size(), 2);
Map<String, Number> bin0 = out.get(0);
double state0Pct = bin0.get("pct").doubleValue();
assertTrue(tuples.size() == 2);
Tuple bin0 = tuples.get(0);
double state0Pct = bin0.getDouble("pct");
assertEquals(state0Pct, .5, .015);
Map<String, Number> bin1 = out.get(1);
double state1Pct = bin1.get("pct").doubleValue();
Tuple bin1 = tuples.get(1);
double state1Pct = bin1.getDouble("pct");
assertEquals(state1Pct, .5, .015);
}
@ -2933,32 +2979,30 @@ public class MathExpressionTest extends SolrCloudTestCase {
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<Map<String,Number>> out = (List<Map<String, Number>>)tuples.get(0).get("return-value");
assertTrue(out.size() == 6);
Map<String, Number> bucket = out.get(0);
assertEquals(bucket.get("value").longValue(), 2);
assertEquals(bucket.get("count").longValue(), 2);
assertTrue(tuples.size() == 6);
Tuple bucket = tuples.get(0);
assertEquals(bucket.getLong("value").longValue(), 2);
assertEquals(bucket.getLong("count").longValue(), 2);
bucket = out.get(1);
assertEquals(bucket.get("value").longValue(), 4);
assertEquals(bucket.get("count").longValue(), 2);
bucket = tuples.get(1);
assertEquals(bucket.getLong("value").longValue(), 4);
assertEquals(bucket.getLong("count").longValue(), 2);
bucket = out.get(2);
assertEquals(bucket.get("value").longValue(), 6);
assertEquals(bucket.get("count").longValue(), 1);
bucket = tuples.get(2);
assertEquals(bucket.getLong("value").longValue(), 6);
assertEquals(bucket.getLong("count").longValue(), 1);
bucket = out.get(3);
assertEquals(bucket.get("value").longValue(), 8);
assertEquals(bucket.get("count").longValue(), 4);
bucket = tuples.get(3);
assertEquals(bucket.getLong("value").longValue(), 8);
assertEquals(bucket.getLong("count").longValue(), 4);
bucket = out.get(4);
assertEquals(bucket.get("value").longValue(), 10);
assertEquals(bucket.get("count").longValue(), 1);
bucket = tuples.get(4);
assertEquals(bucket.getLong("value").longValue(), 10);
assertEquals(bucket.getLong("count").longValue(), 1);
bucket = out.get(5);
assertEquals(bucket.get("value").longValue(), 12);
assertEquals(bucket.get("count").longValue(), 2);
bucket = tuples.get(5);
assertEquals(bucket.getLong("value").longValue(), 12);
assertEquals(bucket.getLong("count").longValue(), 2);
}
@Test
@ -4062,7 +4106,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Map out = (Map)tuples.get(0).get("return-value");
Tuple out = tuples.get(0);
assertEquals((double) out.get("p-value"), 0.788298D, .0001);
assertEquals((double) out.get("f-ratio"), 0.24169D, .0001);
}
@ -4347,7 +4391,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Map out = (Map)tuples.get(0).get("return-value");
Tuple out = tuples.get(0);
assertEquals((double) out.get("u-statistic"), 52.5, .1);
assertEquals((double) out.get("p-value"), 0.7284, .001);
}
@ -5142,7 +5186,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
String expr1 = "search("+COLLECTIONORALIAS+", q=\"col_s:a\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
String cexpr = "let(a="+expr1+", b=col(a, price_f), tuple(stats=describe(b)))";
String cexpr = "let(a="+expr1+", b=col(a, price_f), stats=describe(b))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
@ -5155,8 +5199,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Tuple tuple = tuples.get(0);
Map stats = (Map)tuple.get("stats");
Tuple stats = tuples.get(0);
Number min = (Number)stats.get("min");
Number max = (Number)stats.get("max");
Number mean = (Number)stats.get("mean");