SOLR-13589: Allow zplot to visualize clusters and convex hulls

This commit is contained in:
Joel Bernstein 2019-07-01 11:29:19 -04:00
parent 0e877aac34
commit 2f6a681b39
6 changed files with 104 additions and 7 deletions

View File

@ -290,6 +290,7 @@ public class Lang {
.withFunctionName("notNull", NotNullEvaluator.class) .withFunctionName("notNull", NotNullEvaluator.class)
.withFunctionName("isNull", IsNullEvaluator.class) .withFunctionName("isNull", IsNullEvaluator.class)
.withFunctionName("matches", MatchesEvaluator.class) .withFunctionName("matches", MatchesEvaluator.class)
.withFunctionName("projectToBorder", ProjectToBorderEvaluator.class)
// Boolean Stream Evaluators // Boolean Stream Evaluators

View File

@ -78,7 +78,11 @@ public class KmeansEvaluator extends RecursiveObjectEvaluator implements TwoValu
for(int i=0; i<data.length; i++) { for(int i=0; i<data.length; i++) {
double[] vec = data[i]; double[] vec = data[i];
points.add(new ClusterPoint(ids.get(i), vec)); if(ids != null) {
points.add(new ClusterPoint(ids.get(i), vec));
} else {
points.add(new ClusterPoint(Integer.toString(i), vec));
}
} }
Map fields = new HashMap(); Map fields = new HashMap();

View File

@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.commons.math3.geometry.euclidean.twod.Euclidean2D;
import org.apache.commons.math3.geometry.euclidean.twod.hull.ConvexHull2D;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.geometry.partitioning.BoundaryProjection;
import org.apache.commons.math3.geometry.partitioning.Region;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class ProjectToBorderEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
private static final long serialVersionUID = 1;
public ProjectToBorderEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object value1, Object value2) throws IOException {
if(!(value1 instanceof ConvexHull2D)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a ConvexHull2D",toExpression(constructingFactory), value1.getClass().getSimpleName()));
}
if(!(value2 instanceof Matrix)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a Matrix",toExpression(constructingFactory), value2.getClass().getSimpleName()));
}
ConvexHull2D convexHull2D = (ConvexHull2D)value1;
Matrix matrix = (Matrix)value2;
double[][] data = matrix.getData();
Region<Euclidean2D> region = convexHull2D.createRegion();
double[][] borderPoints = new double[data.length][2];
int i = 0;
for(double[] row : data) {
BoundaryProjection<Euclidean2D> boundaryProjection = region.projectToBoundary(new Vector2D(row));
Vector2D point = (Vector2D)boundaryProjection.getProjected();
borderPoints[i][0] = point.getX();
borderPoints[i][1] = point.getY();
i++;
}
return new Matrix(borderPoints);
}
}

View File

