diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index b9a271a6835..eda53eabd36 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -310,6 +310,8 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("indexOf", IndexOfEvaluator.class) .withFunctionName("columnCount", ColumnCountEvaluator.class) .withFunctionName("rowCount", RowCountEvaluator.class) + .withFunctionName("fuzzyKmeans", FuzzyKmeansEvaluator.class) + .withFunctionName("getMembershipMatrix", GetMembershipMatrixEvaluator.class) // Boolean Stream Evaluators diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FuzzyKmeansEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FuzzyKmeansEvaluator.java new file mode 100644 index 00000000000..62a3444ea30 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FuzzyKmeansEvaluator.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.HashMap; + +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.ml.clustering.CentroidCluster; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.ml.clustering.FuzzyKMeansClusterer; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class FuzzyKmeansEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + + private int maxIterations = 1000; + private double fuzziness = 1.2; + + public FuzzyKmeansEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + + List namedParams = factory.getNamedOperands(expression); + + for(StreamExpressionNamedParameter namedParam : namedParams){ + if(namedParam.getName().equals("fuzziness")){ + this.fuzziness = Double.parseDouble(namedParam.getParameter().toString().trim()); + } else if(namedParam.getName().equals("maxIterations")) { + this.maxIterations = Integer.parseInt(namedParam.getParameter().toString().trim()); + } else { + throw new IOException("Unexpected named parameter:"+namedParam.getName()); + } + } + } + + @Override + public Object doWork(Object value1, Object value2) throws IOException { + + + Matrix matrix = null; + int k = 0; + + + if(value1 instanceof Matrix) { + matrix = (Matrix)value1; + } else { + throw new IOException("The first parameter for fuzzyKmeans should be the observation matrix."); + } + + if(value2 instanceof Number) { + k = ((Number)value2).intValue(); + } else { + throw new IOException("The second parameter for fuzzyKmeans should be k."); + } + + FuzzyKMeansClusterer kmeans = new FuzzyKMeansClusterer(k, + fuzziness, + maxIterations, + new EuclideanDistance()); + List points = new ArrayList(); + double[][] data = matrix.getData(); + + List ids = matrix.getRowLabels(); + + for(int i=0; i> clusters = kmeans.cluster(points); + RealMatrix realMatrix = kmeans.getMembershipMatrix(); + double[][] mmData = realMatrix.getData(); + Matrix mmMatrix = new Matrix(mmData); + mmMatrix.setRowLabels(matrix.getRowLabels()); + return new KmeansEvaluator.ClusterTuple(fields, clusters, matrix.getColumnLabels(),mmMatrix); + } +} + diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetMembershipMatrixEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetMembershipMatrixEvaluator.java new file mode 100644 index 00000000000..c30b66ca490 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetMembershipMatrixEvaluator.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; + +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class GetMembershipMatrixEvaluator extends RecursiveObjectEvaluator implements OneValueWorker { + private static final long serialVersionUID = 1; + + public GetMembershipMatrixEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object value) throws IOException { + if(!(value instanceof KmeansEvaluator.ClusterTuple)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a clustering result",toExpression(constructingFactory), value.getClass().getSimpleName())); + } else { + KmeansEvaluator.ClusterTuple clusterTuple = (KmeansEvaluator.ClusterTuple)value; + return clusterTuple.getMembershipMatrix(); + } + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KmeansEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KmeansEvaluator.java index 410d8bb7d2f..dfac1b35d1a 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KmeansEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KmeansEvaluator.java @@ -29,43 +29,46 @@ import org.apache.solr.client.solrj.io.Tuple; import org.apache.commons.math3.ml.clustering.Clusterable; import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; -public class KmeansEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker { +public class KmeansEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { protected static final long serialVersionUID = 1L; - + private int maxIterations = 1000; public KmeansEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ super(expression, factory); + + List namedParams = factory.getNamedOperands(expression); + + for(StreamExpressionNamedParameter namedParam : namedParams){ + if(namedParam.getName().equals("maxIterations")) { + this.maxIterations = Integer.parseInt(namedParam.getParameter().toString().trim()); + } else { + throw new IOException("Unexpected named parameter:"+namedParam.getName()); + } + } } @Override - public Object doWork(Object... values) throws IOException { - - if(values.length < 2) { - throw new IOException("kmeans expects atleast two parameters a Matrix of observations and k"); - } + public Object doWork(Object value1, Object value2) throws IOException { Matrix matrix = null; int k = 0; - int maxIterations = 1000; - if(values[0] instanceof Matrix) { - matrix = (Matrix)values[0]; + if(value1 instanceof Matrix) { + matrix = (Matrix)value1; } else { throw new IOException("The first parameter for kmeans should be the observation matrix."); } - if(values[1] instanceof Number) { - k = ((Number)values[1]).intValue(); + if(value2 instanceof Number) { + k = ((Number)value2).intValue(); } else { throw new IOException("The second parameter for kmeans should be k."); } - if(values.length == 3) { - maxIterations = ((Number)values[2]).intValue(); - } KMeansPlusPlusClusterer kmeans = new KMeansPlusPlusClusterer(k, maxIterations); List points = new ArrayList(); @@ -110,6 +113,7 @@ public class KmeansEvaluator extends RecursiveObjectEvaluator implements ManyVal private List columnLabels; private List> clusters; + private Matrix membershipMatrix; public ClusterTuple(Map fields, List> clusters, @@ -119,6 +123,20 @@ public class KmeansEvaluator extends RecursiveObjectEvaluator implements ManyVal this.columnLabels = columnLabels; } + public ClusterTuple(Map fields, + List> clusters, + List columnLabels, + Matrix membershipMatrix) { + super(fields); + this.clusters = clusters; + this.columnLabels = columnLabels; + this.membershipMatrix = membershipMatrix; + } + + public Matrix getMembershipMatrix() { + return this.membershipMatrix; + } + public List getColumnLabels() { return this.columnLabels; } @@ -126,10 +144,6 @@ public class KmeansEvaluator extends RecursiveObjectEvaluator implements ManyVal public List> getClusters() { return this.clusters; } - - - - } } diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java index 493799299dd..ea2a7abe7b7 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java @@ -7076,6 +7076,111 @@ public class StreamExpressionTest extends SolrCloudTestCase { } } + + @Test + public void testFuzzyKmeans() throws Exception { + String cexpr = "let(echo=true," + + " a=array(1,1,1,0,0,0)," + + " b=array(1,1,1,0,0,0)," + + " c=array(0,0,0,1,1,1)," + + " d=array(0,0,0,1,1,1)," + + " e=setRowLabels(matrix(a,b,c,d), " + + " array(doc1, doc2, doc3, doc4))," + + " f=fuzzyKmeans(e, 2)," + + " g=getCluster(f, 0)," + + " h=getCluster(f, 1)," + + " i=getCentroids(f)," + + " j=getRowLabels(g)," + + " k=getRowLabels(h)," + + " l=getMembershipMatrix(f))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List> cluster1 = (List>)tuples.get(0).get("g"); + List> cluster2 = (List>)tuples.get(0).get("h"); + List> centroids = (List>)tuples.get(0).get("i"); + List> membership = (List>)tuples.get(0).get("l"); + + List labels1 = (List)tuples.get(0).get("j"); + List labels2 = (List)tuples.get(0).get("k"); + + assertEquals(cluster1.size(), 2); + assertEquals(cluster2.size(), 2); + assertEquals(centroids.size(), 2); + + //Assert that the docs are not in both clusters + assertTrue(!(labels1.contains("doc1") && labels2.contains("doc1"))); + assertTrue(!(labels1.contains("doc2") && labels2.contains("doc2"))); + assertTrue(!(labels1.contains("doc3") && labels2.contains("doc3"))); + assertTrue(!(labels1.contains("doc4") && labels2.contains("doc4"))); + + //Assert that (doc1 and doc2) or (doc3 and doc4) are in labels1 + assertTrue((labels1.contains("doc1") && labels1.contains("doc2")) || + ((labels1.contains("doc3") && labels1.contains("doc4")))); + + //Assert that (doc1 and doc2) or (doc3 and doc4) are in labels2 + assertTrue((labels2.contains("doc1") && labels2.contains("doc2")) || + ((labels2.contains("doc3") && labels2.contains("doc4")))); + + + if(labels1.contains("doc1")) { + assertEquals(centroids.get(0).get(0).doubleValue(), 1.0, 0.0001); + assertEquals(centroids.get(0).get(1).doubleValue(), 1.0, 0.0001); + assertEquals(centroids.get(0).get(2).doubleValue(), 1.0, 0.0001); + assertEquals(centroids.get(0).get(3).doubleValue(), 0.0, 0.0001); + assertEquals(centroids.get(0).get(4).doubleValue(), 0.0, 0.0001); + assertEquals(centroids.get(0).get(5).doubleValue(), 0.0, 0.0001); + + assertEquals(centroids.get(1).get(0).doubleValue(), 0.0, 0.0001); + assertEquals(centroids.get(1).get(1).doubleValue(), 0.0, 0.0001); + assertEquals(centroids.get(1).get(2).doubleValue(), 0.0, 0.0001); + assertEquals(centroids.get(1).get(3).doubleValue(), 1.0, 0.0001); + assertEquals(centroids.get(1).get(4).doubleValue(), 1.0, 0.0001); + assertEquals(centroids.get(1).get(5).doubleValue(), 1.0, 0.0001); + + //Assert the membership matrix + assertEquals(membership.get(0).get(0).doubleValue(), 1.0, 0.0001); + assertEquals(membership.get(0).get(1).doubleValue(), 0.0, 0.0001); + assertEquals(membership.get(1).get(0).doubleValue(), 1.0, 0.0001); + assertEquals(membership.get(1).get(1).doubleValue(), 0.0, 0.0001); + assertEquals(membership.get(2).get(0).doubleValue(), 0.0, 0.0001); + assertEquals(membership.get(2).get(1).doubleValue(), 1.0, 0.0001); + assertEquals(membership.get(3).get(0).doubleValue(), 0.0, 0.0001); + assertEquals(membership.get(3).get(1).doubleValue(), 1.0, 0.0001); + + } else { + assertEquals(centroids.get(0).get(0).doubleValue(), 0.0, 0.0001); + assertEquals(centroids.get(0).get(1).doubleValue(), 0.0, 0.0001); + assertEquals(centroids.get(0).get(2).doubleValue(), 0.0, 0.0001); + assertEquals(centroids.get(0).get(3).doubleValue(), 1.0, 0.0001); + assertEquals(centroids.get(0).get(4).doubleValue(), 1.0, 0.0001); + assertEquals(centroids.get(0).get(5).doubleValue(), 1.0, 0.0001); + + assertEquals(centroids.get(1).get(0).doubleValue(), 1.0, 0.0001); + assertEquals(centroids.get(1).get(1).doubleValue(), 1.0, 0.0001); + assertEquals(centroids.get(1).get(2).doubleValue(), 1.0, 0.0001); + assertEquals(centroids.get(1).get(3).doubleValue(), 0.0, 0.0001); + assertEquals(centroids.get(1).get(4).doubleValue(), 0.0, 0.0001); + assertEquals(centroids.get(1).get(5).doubleValue(), 0.0, 0.0001); + + //Assert the membership matrix + assertEquals(membership.get(0).get(0).doubleValue(), 0.0, 0.0001); + assertEquals(membership.get(0).get(1).doubleValue(), 1.0, 0.0001); + assertEquals(membership.get(1).get(0).doubleValue(), 0.0, 0.0001); + assertEquals(membership.get(1).get(1).doubleValue(), 1.0, 0.0001); + assertEquals(membership.get(2).get(0).doubleValue(), 1.0, 0.0001); + assertEquals(membership.get(2).get(1).doubleValue(), 0.0, 0.0001); + assertEquals(membership.get(3).get(0).doubleValue(), 1.0, 0.0001); + assertEquals(membership.get(3).get(1).doubleValue(), 0.0, 0.0001); + } + } + @Test public void testEBEMultiply() throws Exception { String cexpr = "ebeMultiply(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";