SOLR-11737: Add kmeans Stream Evaluator to support kmeans clustering

This commit is contained in:
Joel Bernstein 2018-01-15 14:50:17 -05:00
parent d99bfa4bdb
commit a08f71279c
14 changed files with 897 additions and 8 deletions

View File

@ -296,6 +296,15 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("getColumnLabels", GetColumnLabelsEvaluator.class)
.withFunctionName("getRowLabels", GetRowLabelsEvaluator.class)
.withFunctionName("getAttribute", GetAttributeEvaluator.class)
.withFunctionName("kmeans", KmeansEvaluator.class)
.withFunctionName("getCentroids", GetCentroidsEvaluator.class)
.withFunctionName("getCluster", GetClusterEvaluator.class)
.withFunctionName("topFeatures", TopFeaturesEvaluator.class)
.withFunctionName("featureSelect", FeatureSelectEvaluator.class)
.withFunctionName("rowAt", RowAtEvaluator.class)
.withFunctionName("colAt", ColumnAtEvaluator.class)
.withFunctionName("setColumnLabels", SetColumnLabelsEvaluator.class)
.withFunctionName("setRowLabels", SetRowLabelsEvaluator.class)
// Boolean Stream Evaluators

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import java.util.List;
import java.util.ArrayList;
public class ColumnAtEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L;
public ColumnAtEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(2 != containedEvaluators.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object value1, Object value2) throws IOException {
if(value1 instanceof Matrix) {
Matrix matrix = (Matrix) value1;
Number index = (Number) value2;
double[][] data = matrix.getData();
List<Number> list = new ArrayList();
for(double[] row : data) {
list.add(row[index.intValue()]);
}
return list;
} else {
throw new IOException("The rowAt function expects a matrix as the first parameter");
}
}
}

View File

@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.HashSet;
import java.util.Locale;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import java.util.List;
import java.util.Set;
import java.util.ArrayList;
public class FeatureSelectEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L;
public FeatureSelectEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(2 != containedEvaluators.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object value1, Object value2) throws IOException {
if(value1 instanceof Matrix) {
Matrix matrix = (Matrix) value1;
double[][] data = matrix.getData();
List<String> labels = matrix.getColumnLabels();
Set<String> features = new HashSet();
loadFeatures(value2, features);
List<String> newColumnLabels = new ArrayList();
for(String label : labels) {
if(features.contains(label)) {
newColumnLabels.add(label);
}
}
double[][] selectFeatures = new double[data.length][newColumnLabels.size()];
for(int i=0; i<data.length; i++) {
double[] currentRow = data[i];
double[] newRow = new double[newColumnLabels.size()];
int index = -1;
for(int l=0; l<currentRow.length; l++) {
String label = labels.get(l);
if(features.contains(label)) {
newRow[++index] = currentRow[l];
}
}
selectFeatures[i] = newRow;
}
Matrix newMatrix = new Matrix(selectFeatures);
newMatrix.setRowLabels(matrix.getRowLabels());
newMatrix.setColumnLabels(newColumnLabels);
return newMatrix;
} else {
throw new IOException("The featureSelect function expects a matrix as a parameter");
}
}
private void loadFeatures(Object o, Set<String> features) {
List list = (List)o;
for(Object v : list) {
if(v instanceof List) {
loadFeatures(v, features);
} else {
features.add((String)v);
}
}
}
}

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import java.util.List;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class GetCentroidsEvaluator extends RecursiveObjectEvaluator implements OneValueWorker {
private static final long serialVersionUID = 1;
public GetCentroidsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object value) throws IOException {
if(!(value instanceof KmeansEvaluator.ClusterTuple)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a clustering result",toExpression(constructingFactory), value.getClass().getSimpleName()));
} else {
KmeansEvaluator.ClusterTuple clusterTuple = (KmeansEvaluator.ClusterTuple)value;
List<CentroidCluster<KmeansEvaluator.ClusterPoint>> clusters = clusterTuple.getClusters();
double[][] data = new double[clusters.size()][];
for(int i=0; i<clusters.size(); i++) {
CentroidCluster<KmeansEvaluator.ClusterPoint> centroidCluster = clusters.get(i);
Clusterable clusterable = centroidCluster.getCenter();
data[i] = clusterable.getPoint();
}
Matrix centroids = new Matrix(data);
centroids.setColumnLabels(clusterTuple.getColumnLabels());
return centroids;
}
}
}

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import java.util.List;
import java.util.ArrayList;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class GetClusterEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
private static final long serialVersionUID = 1;
public GetClusterEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object value1, Object value2) throws IOException {
if(!(value1 instanceof KmeansEvaluator.ClusterTuple)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a cluster result.",toExpression(constructingFactory), value1.getClass().getSimpleName()));
} else {
KmeansEvaluator.ClusterTuple clusterTuple = (KmeansEvaluator.ClusterTuple)value1;
List<CentroidCluster<KmeansEvaluator.ClusterPoint>> clusters = clusterTuple.getClusters();
Number index = (Number)value2;
CentroidCluster cluster = clusters.get(index.intValue());
List points = cluster.getPoints();
List<String> rowLabels = new ArrayList();
double[][] data = new double[points.size()][];
for(int i=0; i<points.size(); i++) {
KmeansEvaluator.ClusterPoint p = (KmeansEvaluator.ClusterPoint)points.get(i);
data[i] = p.getPoint();
rowLabels.add(p.getId());
}
Matrix matrix = new Matrix(data);
matrix.setRowLabels(rowLabels);
matrix.setColumnLabels(clusterTuple.getColumnLabels());
return matrix;
}
}
}

