mirror of https://github.com/apache/lucene.git
SOLR-13287: Allow zplot to visualize probability distributions in Apache Zeppelin
This commit is contained in:
parent
7bfe7b265a
commit
c34c56b7b2
|
@ -14,6 +14,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -24,27 +25,34 @@ import org.apache.commons.math3.random.EmpiricalDistribution;
|
|||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EmpiricalDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
|
||||
public class EmpiricalDistributionEvaluator extends RecursiveNumericEvaluator implements ManyValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
private int bins = 99;
|
||||
|
||||
public EmpiricalDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
||||
if(1 != containedEvaluators.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly one value but found %d",expression,containedEvaluators.size()));
|
||||
if(2 != containedEvaluators.size() && 1 != containedEvaluators.size()) {
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting one or two values but found %d",expression,containedEvaluators.size()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object value) throws IOException {
|
||||
public Object doWork(Object[] values) throws IOException {
|
||||
|
||||
if(!(value instanceof List<?>)){
|
||||
throw new StreamEvaluatorException("List value expected but found type %s for value %s", value.getClass().getName(), value.toString());
|
||||
if(!(values[0] instanceof List<?>)){
|
||||
throw new StreamEvaluatorException("List value expected but found type %s for value %s", values[0].getClass().getName(), values[0].toString());
|
||||
}
|
||||
|
||||
EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution();
|
||||
if(values.length == 2) {
|
||||
Number n = (Number)values[1];
|
||||
bins = n.intValue();
|
||||
}
|
||||
|
||||
EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution(bins);
|
||||
|
||||
double[] backingValues = ((List<?>)value).stream().mapToDouble(innerValue -> ((Number)innerValue).doubleValue()).sorted().toArray();
|
||||
double[] backingValues = ((List<?>)values[0]).stream().mapToDouble(innerValue -> ((Number)innerValue).doubleValue()).sorted().toArray();
|
||||
|
||||
empiricalDistribution.load(backingValues);
|
||||
|
||||
return empiricalDistribution;
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.List;
|
|||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Iterator;
|
||||
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.comp.SingleValueComparator;
|
||||
|
@ -45,16 +46,19 @@ public class TupStream extends TupleStream implements Expressible {
|
|||
|
||||
private static final long serialVersionUID = 1;
|
||||
private StreamContext streamContext;
|
||||
|
||||
|
||||
private Map<String,String> stringParams = new HashMap<>();
|
||||
private Map<String,StreamEvaluator> evaluatorParams = new HashMap<>();
|
||||
private Map<String,TupleStream> streamParams = new HashMap<>();
|
||||
private List<String> fieldNames = new ArrayList();
|
||||
private Map<String, String> fieldLabels = new HashMap();
|
||||
private Tuple tup = null;
|
||||
private Tuple unnestedTuple = null;
|
||||
private Iterator<Tuple> unnestedTuples = null;
|
||||
|
||||
private boolean finished;
|
||||
|
||||
|
||||
public TupStream(StreamExpression expression, StreamFactory factory) throws IOException {
|
||||
|
||||
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
|
||||
|
@ -146,13 +150,27 @@ public class TupStream extends TupleStream implements Expressible {
|
|||
|
||||
public Tuple read() throws IOException {
|
||||
|
||||
if(finished) {
|
||||
Map<String,Object> m = new HashMap<>();
|
||||
m.put("EOF", true);
|
||||
return new Tuple(m);
|
||||
if(unnestedTuples == null) {
|
||||
if (finished) {
|
||||
Map<String, Object> m = new HashMap<>();
|
||||
m.put("EOF", true);
|
||||
return new Tuple(m);
|
||||
} else {
|
||||
finished = true;
|
||||
if(unnestedTuple != null) {
|
||||
return unnestedTuple;
|
||||
} else {
|
||||
return tup;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
finished = true;
|
||||
return tup;
|
||||
if(unnestedTuples.hasNext()) {
|
||||
return unnestedTuples.next();
|
||||
} else {
|
||||
Map<String, Object> m = new HashMap<>();
|
||||
m.put("EOF", true);
|
||||
return new Tuple(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -202,6 +220,19 @@ public class TupStream extends TupleStream implements Expressible {
|
|||
}
|
||||
}
|
||||
|
||||
if(values.size() == 1) {
|
||||
for(Object o :values.values()) {
|
||||
if(o instanceof Tuple) {
|
||||
unnestedTuple = (Tuple)o;
|
||||
} else if(o instanceof List) {
|
||||
List l = (List)o;
|
||||
if(l.size() > 0 && l.get(0) instanceof Tuple) {
|
||||
List<Tuple> tl = (List<Tuple>)l;
|
||||
unnestedTuples = tl.iterator();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
this.tup = new Tuple(values);
|
||||
tup.fieldNames = fieldNames;
|
||||
tup.fieldLabels = fieldLabels;
|
||||
|
|
|
@ -25,6 +25,12 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.random.EmpiricalDistribution;
|
||||
import org.apache.commons.math3.stat.Frequency;
|
||||
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
|
||||
import org.apache.commons.math3.util.Precision;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.comp.StreamComparator;
|
||||
import org.apache.solr.client.solrj.io.eval.StreamEvaluator;
|
||||
|
@ -119,12 +125,15 @@ public class ZplotStream extends TupleStream implements Expressible {
|
|||
int numTuples = -1;
|
||||
int columns = 0;
|
||||
boolean table = false;
|
||||
boolean distribution = false;
|
||||
for(Map.Entry<String, Object> entry : entries) {
|
||||
++columns;
|
||||
|
||||
String name = entry.getKey();
|
||||
if(name.equals("table")) {
|
||||
table = true;
|
||||
} else if(name.equals("dist")) {
|
||||
distribution = true;
|
||||
}
|
||||
|
||||
Object o = entry.getValue();
|
||||
|
@ -145,6 +154,8 @@ public class ZplotStream extends TupleStream implements Expressible {
|
|||
evaluated.put(name, l);
|
||||
} else if (eo instanceof Tuple) {
|
||||
evaluated.put(name, eo);
|
||||
} else {
|
||||
evaluated.put(name, eo);
|
||||
}
|
||||
} else {
|
||||
Object eval = lets.get(o);
|
||||
|
@ -164,13 +175,13 @@ public class ZplotStream extends TupleStream implements Expressible {
|
|||
}
|
||||
}
|
||||
|
||||
if(columns > 1 && table) {
|
||||
throw new IOException("If the table parameter is set there can only be one parameter.");
|
||||
if(columns > 1 && (table || distribution)) {
|
||||
throw new IOException("If the table or dist parameter is set there can only be one parameter.");
|
||||
}
|
||||
//Load the values into tuples
|
||||
|
||||
List<Tuple> outTuples = new ArrayList();
|
||||
if(!table) {
|
||||
if(!table && !distribution) {
|
||||
//Handle the vectors
|
||||
for (int i = 0; i < numTuples; i++) {
|
||||
Tuple tuple = new Tuple(new HashMap());
|
||||
|
@ -181,7 +192,94 @@ public class ZplotStream extends TupleStream implements Expressible {
|
|||
|
||||
outTuples.add(tuple);
|
||||
}
|
||||
} else {
|
||||
} else if(distribution) {
|
||||
Object o = evaluated.get("dist");
|
||||
if(o instanceof RealDistribution) {
|
||||
RealDistribution realDistribution = (RealDistribution) o;
|
||||
List<SummaryStatistics> binStats = null;
|
||||
if(realDistribution instanceof EmpiricalDistribution) {
|
||||
EmpiricalDistribution empiricalDistribution = (EmpiricalDistribution)realDistribution;
|
||||
binStats = empiricalDistribution.getBinStats();
|
||||
} else {
|
||||
double[] samples = realDistribution.sample(500000);
|
||||
EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution(32);
|
||||
empiricalDistribution.load(samples);
|
||||
binStats = empiricalDistribution.getBinStats();
|
||||
}
|
||||
double[] x = new double[binStats.size()];
|
||||
double[] y = new double[binStats.size()];
|
||||
for (int i = 0; i < binStats.size(); i++) {
|
||||
x[i] = binStats.get(i).getMean();
|
||||
y[i] = realDistribution.density(x[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < x.length; i++) {
|
||||
Tuple tuple = new Tuple(new HashMap());
|
||||
if(!Double.isNaN(x[i])) {
|
||||
tuple.put("x", Precision.round(x[i], 2));
|
||||
if(y[i] == Double.NEGATIVE_INFINITY || y[i] == Double.POSITIVE_INFINITY) {
|
||||
tuple.put("y", 0);
|
||||
|
||||
} else {
|
||||
tuple.put("y", y[i]);
|
||||
}
|
||||
outTuples.add(tuple);
|
||||
}
|
||||
}
|
||||
} else if(o instanceof IntegerDistribution) {
|
||||
IntegerDistribution integerDistribution = (IntegerDistribution)o;
|
||||
int[] samples = integerDistribution.sample(50000);
|
||||
Frequency frequency = new Frequency();
|
||||
for(int i : samples) {
|
||||
frequency.addValue(i);
|
||||
}
|
||||
|
||||
Iterator it = frequency.valuesIterator();
|
||||
List<Long> values = new ArrayList();
|
||||
while(it.hasNext()) {
|
||||
values.add((Long)it.next());
|
||||
}
|
||||
System.out.println(values);
|
||||
int[] x = new int[values.size()];
|
||||
double[] y = new double[values.size()];
|
||||
for(int i=0; i<values.size(); i++) {
|
||||
x[i] = values.get(i).intValue();
|
||||
y[i] = integerDistribution.probability(x[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < x.length; i++) {
|
||||
Tuple tuple = new Tuple(new HashMap());
|
||||
tuple.put("x", x[i]);
|
||||
tuple.put("y", y[i]);
|
||||
outTuples.add(tuple);
|
||||
}
|
||||
} else if(o instanceof List) {
|
||||
System.out.print("Is list");
|
||||
List list = (List)o;
|
||||
if(list.get(0) instanceof Tuple) {
|
||||
System.out.print("Are tuples");
|
||||
List<Tuple> tlist = (List<Tuple>)o;
|
||||
Tuple tuple = tlist.get(0);
|
||||
if(tuple.fields.containsKey("N")) {
|
||||
System.out.println("Is hist");
|
||||
for(Tuple t : tlist) {
|
||||
Tuple outtuple = new Tuple(new HashMap());
|
||||
outtuple.put("x", Precision.round(((double)t.get("mean")), 2));
|
||||
outtuple.put("y", t.get("prob"));
|
||||
outTuples.add(outtuple);
|
||||
}
|
||||
} else if(tuple.fields.containsKey("count")) {
|
||||
System.out.println("Is freq");
|
||||
for(Tuple t : tlist) {
|
||||
Tuple outtuple = new Tuple(new HashMap());
|
||||
outtuple.put("x", t.get("value"));
|
||||
outtuple.put("y", t.get("pct"));
|
||||
outTuples.add(outtuple);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if(table){
|
||||
//Handle the Tuple and List of Tuples
|
||||
Object o = evaluated.get("table");
|
||||
if(o instanceof List) {
|
||||
|
|
|
@ -369,11 +369,9 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Map> hist = (List<Map>)tuples.get(0).get("return-value");
|
||||
assertTrue(hist.size() == 10);
|
||||
for(int i=0; i<hist.size(); i++) {
|
||||
Map stats = hist.get(i);
|
||||
assertTrue(tuples.size() == 10);
|
||||
for(int i=0; i<tuples.size(); i++) {
|
||||
Tuple stats = tuples.get(i);
|
||||
assertTrue(((Number)stats.get("N")).intValue() == 10);
|
||||
assertTrue(((Number)stats.get("min")).intValue() == 10*i);
|
||||
assertTrue(((Number)stats.get("var")).doubleValue() == 9.166666666666666);
|
||||
|
@ -388,11 +386,10 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
solrStream = new SolrStream(url, paramsLoc);
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
hist = (List<Map>)tuples.get(0).get("return-value");
|
||||
assertTrue(hist.size() == 5);
|
||||
for(int i=0; i<hist.size(); i++) {
|
||||
Map stats = hist.get(i);
|
||||
assertTrue(tuples.size() == 5);
|
||||
|
||||
for(int i=0; i<tuples.size(); i++) {
|
||||
Tuple stats = tuples.get(i);
|
||||
assertTrue(((Number)stats.get("N")).intValue() == 20);
|
||||
assertTrue(((Number)stats.get("min")).intValue() == 20*i);
|
||||
assertTrue(((Number)stats.get("var")).doubleValue() == 35);
|
||||
|
@ -1476,6 +1473,56 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
assertEquals(out.getDouble("x").doubleValue(), 4.0, 0.0);
|
||||
assertEquals(out.getDouble("y").doubleValue(), 13.0, 0.0);
|
||||
|
||||
cexpr = "zplot(dist=binomialDistribution(10, .50))";
|
||||
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
solrStream = new SolrStream(url, paramsLoc);
|
||||
context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
assertEquals(tuples.size(),11);
|
||||
long x = tuples.get(5).getLong("x");
|
||||
double y = tuples.get(5).getDouble("y");
|
||||
|
||||
assertEquals(x, 5);
|
||||
assertEquals(y, 0.24609375000000003, 0);
|
||||
|
||||
//Due to random errors (bugs) in Apache Commons Math EmpiricalDistribution
|
||||
//there are times when tuples are discarded because
|
||||
//they contain values with NaN values. This will occur
|
||||
//only on the very end of the tails of the normal distribution or other
|
||||
//real distributions and doesn't effect the visual quality of the curve very much.
|
||||
//But it does effect the reliability of tests.
|
||||
//For this reason the loop below is in place to run the test N times looking
|
||||
//for the correct number of tuples before asserting the mean.
|
||||
|
||||
int n = 0;
|
||||
int limit = 15;
|
||||
while(true) {
|
||||
cexpr = "zplot(dist=normalDistribution(100, 10))";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
solrStream = new SolrStream(url, paramsLoc);
|
||||
context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
//Assert the mean
|
||||
if (tuples.size() == 32) {
|
||||
double x1 = tuples.get(15).getDouble("x");
|
||||
double y1 = tuples.get(15).getDouble("y");
|
||||
assertEquals(x1, 100, 10);
|
||||
assertEquals(y1, .039, .02);
|
||||
break;
|
||||
} else {
|
||||
++n;
|
||||
if(n == limit) {
|
||||
throw new Exception("Reached iterations limit without correct tuple count.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -1751,14 +1798,13 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Map<String, Number>> out = (List<Map<String, Number>>)tuples.get(0).get("f");
|
||||
assertEquals(out.size(), 2);
|
||||
Map<String, Number> bin0 = out.get(0);
|
||||
double state0Pct = bin0.get("pct").doubleValue();
|
||||
assertTrue(tuples.size() == 2);
|
||||
|
||||
Tuple bin0 = tuples.get(0);
|
||||
double state0Pct = bin0.getDouble("pct");
|
||||
assertEquals(state0Pct, .5, .015);
|
||||
Map<String, Number> bin1 = out.get(1);
|
||||
double state1Pct = bin1.get("pct").doubleValue();
|
||||
Tuple bin1 = tuples.get(1);
|
||||
double state1Pct = bin1.getDouble("pct");
|
||||
assertEquals(state1Pct, .5, .015);
|
||||
}
|
||||
|
||||
|
@ -2933,32 +2979,30 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Map<String,Number>> out = (List<Map<String, Number>>)tuples.get(0).get("return-value");
|
||||
assertTrue(out.size() == 6);
|
||||
Map<String, Number> bucket = out.get(0);
|
||||
assertEquals(bucket.get("value").longValue(), 2);
|
||||
assertEquals(bucket.get("count").longValue(), 2);
|
||||
assertTrue(tuples.size() == 6);
|
||||
Tuple bucket = tuples.get(0);
|
||||
assertEquals(bucket.getLong("value").longValue(), 2);
|
||||
assertEquals(bucket.getLong("count").longValue(), 2);
|
||||
|
||||
bucket = out.get(1);
|
||||
assertEquals(bucket.get("value").longValue(), 4);
|
||||
assertEquals(bucket.get("count").longValue(), 2);
|
||||
bucket = tuples.get(1);
|
||||
assertEquals(bucket.getLong("value").longValue(), 4);
|
||||
assertEquals(bucket.getLong("count").longValue(), 2);
|
||||
|
||||
bucket = out.get(2);
|
||||
assertEquals(bucket.get("value").longValue(), 6);
|
||||
assertEquals(bucket.get("count").longValue(), 1);
|
||||
bucket = tuples.get(2);
|
||||
assertEquals(bucket.getLong("value").longValue(), 6);
|
||||
assertEquals(bucket.getLong("count").longValue(), 1);
|
||||
|
||||
bucket = out.get(3);
|
||||
assertEquals(bucket.get("value").longValue(), 8);
|
||||
assertEquals(bucket.get("count").longValue(), 4);
|
||||
bucket = tuples.get(3);
|
||||
assertEquals(bucket.getLong("value").longValue(), 8);
|
||||
assertEquals(bucket.getLong("count").longValue(), 4);
|
||||
|
||||
bucket = out.get(4);
|
||||
assertEquals(bucket.get("value").longValue(), 10);
|
||||
assertEquals(bucket.get("count").longValue(), 1);
|
||||
bucket = tuples.get(4);
|
||||
assertEquals(bucket.getLong("value").longValue(), 10);
|
||||
assertEquals(bucket.getLong("count").longValue(), 1);
|
||||
|
||||
bucket = out.get(5);
|
||||
assertEquals(bucket.get("value").longValue(), 12);
|
||||
assertEquals(bucket.get("count").longValue(), 2);
|
||||
bucket = tuples.get(5);
|
||||
assertEquals(bucket.getLong("value").longValue(), 12);
|
||||
assertEquals(bucket.getLong("count").longValue(), 2);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -4062,7 +4106,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Map out = (Map)tuples.get(0).get("return-value");
|
||||
Tuple out = tuples.get(0);
|
||||
assertEquals((double) out.get("p-value"), 0.788298D, .0001);
|
||||
assertEquals((double) out.get("f-ratio"), 0.24169D, .0001);
|
||||
}
|
||||
|
@ -4347,7 +4391,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Map out = (Map)tuples.get(0).get("return-value");
|
||||
Tuple out = tuples.get(0);
|
||||
assertEquals((double) out.get("u-statistic"), 52.5, .1);
|
||||
assertEquals((double) out.get("p-value"), 0.7284, .001);
|
||||
}
|
||||
|
@ -5142,7 +5186,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
|
||||
String expr1 = "search("+COLLECTIONORALIAS+", q=\"col_s:a\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
|
||||
|
||||
String cexpr = "let(a="+expr1+", b=col(a, price_f), tuple(stats=describe(b)))";
|
||||
String cexpr = "let(a="+expr1+", b=col(a, price_f), stats=describe(b))";
|
||||
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
|
@ -5155,8 +5199,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Tuple tuple = tuples.get(0);
|
||||
Map stats = (Map)tuple.get("stats");
|
||||
Tuple stats = tuples.get(0);
|
||||
Number min = (Number)stats.get("min");
|
||||
Number max = (Number)stats.get("max");
|
||||
Number mean = (Number)stats.get("mean");
|
||||
|
|
Loading…
Reference in New Issue