@ -27,12 +27,15 @@ import java.util.Set;
import org.apache.commons.math3.distribution.IntegerDistribution; import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.distribution.RealDistribution; import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.geometry.Point;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.commons.math3.random.EmpiricalDistribution; import org.apache.commons.math3.random.EmpiricalDistribution;
import org.apache.commons.math3.stat.Frequency; import org.apache.commons.math3.stat.Frequency;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics; import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.commons.math3.util.Precision; import org.apache.commons.math3.util.Precision;
import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator; import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.eval.KmeansEvaluator;
import org.apache.solr.client.solrj.io.eval.StreamEvaluator; import org.apache.solr.client.solrj.io.eval.StreamEvaluator;
import org.apache.solr.client.solrj.io.stream.expr.Explanation; import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType; import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
@ -126,6 +129,7 @@ public class ZplotStream extends TupleStream implements Expressible {
int columns = 0; int columns = 0;
boolean table = false; boolean table = false;
boolean distribution = false; boolean distribution = false;
boolean clusters = false;
for(Map.Entry<String, Object> entry : entries) { for(Map.Entry<String, Object> entry : entries) {
++columns; ++columns;
@ -134,6 +138,8 @@ public class ZplotStream extends TupleStream implements Expressible {
table = true; table = true;
} else if(name.equals("dist")) { } else if(name.equals("dist")) {
distribution = true; distribution = true;
} else if(name.equals("clusters")) {
clusters = true;
} }
Object o = entry.getValue(); Object o = entry.getValue();
@ -181,7 +187,7 @@ public class ZplotStream extends TupleStream implements Expressible {
//Load the values into tuples //Load the values into tuples
List<Tuple> outTuples = new ArrayList(); List<Tuple> outTuples = new ArrayList();
if(!table && !distribution) { if(!table && !distribution && !clusters) {
//Handle the vectors //Handle the vectors
for (int i = 0; i < numTuples; i++) { for (int i = 0; i < numTuples; i++) {
Tuple tuple = new Tuple(new HashMap()); Tuple tuple = new Tuple(new HashMap());
@ -194,13 +200,28 @@ public class ZplotStream extends TupleStream implements Expressible {
} }
//Generate the x axis if the tuples contain y and not x //Generate the x axis if the tuples contain y and not x
if(outTuples.get(0).fields.containsKey("y") && !outTuples.get(0).fields.containsKey("x")) { if (outTuples.get(0).fields.containsKey("y") && !outTuples.get(0).fields.containsKey("x")) {
int x = 0; int x = 0;
for(Tuple tuple : outTuples) { for (Tuple tuple : outTuples) {
tuple.put("x", x++); tuple.put("x", x++);
} }
} }
} else if(clusters) {
Object o = evaluated.get("clusters");
KmeansEvaluator.ClusterTuple ct = (KmeansEvaluator.ClusterTuple)o;
List<CentroidCluster<KmeansEvaluator.ClusterPoint>> cs = ct.getClusters();
int clusterNum = 0;
for(CentroidCluster<KmeansEvaluator.ClusterPoint> c : cs) {
clusterNum++;
List<KmeansEvaluator.ClusterPoint> points = c.getPoints();
for(KmeansEvaluator.ClusterPoint p : points) {
Tuple tuple = new Tuple(new HashMap());
tuple.put("x", p.getPoint()[0]);
tuple.put("y", p.getPoint()[1]);
tuple.put("cluster", "cluster"+clusterNum);
outTuples.add(tuple);
}
}
} else if(distribution) { } else if(distribution) {
Object o = evaluated.get("dist"); Object o = evaluated.get("dist");
if(o instanceof RealDistribution) { if(o instanceof RealDistribution) {

View File

@ -76,7 +76,7 @@ public class TestLang extends SolrTestCase {
"getAmplitude", "getPhase", "getAngularFrequency", "enclosingDisk", "getCenter", "getRadius", "getAmplitude", "getPhase", "getAngularFrequency", "enclosingDisk", "getCenter", "getRadius",
"getSupportPoints", "pairSort", "log10", "plist", "recip", "pivot", "ltrim", "rtrim", "export", "getSupportPoints", "pairSort", "log10", "plist", "recip", "pivot", "ltrim", "rtrim", "export",
"zplot", "natural", "repeat", "movingMAD", "hashRollup", "noop", "var", "stddev", "recNum", "isNull", "zplot", "natural", "repeat", "movingMAD", "hashRollup", "noop", "var", "stddev", "recNum", "isNull",
"notNull", "matches"}; "notNull", "matches", "projectToBorder"};
@Test @Test
public void testLang() { public void testLang() {

View File

@ -407,7 +407,8 @@ public class MathExpressionTest extends SolrCloudTestCase {
" e=getVertices(d)," + " e=getVertices(d)," +
" f=getArea(d)," + " f=getArea(d)," +
" g=getBoundarySize(d)," + " g=getBoundarySize(d)," +
" h=getBaryCenter(d))"; " h=getBaryCenter(d)," +
" i=projectToBorder(d, matrix(array(99.11076410926444, 109.5441846957560))))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", expr); paramsLoc.set("expr", expr);
paramsLoc.set("qt", "/stream"); paramsLoc.set("qt", "/stream");
@ -466,6 +467,11 @@ public class MathExpressionTest extends SolrCloudTestCase {
assertEquals(baryCenter.get(0).doubleValue(), 101.3021125450865, 0.0); assertEquals(baryCenter.get(0).doubleValue(), 101.3021125450865, 0.0);
assertEquals(baryCenter.get(1).doubleValue(), 100.07343616615786, 0.0); assertEquals(baryCenter.get(1).doubleValue(), 100.07343616615786, 0.0);
List<List<Number>> borderPoints = (List<List<Number>>)tuples.get(0).get("i");
assertEquals(borderPoints.get(0).get(0).doubleValue(), 100.31316833934775, 0);
assertEquals(borderPoints.get(0).get(1).doubleValue(), 115.6639686234851, 0);
} }
@Test @Test