View File

@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class KmeansEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L;
public KmeansEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
@Override
public Object doWork(Object... values) throws IOException {
if(values.length < 2) {
throw new IOException("kmeans expects atleast two parameters a Matrix of observations and k");
}
Matrix matrix = null;
int k = 0;
int maxIterations = 1000;
if(values[0] instanceof Matrix) {
matrix = (Matrix)values[0];
} else {
throw new IOException("The first parameter for kmeans should be the observation matrix.");
}
if(values[1] instanceof Number) {
k = ((Number)values[1]).intValue();
} else {
throw new IOException("The second parameter for kmeans should be k.");
}
if(values.length == 3) {
maxIterations = ((Number)values[2]).intValue();
}
KMeansPlusPlusClusterer<ClusterPoint> kmeans = new KMeansPlusPlusClusterer(k, maxIterations);
List<ClusterPoint> points = new ArrayList();
double[][] data = matrix.getData();
List<String> ids = matrix.getRowLabels();
for(int i=0; i<data.length; i++) {
double[] vec = data[i];
points.add(new ClusterPoint(ids.get(i), vec));
}
Map fields = new HashMap();
fields.put("k", k);
fields.put("distance", "euclidean");
fields.put("maxIterations", maxIterations);
return new ClusterTuple(fields, kmeans.cluster(points), matrix.getColumnLabels());
}
public static class ClusterPoint implements Clusterable {
private double[] point;
private String id;
public ClusterPoint(String id, double[] point) {
this.id = id;
this.point = point;
}
public double[] getPoint() {
return this.point;
}
public String getId() {
return this.id;
}
}
public static class ClusterTuple extends Tuple {
private List<String> columnLabels;
private List<CentroidCluster<ClusterPoint>> clusters;
public ClusterTuple(Map fields,
List<CentroidCluster<ClusterPoint>> clusters,
List<String> columnLabels) {
super(fields);
this.clusters = clusters;
this.columnLabels = columnLabels;
}
public List<String> getColumnLabels() {
return this.columnLabels;
}
public List<CentroidCluster<ClusterPoint>> getClusters() {
return this.clusters;
}
}
}

View File

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import java.util.List;
import java.util.ArrayList;
public class RowAtEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L;
public RowAtEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(2 != containedEvaluators.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object value1, Object value2) throws IOException {
if(value1 instanceof Matrix) {
Matrix matrix = (Matrix) value1;
Number index = (Number) value2;
double[] row = matrix.getData()[index.intValue()];
List<Number> list = new ArrayList();
for(double d : row) {
list.add(d);
}
return list;
} else {
throw new IOException("The rowAt function expects a matrix as the first parameter");
}
}
}

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import java.util.List;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class SetColumnLabelsEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
private static final long serialVersionUID = 1;
public SetColumnLabelsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object value1, Object value2) throws IOException {
if(!(value1 instanceof Matrix)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a Matrix",toExpression(constructingFactory), value1.getClass().getSimpleName()));
} else if(!(value2 instanceof List)) {
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting an array of labels.",toExpression(constructingFactory), value2.getClass().getSimpleName()));
} else {
Matrix matrix = (Matrix)value1;
List<String> colLabels = (List<String>)value2;
matrix.setColumnLabels(colLabels);
return matrix;
}
}
}

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import java.util.List;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class SetRowLabelsEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
private static final long serialVersionUID = 1;
public SetRowLabelsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object value1, Object value2) throws IOException {
if(!(value1 instanceof Matrix)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a Matrix",toExpression(constructingFactory), value1.getClass().getSimpleName()));
} else if(!(value2 instanceof List)) {
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting an array of labels.",toExpression(constructingFactory), value2.getClass().getSimpleName()));
} else {
Matrix matrix = (Matrix)value1;
List<String> rowlabels = (List<String>)value2;
matrix.setRowLabels(rowlabels);
return matrix;
}
}
}

