From a08f71279ca79113f41a0ae1f954931195ebba41 Mon Sep 17 00:00:00 2001 From: Joel Bernstein Date: Mon, 15 Jan 2018 14:50:17 -0500 Subject: [PATCH] SOLR-11737: Add kmeans Stream Evaluator to support kmeans clustering --- .../apache/solr/handler/StreamHandler.java | 9 + .../solrj/io/eval/ColumnAtEvaluator.java | 55 ++++++ .../solrj/io/eval/FeatureSelectEvaluator.java | 93 +++++++++ .../solrj/io/eval/GetCentroidsEvaluator.java | 55 ++++++ .../solrj/io/eval/GetClusterEvaluator.java | 64 ++++++ .../client/solrj/io/eval/KmeansEvaluator.java | 135 +++++++++++++ .../client/solrj/io/eval/RowAtEvaluator.java | 56 ++++++ .../io/eval/SetColumnLabelsEvaluator.java | 47 +++++ .../solrj/io/eval/SetRowLabelsEvaluator.java | 47 +++++ .../solrj/io/eval/TermVectorsEvaluator.java | 13 +- .../solrj/io/eval/TopFeaturesEvaluator.java | 112 +++++++++++ .../client/solrj/io/eval/UnitEvaluator.java | 5 +- .../client/solrj/io/stream/LetStream.java | 28 ++- .../solrj/io/stream/StreamExpressionTest.java | 186 +++++++++++++++++- 14 files changed, 897 insertions(+), 8 deletions(-) create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ColumnAtEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FeatureSelectEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetCentroidsEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetClusterEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KmeansEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/RowAtEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SetColumnLabelsEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SetRowLabelsEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TopFeaturesEvaluator.java diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index ee3a17baa7d..206136c1a83 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -296,6 +296,15 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("getColumnLabels", GetColumnLabelsEvaluator.class) .withFunctionName("getRowLabels", GetRowLabelsEvaluator.class) .withFunctionName("getAttribute", GetAttributeEvaluator.class) + .withFunctionName("kmeans", KmeansEvaluator.class) + .withFunctionName("getCentroids", GetCentroidsEvaluator.class) + .withFunctionName("getCluster", GetClusterEvaluator.class) + .withFunctionName("topFeatures", TopFeaturesEvaluator.class) + .withFunctionName("featureSelect", FeatureSelectEvaluator.class) + .withFunctionName("rowAt", RowAtEvaluator.class) + .withFunctionName("colAt", ColumnAtEvaluator.class) + .withFunctionName("setColumnLabels", SetColumnLabelsEvaluator.class) + .withFunctionName("setRowLabels", SetRowLabelsEvaluator.class) // Boolean Stream Evaluators diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ColumnAtEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ColumnAtEvaluator.java new file mode 100644 index 00000000000..5714096c559 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ColumnAtEvaluator.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; + +import java.util.Locale; + +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import java.util.List; +import java.util.ArrayList; + +public class ColumnAtEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public ColumnAtEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + + if(2 != containedEvaluators.size()){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size())); + } + } + + @Override + public Object doWork(Object value1, Object value2) throws IOException { + + if(value1 instanceof Matrix) { + Matrix matrix = (Matrix) value1; + Number index = (Number) value2; + double[][] data = matrix.getData(); + List list = new ArrayList(); + for(double[] row : data) { + list.add(row[index.intValue()]); + } + return list; + } else { + throw new IOException("The rowAt function expects a matrix as the first parameter"); + } + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FeatureSelectEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FeatureSelectEvaluator.java new file mode 100644 index 00000000000..b3c06d824c5 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FeatureSelectEvaluator.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Locale; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import java.util.List; +import java.util.Set; +import java.util.ArrayList; + +public class FeatureSelectEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public FeatureSelectEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + + if(2 != containedEvaluators.size()){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size())); + } + } + + @Override + public Object doWork(Object value1, Object value2) throws IOException { + + if(value1 instanceof Matrix) { + Matrix matrix = (Matrix) value1; + double[][] data = matrix.getData(); + + List labels = matrix.getColumnLabels(); + Set features = new HashSet(); + loadFeatures(value2, features); + + List newColumnLabels = new ArrayList(); + + for(String label : labels) { + if(features.contains(label)) { + newColumnLabels.add(label); + } + } + + double[][] selectFeatures = new double[data.length][newColumnLabels.size()]; + + for(int i=0; i features) { + List list = (List)o; + for(Object v : list) { + if(v instanceof List) { + loadFeatures(v, features); + } else { + features.add((String)v); + } + } + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetCentroidsEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetCentroidsEvaluator.java new file mode 100644 index 00000000000..e55263de349 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetCentroidsEvaluator.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; + +import java.util.List; + +import org.apache.commons.math3.ml.clustering.CentroidCluster; +import org.apache.commons.math3.ml.clustering.Clusterable; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class GetCentroidsEvaluator extends RecursiveObjectEvaluator implements OneValueWorker { + private static final long serialVersionUID = 1; + + public GetCentroidsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object value) throws IOException { + if(!(value instanceof KmeansEvaluator.ClusterTuple)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a clustering result",toExpression(constructingFactory), value.getClass().getSimpleName())); + } else { + KmeansEvaluator.ClusterTuple clusterTuple = (KmeansEvaluator.ClusterTuple)value; + List> clusters = clusterTuple.getClusters(); + double[][] data = new double[clusters.size()][]; + for(int i=0; i centroidCluster = clusters.get(i); + Clusterable clusterable = centroidCluster.getCenter(); + data[i] = clusterable.getPoint(); + } + Matrix centroids = new Matrix(data); + centroids.setColumnLabels(clusterTuple.getColumnLabels()); + return centroids; + } + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetClusterEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetClusterEvaluator.java new file mode 100644 index 00000000000..903670d9614 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetClusterEvaluator.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; + +import java.util.List; +import java.util.ArrayList; + +import org.apache.commons.math3.ml.clustering.CentroidCluster; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class GetClusterEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + private static final long serialVersionUID = 1; + + public GetClusterEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object value1, Object value2) throws IOException { + if(!(value1 instanceof KmeansEvaluator.ClusterTuple)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a cluster result.",toExpression(constructingFactory), value1.getClass().getSimpleName())); + } else { + + KmeansEvaluator.ClusterTuple clusterTuple = (KmeansEvaluator.ClusterTuple)value1; + List> clusters = clusterTuple.getClusters(); + + Number index = (Number)value2; + CentroidCluster cluster = clusters.get(index.intValue()); + List points = cluster.getPoints(); + List rowLabels = new ArrayList(); + double[][] data = new double[points.size()][]; + + for(int i=0; i kmeans = new KMeansPlusPlusClusterer(k, maxIterations); + List points = new ArrayList(); + double[][] data = matrix.getData(); + + List ids = matrix.getRowLabels(); + + for(int i=0; i columnLabels; + private List> clusters; + + public ClusterTuple(Map fields, + List> clusters, + List columnLabels) { + super(fields); + this.clusters = clusters; + this.columnLabels = columnLabels; + } + + public List getColumnLabels() { + return this.columnLabels; + } + + public List> getClusters() { + return this.clusters; + } + + + + + } +} + diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/RowAtEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/RowAtEvaluator.java new file mode 100644 index 00000000000..982cfbb744d --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/RowAtEvaluator.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; + +import java.util.Locale; + +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import java.util.List; +import java.util.ArrayList; + +public class RowAtEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public RowAtEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + + if(2 != containedEvaluators.size()){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size())); + } + } + + @Override + public Object doWork(Object value1, Object value2) throws IOException { + + if(value1 instanceof Matrix) { + Matrix matrix = (Matrix) value1; + Number index = (Number) value2; + double[] row = matrix.getData()[index.intValue()]; + List list = new ArrayList(); + for(double d : row) { + list.add(d); + } + + return list; + } else { + throw new IOException("The rowAt function expects a matrix as the first parameter"); + } + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SetColumnLabelsEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SetColumnLabelsEvaluator.java new file mode 100644 index 00000000000..1d589aff71f --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SetColumnLabelsEvaluator.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; +import java.util.List; + +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class SetColumnLabelsEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + private static final long serialVersionUID = 1; + + public SetColumnLabelsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object value1, Object value2) throws IOException { + if(!(value1 instanceof Matrix)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a Matrix",toExpression(constructingFactory), value1.getClass().getSimpleName())); + } else if(!(value2 instanceof List)) { + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting an array of labels.",toExpression(constructingFactory), value2.getClass().getSimpleName())); + } else { + Matrix matrix = (Matrix)value1; + List colLabels = (List)value2; + matrix.setColumnLabels(colLabels); + return matrix; + } + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SetRowLabelsEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SetRowLabelsEvaluator.java new file mode 100644 index 00000000000..66a59c8b122 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SetRowLabelsEvaluator.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; +import java.util.List; + +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class SetRowLabelsEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + private static final long serialVersionUID = 1; + + public SetRowLabelsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object value1, Object value2) throws IOException { + if(!(value1 instanceof Matrix)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a Matrix",toExpression(constructingFactory), value1.getClass().getSimpleName())); + } else if(!(value2 instanceof List)) { + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting an array of labels.",toExpression(constructingFactory), value2.getClass().getSimpleName())); + } else { + Matrix matrix = (Matrix)value1; + List rowlabels = (List)value2; + matrix.setRowLabels(rowlabels); + return matrix; + } + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TermVectorsEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TermVectorsEvaluator.java index 8bf050df929..7c097124054 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TermVectorsEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TermVectorsEvaluator.java @@ -38,6 +38,7 @@ public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements Ma private int minTermLength = 3; private double minDocFreq = .05; // 5% of the docs min private double maxDocFreq = .5; // 50% of the docs max + private String[] excludes; public TermVectorsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { super(expression, factory); @@ -57,6 +58,8 @@ public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements Ma if (maxDocFreq < 0 || maxDocFreq > 1) { throw new IOException("Doc frequency percentage must be between 0 and 1"); } + } else if(namedParam.getName().equals("exclude")) { + this.excludes = namedParam.getParameter().toString().split(","); } else { throw new IOException("Unexpected named parameter:" + namedParam.getName()); } @@ -100,6 +103,7 @@ public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements Ma String id = tuple.getString("id"); rowLabels.add(id); + OUTER: for (String term : terms) { if (term.length() < minTermLength) { @@ -107,6 +111,14 @@ public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements Ma continue; } + if(excludes != null) { + for (String exclude : excludes) { + if (term.indexOf(exclude) > -1) { + continue OUTER; + } + } + } + if (!docTerms.contains(term)) { docTerms.add(term); if (docFreqs.containsKey(term)) { @@ -134,7 +146,6 @@ public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements Ma it.remove(); } } - int totalTerms = docFreqs.size(); Set keys = docFreqs.keySet(); features.addAll(keys); diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TopFeaturesEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TopFeaturesEvaluator.java new file mode 100644 index 00000000000..e2100b1fbde --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TopFeaturesEvaluator.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; + +import java.util.Locale; + +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import java.util.List; +import java.util.ArrayList; +import java.util.TreeSet; + +public class TopFeaturesEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public TopFeaturesEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + + if(2 != containedEvaluators.size()){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size())); + } + } + + @Override + public Object doWork(Object value1, Object value2) throws IOException { + + int k = ((Number)value2).intValue(); + + if(value1 instanceof Matrix) { + + Matrix matrix = (Matrix) value1; + List features = matrix.getColumnLabels(); + + if(features == null) { + throw new IOException("Matrix column labels cannot be null for topFeatures function."); + } + + double[][] data = matrix.getData(); + List> topFeatures = new ArrayList(); + + for(int i=0; i featuresRow = new ArrayList(); + List indexes = getMaxIndexes(row, k); + for(int index : indexes) { + featuresRow.add(features.get(index)); + } + topFeatures.add(featuresRow); + } + + return topFeatures; + } else { + throw new IOException("The topFeatures function expects a matrix as the first parameter"); + } + } + + private List getMaxIndexes(double[] values, int k) { + TreeSet set = new TreeSet(); + for(int i=0; i k) { + set.pollFirst(); + } + } + + List top = new ArrayList(k); + while(set.size() > 0) { + top.add(set.pollLast().getIndex()); + } + + return top; + } + + public static class Pair implements Comparable { + + private int index; + private Double value; + + public Pair(int index, Number value) { + this.index = index; + this.value = value.doubleValue(); + } + + public int compareTo(Pair pair) { + return value.compareTo(pair.value); + } + + public int getIndex() { + return this.index; + } + + public Number getValue() { + return value; + } + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/UnitEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/UnitEvaluator.java index 8be990d55f0..16d72ae367f 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/UnitEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/UnitEvaluator.java @@ -53,7 +53,10 @@ public class UnitEvaluator extends RecursiveObjectEvaluator implements OneValueW unitData[i] = unitRow; } - return new Matrix(unitData); + Matrix m = new Matrix(unitData); + m.setRowLabels(matrix.getRowLabels()); + m.setColumnLabels(matrix.getRowLabels()); + return m; } else if(value instanceof List) { List values = (List)value; double[] doubles = new double[values.size()]; diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/LetStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/LetStream.java index ce883ad3969..8bb12a530b2 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/LetStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/LetStream.java @@ -22,6 +22,8 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.HashSet; + import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.StreamComparator; @@ -50,14 +52,25 @@ public class LetStream extends TupleStream implements Expressible { List namedParams = factory.getNamedOperands(expression); //Get all the named params - boolean echo = false; + Set echo = null; + boolean echoAll = false; String currentName = null; for(StreamExpressionParameter np : namedParams) { String name = ((StreamExpressionNamedParameter)np).getName(); currentName = name; if(name.equals("echo")) { - echo = true; + echo = new HashSet(); + String echoString = ((StreamExpressionNamedParameter) np).getParameter().toString().trim(); + if(echoString.equalsIgnoreCase("true")) { + echoAll = true; + } else { + String[] echoVars = echoString.split(","); + for (String echoVar : echoVars) { + echo.add(echoVar.trim()); + } + } + continue; } @@ -75,14 +88,21 @@ public class LetStream extends TupleStream implements Expressible { stream = factory.constructStream(streamExpressions.get(0)); } else { StreamExpression tupleExpression = new StreamExpression("tuple"); - if(!echo) { + if(!echoAll && echo == null) { tupleExpression.addParameter(new StreamExpressionNamedParameter(currentName, currentName)); } else { Set names = letParams.keySet(); for(String name : names) { - tupleExpression.addParameter(new StreamExpressionNamedParameter(name, name)); + if(echoAll) { + tupleExpression.addParameter(new StreamExpressionNamedParameter(name, name)); + } else { + if(echo.contains(name)) { + tupleExpression.addParameter(new StreamExpressionNamedParameter(name, name)); + } + } } } + stream = factory.constructStream(tupleExpression); } } diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java index 2a9df01a4e0..6f1e61f9aaf 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java @@ -6175,7 +6175,14 @@ public class StreamExpressionTest extends SolrCloudTestCase { @Test public void testMatrix() throws Exception { - String cexpr = "matrix(array(1, 2, 3), rev(array(4,5,6)))"; + String cexpr = "let(echo=true," + + " a=setColumnLabels(matrix(array(1, 2, 3), " + + " rev(array(4,5,6)))," + + " array(col1, col2, col3))," + + " b=rowAt(a, 1)," + + " c=colAt(a, 2)," + + " d=getColumnLabels(a)," + + " e=topFeatures(a, 1))"; ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); paramsLoc.set("expr", cexpr); paramsLoc.set("qt", "/stream"); @@ -6185,7 +6192,7 @@ public class StreamExpressionTest extends SolrCloudTestCase { solrStream.setStreamContext(context); List tuples = getTuples(solrStream); assertTrue(tuples.size() == 1); - List> out = (List>)tuples.get(0).get("return-value"); + List> out = (List>)tuples.get(0).get("a"); List array1 = out.get(0); assertEquals(array1.size(), 3); @@ -6198,6 +6205,31 @@ public class StreamExpressionTest extends SolrCloudTestCase { assertEquals(array2.get(0).doubleValue(), 6.0, 0.0); assertEquals(array2.get(1).doubleValue(), 5.0, 0.0); assertEquals(array2.get(2).doubleValue(), 4.0, 0.0); + + List row = (List)tuples.get(0).get("b"); + + assertEquals(row.size(), 3); + assertEquals(array2.get(0).doubleValue(), 6.0, 0.0); + assertEquals(array2.get(1).doubleValue(), 5.0, 0.0); + assertEquals(array2.get(2).doubleValue(), 4.0, 0.0); + + List col = (List)tuples.get(0).get("c"); + assertEquals(col.size(), 2); + assertEquals(col.get(0).doubleValue(), 3.0, 0.0); + assertEquals(col.get(1).doubleValue(), 4.0, 0.0); + + List colLabels = (List)tuples.get(0).get("d"); + assertEquals(colLabels.size(), 3); + assertEquals(colLabels.get(0), "col1"); + assertEquals(colLabels.get(1), "col2"); + assertEquals(colLabels.get(2), "col3"); + + List> features = (List>)tuples.get(0).get("e"); + assertEquals(features.size(), 2); + assertEquals(features.get(0).size(), 1); + assertEquals(features.get(1).size(), 1); + assertEquals(features.get(0).get(0), "col3"); + assertEquals(features.get(1).get(0), "col1"); } @@ -6784,6 +6816,78 @@ public class StreamExpressionTest extends SolrCloudTestCase { assertEquals(docFreqs.get("world").intValue(), 1); + //Test exclude. This should drop off the term jim + + cexpr = "let(echo=true," + + " a=select(list(tuple(id=\"1\", text=\"hello world\"), " + + " tuple(id=\"2\", text=\"hello steve\"), " + + " tuple(id=\"3\", text=\"hello jim jim\"), " + + " tuple(id=\"4\", text=\"hello jack\")), id, analyze(text, test_t) as terms)," + + " b=termVectors(a, exclude=jim, minDocFreq=0, maxDocFreq=1)," + + " c=getRowLabels(b)," + + " d=getColumnLabels(b)," + + " e=getAttribute(b, docFreqs))"; + + paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + solrStream = new SolrStream(url, paramsLoc); + context = new StreamContext(); + solrStream.setStreamContext(context); + tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + termVectors = (List>)tuples.get(0).get("b"); + assertEquals(termVectors.size(), 4); + termVector = termVectors.get(0); + assertEquals(termVector.size(), 4); + assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0); + assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0); + assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0); + assertEquals(termVector.get(3).doubleValue(), 1.916290731874155, 0.0); + + termVector = termVectors.get(1); + assertEquals(termVector.size(), 4); + assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0); + assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0); + assertEquals(termVector.get(2).doubleValue(), 1.916290731874155, 0.0); + assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0); + + termVector = termVectors.get(2); + assertEquals(termVector.size(), 4); + assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0); + assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0); + assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0); + assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0); + + termVector = termVectors.get(3); + assertEquals(termVector.size(), 4); + assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0); + assertEquals(termVector.get(1).doubleValue(), 1.916290731874155, 0.0); + assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0); + assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0); + + rowLabels = (List)tuples.get(0).get("c"); + assertEquals(rowLabels.size(), 4); + assertEquals(rowLabels.get(0), "1"); + assertEquals(rowLabels.get(1), "2"); + assertEquals(rowLabels.get(2), "3"); + assertEquals(rowLabels.get(3), "4"); + + columnLabels = (List)tuples.get(0).get("d"); + assertEquals(columnLabels.size(), 4); + assertEquals(columnLabels.get(0), "hello"); + assertEquals(columnLabels.get(1), "jack"); + assertEquals(columnLabels.get(2), "steve"); + assertEquals(columnLabels.get(3), "world"); + + docFreqs = (Map)tuples.get(0).get("e"); + + assertEquals(docFreqs.size(), 4); + assertEquals(docFreqs.get("hello").intValue(), 4); + assertEquals(docFreqs.get("jack").intValue(), 1); + assertEquals(docFreqs.get("steve").intValue(), 1); + assertEquals(docFreqs.get("world").intValue(), 1); + //Test minDocFreq attribute at .5. This should eliminate all but the term hello cexpr = "let(echo=true," + @@ -6884,6 +6988,84 @@ public class StreamExpressionTest extends SolrCloudTestCase { assertTrue(out.get(5).intValue() == 6); } + @Test + public void testKmeans() throws Exception { + String cexpr = "let(echo=true," + + " a=array(1,1,1,0,0,0)," + + " b=array(1,1,1,0,0,0)," + + " c=array(0,0,0,1,1,1)," + + " d=array(0,0,0,1,1,1)," + + " e=setRowLabels(matrix(a,b,c,d), " + + " array(doc1, doc2, doc3, doc4))," + + " f=kmeans(e, 2)," + + " g=getCluster(f, 0)," + + " h=getCluster(f, 1)," + + " i=getCentroids(f)," + + " j=getRowLabels(g)," + + " k=getRowLabels(h))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List> cluster1 = (List>)tuples.get(0).get("g"); + List> cluster2 = (List>)tuples.get(0).get("h"); + List> centroids = (List>)tuples.get(0).get("i"); + List labels1 = (List)tuples.get(0).get("j"); + List labels2 = (List)tuples.get(0).get("k"); + + assertEquals(cluster1.size(), 2); + assertEquals(cluster2.size(), 2); + assertEquals(centroids.size(), 2); + + //Assert that the docs are not in both clusters + assertTrue(!(labels1.contains("doc1") && labels2.contains("doc1"))); + assertTrue(!(labels1.contains("doc2") && labels2.contains("doc2"))); + assertTrue(!(labels1.contains("doc3") && labels2.contains("doc3"))); + assertTrue(!(labels1.contains("doc4") && labels2.contains("doc4"))); + + //Assert that (doc1 and doc2) or (doc3 and doc4) are in labels1 + assertTrue((labels1.contains("doc1") && labels1.contains("doc2")) || + ((labels1.contains("doc3") && labels1.contains("doc4")))); + + //Assert that (doc1 and doc2) or (doc3 and doc4) are in labels2 + assertTrue((labels2.contains("doc1") && labels2.contains("doc2")) || + ((labels2.contains("doc3") && labels2.contains("doc4")))); + + if(labels1.contains("doc1")) { + assertEquals(centroids.get(0).get(0).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(0).get(1).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(0).get(2).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(0).get(3).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(0).get(4).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(0).get(5).doubleValue(), 0.0, 0.0); + + assertEquals(centroids.get(1).get(0).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(1).get(1).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(1).get(2).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(1).get(3).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(1).get(4).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(1).get(5).doubleValue(), 1.0, 0.0); + } else { + assertEquals(centroids.get(0).get(0).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(0).get(1).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(0).get(2).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(0).get(3).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(0).get(4).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(0).get(5).doubleValue(), 1.0, 0.0); + + assertEquals(centroids.get(1).get(0).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(1).get(1).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(1).get(2).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(1).get(3).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(1).get(4).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(1).get(5).doubleValue(), 0.0, 0.0); + } + } @Test public void testEBEMultiply() throws Exception {