View File

@ -38,6 +38,7 @@ public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements Ma
private int minTermLength = 3;
private double minDocFreq = .05; // 5% of the docs min
private double maxDocFreq = .5; // 50% of the docs max
private String[] excludes;
public TermVectorsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
@ -57,6 +58,8 @@ public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements Ma
if (maxDocFreq < 0 || maxDocFreq > 1) {
throw new IOException("Doc frequency percentage must be between 0 and 1");
}
} else if(namedParam.getName().equals("exclude")) {
this.excludes = namedParam.getParameter().toString().split(",");
} else {
throw new IOException("Unexpected named parameter:" + namedParam.getName());
}
@ -100,6 +103,7 @@ public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements Ma
String id = tuple.getString("id");
rowLabels.add(id);
OUTER:
for (String term : terms) {
if (term.length() < minTermLength) {
@ -107,6 +111,14 @@ public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements Ma
continue;
}
if(excludes != null) {
for (String exclude : excludes) {
if (term.indexOf(exclude) > -1) {
continue OUTER;
}
}
}
if (!docTerms.contains(term)) {
docTerms.add(term);
if (docFreqs.containsKey(term)) {
@ -134,7 +146,6 @@ public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements Ma
it.remove();
}
}
int totalTerms = docFreqs.size();
Set<String> keys = docFreqs.keySet();
features.addAll(keys);

View File

@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import java.util.List;
import java.util.ArrayList;
import java.util.TreeSet;
public class TopFeaturesEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L;
public TopFeaturesEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(2 != containedEvaluators.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object value1, Object value2) throws IOException {
int k = ((Number)value2).intValue();
if(value1 instanceof Matrix) {
Matrix matrix = (Matrix) value1;
List<String> features = matrix.getColumnLabels();
if(features == null) {
throw new IOException("Matrix column labels cannot be null for topFeatures function.");
}
double[][] data = matrix.getData();
List<List<String>> topFeatures = new ArrayList();
for(int i=0; i<data.length; i++) {
double[] row = data[i];
List<String> featuresRow = new ArrayList();
List<Integer> indexes = getMaxIndexes(row, k);
for(int index : indexes) {
featuresRow.add(features.get(index));
}
topFeatures.add(featuresRow);
}
return topFeatures;
} else {
throw new IOException("The topFeatures function expects a matrix as the first parameter");
}
}
private List<Integer> getMaxIndexes(double[] values, int k) {
TreeSet<Pair> set = new TreeSet();
for(int i=0; i<values.length; i++) {
set.add(new Pair(i, values[i]));
if(set.size() > k) {
set.pollFirst();
}
}
List<Integer> top = new ArrayList(k);
while(set.size() > 0) {
top.add(set.pollLast().getIndex());
}
return top;
}
public static class Pair implements Comparable<Pair> {
private int index;
private Double value;
public Pair(int index, Number value) {
this.index = index;
this.value = value.doubleValue();
}
public int compareTo(Pair pair) {
return value.compareTo(pair.value);
}
public int getIndex() {
return this.index;
}
public Number getValue() {
return value;
}
}
}

View File

@ -53,7 +53,10 @@ public class UnitEvaluator extends RecursiveObjectEvaluator implements OneValueW
unitData[i] = unitRow;
}
return new Matrix(unitData);
Matrix m = new Matrix(unitData);
m.setRowLabels(matrix.getRowLabels());
m.setColumnLabels(matrix.getRowLabels());
return m;
} else if(value instanceof List) {
List<Number> values = (List<Number>)value;
double[] doubles = new double[values.size()];

View File

@ -22,6 +22,8 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.HashSet;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
@ -50,14 +52,25 @@ public class LetStream extends TupleStream implements Expressible {
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
//Get all the named params
boolean echo = false;
Set<String> echo = null;
boolean echoAll = false;
String currentName = null;
for(StreamExpressionParameter np : namedParams) {
String name = ((StreamExpressionNamedParameter)np).getName();
currentName = name;
if(name.equals("echo")) {
echo = true;
echo = new HashSet();
String echoString = ((StreamExpressionNamedParameter) np).getParameter().toString().trim();
if(echoString.equalsIgnoreCase("true")) {
echoAll = true;
} else {
String[] echoVars = echoString.split(",");
for (String echoVar : echoVars) {
echo.add(echoVar.trim());
}
}
continue;
}
@ -75,14 +88,21 @@ public class LetStream extends TupleStream implements Expressible {
stream = factory.constructStream(streamExpressions.get(0));
} else {
StreamExpression tupleExpression = new StreamExpression("tuple");
if(!echo) {
if(!echoAll && echo == null) {
tupleExpression.addParameter(new StreamExpressionNamedParameter(currentName, currentName));
} else {
Set<String> names = letParams.keySet();
for(String name : names) {
if(echoAll) {
tupleExpression.addParameter(new StreamExpressionNamedParameter(name, name));
} else {
if(echo.contains(name)) {
tupleExpression.addParameter(new StreamExpressionNamedParameter(name, name));
}
}
}
}
stream = factory.constructStream(tupleExpression);
}
}

View File

@ -6175,7 +6175,14 @@ public class StreamExpressionTest extends SolrCloudTestCase {
@Test
public void testMatrix() throws Exception {
String cexpr = "matrix(array(1, 2, 3), rev(array(4,5,6)))";
String cexpr = "let(echo=true," +
" a=setColumnLabels(matrix(array(1, 2, 3), " +
" rev(array(4,5,6)))," +
" array(col1, col2, col3))," +
" b=rowAt(a, 1)," +
" c=colAt(a, 2)," +
" d=getColumnLabels(a)," +
" e=topFeatures(a, 1))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
@ -6185,7 +6192,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<List<Number>> out = (List<List<Number>>)tuples.get(0).get("return-value");
List<List<Number>> out = (List<List<Number>>)tuples.get(0).get("a");
List<Number> array1 = out.get(0);
assertEquals(array1.size(), 3);
@ -6198,6 +6205,31 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertEquals(array2.get(0).doubleValue(), 6.0, 0.0);
assertEquals(array2.get(1).doubleValue(), 5.0, 0.0);
assertEquals(array2.get(2).doubleValue(), 4.0, 0.0);
List<Number> row = (List<Number>)tuples.get(0).get("b");
assertEquals(row.size(), 3);
assertEquals(array2.get(0).doubleValue(), 6.0, 0.0);
assertEquals(array2.get(1).doubleValue(), 5.0, 0.0);
assertEquals(array2.get(2).doubleValue(), 4.0, 0.0);
List<Number> col = (List<Number>)tuples.get(0).get("c");
assertEquals(col.size(), 2);
assertEquals(col.get(0).doubleValue(), 3.0, 0.0);
assertEquals(col.get(1).doubleValue(), 4.0, 0.0);
List<String> colLabels = (List<String>)tuples.get(0).get("d");
assertEquals(colLabels.size(), 3);
assertEquals(colLabels.get(0), "col1");
assertEquals(colLabels.get(1), "col2");
assertEquals(colLabels.get(2), "col3");
List<List<String>> features = (List<List<String>>)tuples.get(0).get("e");
assertEquals(features.size(), 2);
assertEquals(features.get(0).size(), 1);
assertEquals(features.get(1).size(), 1);
assertEquals(features.get(0).get(0), "col3");
assertEquals(features.get(1).get(0), "col1");
}
@ -6784,6 +6816,78 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertEquals(docFreqs.get("world").intValue(), 1);
//Test exclude. This should drop off the term jim
cexpr = "let(echo=true," +
" a=select(list(tuple(id=\"1\", text=\"hello world\"), " +
" tuple(id=\"2\", text=\"hello steve\"), " +
" tuple(id=\"3\", text=\"hello jim jim\"), " +
" tuple(id=\"4\", text=\"hello jack\")), id, analyze(text, test_t) as terms)," +
" b=termVectors(a, exclude=jim, minDocFreq=0, maxDocFreq=1)," +
" c=getRowLabels(b)," +
" d=getColumnLabels(b)," +
" e=getAttribute(b, docFreqs))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
termVectors = (List<List<Number>>)tuples.get(0).get("b");
assertEquals(termVectors.size(), 4);
termVector = termVectors.get(0);
assertEquals(termVector.size(), 4);
assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0);
assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0);
assertEquals(termVector.get(3).doubleValue(), 1.916290731874155, 0.0);
termVector = termVectors.get(1);
assertEquals(termVector.size(), 4);
assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0);
assertEquals(termVector.get(2).doubleValue(), 1.916290731874155, 0.0);
assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0);
termVector = termVectors.get(2);
assertEquals(termVector.size(), 4);
assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0);
assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0);
assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0);
termVector = termVectors.get(3);
assertEquals(termVector.size(), 4);
assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
assertEquals(termVector.get(1).doubleValue(), 1.916290731874155, 0.0);
assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0);
assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0);
rowLabels = (List<String>)tuples.get(0).get("c");
assertEquals(rowLabels.size(), 4);
assertEquals(rowLabels.get(0), "1");
assertEquals(rowLabels.get(1), "2");
assertEquals(rowLabels.get(2), "3");
assertEquals(rowLabels.get(3), "4");
columnLabels = (List<String>)tuples.get(0).get("d");
assertEquals(columnLabels.size(), 4);
assertEquals(columnLabels.get(0), "hello");
assertEquals(columnLabels.get(1), "jack");
assertEquals(columnLabels.get(2), "steve");
assertEquals(columnLabels.get(3), "world");
docFreqs = (Map<String, Number>)tuples.get(0).get("e");
assertEquals(docFreqs.size(), 4);
assertEquals(docFreqs.get("hello").intValue(), 4);
assertEquals(docFreqs.get("jack").intValue(), 1);
assertEquals(docFreqs.get("steve").intValue(), 1);
assertEquals(docFreqs.get("world").intValue(), 1);
//Test minDocFreq attribute at .5. This should eliminate all but the term hello
cexpr = "let(echo=true," +
@ -6884,6 +6988,84 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertTrue(out.get(5).intValue() == 6);
}
@Test
public void testKmeans() throws Exception {
String cexpr = "let(echo=true," +
" a=array(1,1,1,0,0,0)," +
" b=array(1,1,1,0,0,0)," +
" c=array(0,0,0,1,1,1)," +
" d=array(0,0,0,1,1,1)," +
" e=setRowLabels(matrix(a,b,c,d), " +
" array(doc1, doc2, doc3, doc4))," +
" f=kmeans(e, 2)," +
" g=getCluster(f, 0)," +
" h=getCluster(f, 1)," +
" i=getCentroids(f)," +
" j=getRowLabels(g)," +
" k=getRowLabels(h))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<List<Number>> cluster1 = (List<List<Number>>)tuples.get(0).get("g");
List<List<Number>> cluster2 = (List<List<Number>>)tuples.get(0).get("h");
List<List<Number>> centroids = (List<List<Number>>)tuples.get(0).get("i");
List<String> labels1 = (List<String>)tuples.get(0).get("j");
List<String> labels2 = (List<String>)tuples.get(0).get("k");
assertEquals(cluster1.size(), 2);
assertEquals(cluster2.size(), 2);
assertEquals(centroids.size(), 2);
//Assert that the docs are not in both clusters
assertTrue(!(labels1.contains("doc1") && labels2.contains("doc1")));
assertTrue(!(labels1.contains("doc2") && labels2.contains("doc2")));
assertTrue(!(labels1.contains("doc3") && labels2.contains("doc3")));
assertTrue(!(labels1.contains("doc4") && labels2.contains("doc4")));
//Assert that (doc1 and doc2) or (doc3 and doc4) are in labels1
assertTrue((labels1.contains("doc1") && labels1.contains("doc2")) ||
((labels1.contains("doc3") && labels1.contains("doc4"))));
//Assert that (doc1 and doc2) or (doc3 and doc4) are in labels2
assertTrue((labels2.contains("doc1") && labels2.contains("doc2")) ||
((labels2.contains("doc3") && labels2.contains("doc4"))));
if(labels1.contains("doc1")) {
assertEquals(centroids.get(0).get(0).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(0).get(1).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(0).get(2).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(0).get(3).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(0).get(4).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(0).get(5).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(1).get(0).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(1).get(1).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(1).get(2).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(1).get(3).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(1).get(4).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(1).get(5).doubleValue(), 1.0, 0.0);
} else {
assertEquals(centroids.get(0).get(0).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(0).get(1).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(0).get(2).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(0).get(3).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(0).get(4).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(0).get(5).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(1).get(0).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(1).get(1).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(1).get(2).doubleValue(), 1.0, 0.0);
assertEquals(centroids.get(1).get(3).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(1).get(4).doubleValue(), 0.0, 0.0);
assertEquals(centroids.get(1).get(5).doubleValue(), 0.0, 0.0);
}
}
@Test
public void testEBEMultiply() throws Exception {