From a664b63d420682828e73f5d2ad2cbba1cb14fe66 Mon Sep 17 00:00:00 2001 From: Joel Bernstein Date: Tue, 23 Jan 2018 15:54:37 -0500 Subject: [PATCH] SOLR-11890: Add multiKmeans Stream Evaluator --- .../apache/solr/handler/StreamHandler.java | 1 + .../solrj/io/eval/MultiKmeansEvaluator.java | 108 ++++++++++++++++++ .../solrj/io/stream/StreamExpressionTest.java | 81 +++++++++++++ 3 files changed, 190 insertions(+) create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiKmeansEvaluator.java diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index eda53eabd36..c9616af60f6 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -312,6 +312,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("rowCount", RowCountEvaluator.class) .withFunctionName("fuzzyKmeans", FuzzyKmeansEvaluator.class) .withFunctionName("getMembershipMatrix", GetMembershipMatrixEvaluator.class) + .withFunctionName("multiKmeans", MultiKmeansEvaluator.class) // Boolean Stream Evaluators diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiKmeansEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiKmeansEvaluator.java new file mode 100644 index 00000000000..86f1d85d6b7 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiKmeansEvaluator.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.HashMap; + +import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer; +import org.apache.commons.math3.ml.clustering.MultiKMeansPlusPlusClusterer; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class MultiKmeansEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker { + protected static final long serialVersionUID = 1L; + + private int maxIterations = 1000; + + public MultiKmeansEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + + List namedParams = factory.getNamedOperands(expression); + + for(StreamExpressionNamedParameter namedParam : namedParams){ + if(namedParam.getName().equals("maxIterations")) { + this.maxIterations = Integer.parseInt(namedParam.getParameter().toString().trim()); + } else { + throw new IOException("Unexpected named parameter:"+namedParam.getName()); + } + } + } + + @Override + public Object doWork(Object... values) throws IOException { + + if(values.length != 3) { + throw new IOException("The multiKmeans function expects three parameters; a matrix to cluster, k and number of trials."); + } + + Object value1 = values[0]; + Object value2 = values[1]; + Object value3 = values[2]; + + Matrix matrix = null; + int k = 0; + int trials=0; + + if(value1 instanceof Matrix) { + matrix = (Matrix)value1; + } else { + throw new IOException("The first parameter for multiKmeans should be the observation matrix."); + } + + if(value2 instanceof Number) { + k = ((Number)value2).intValue(); + } else { + throw new IOException("The second parameter for multiKmeans should be k."); + } + + if(value3 instanceof Number) { + trials= ((Number)value3).intValue(); + } else { + throw new IOException("The third parameter for multiKmeans should be trials."); + } + + KMeansPlusPlusClusterer kmeans = new KMeansPlusPlusClusterer(k, maxIterations); + MultiKMeansPlusPlusClusterer multiKmeans = new MultiKMeansPlusPlusClusterer(kmeans, trials); + + List points = new ArrayList(); + double[][] data = matrix.getData(); + + List ids = matrix.getRowLabels(); + + for(int i=0; i tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List> cluster1 = (List>)tuples.get(0).get("g"); + List> cluster2 = (List>)tuples.get(0).get("h"); + List> centroids = (List>)tuples.get(0).get("i"); + List labels1 = (List)tuples.get(0).get("j"); + List labels2 = (List)tuples.get(0).get("k"); + + assertEquals(cluster1.size(), 2); + assertEquals(cluster2.size(), 2); + assertEquals(centroids.size(), 2); + + //Assert that the docs are not in both clusters + assertTrue(!(labels1.contains("doc1") && labels2.contains("doc1"))); + assertTrue(!(labels1.contains("doc2") && labels2.contains("doc2"))); + assertTrue(!(labels1.contains("doc3") && labels2.contains("doc3"))); + assertTrue(!(labels1.contains("doc4") && labels2.contains("doc4"))); + + //Assert that (doc1 and doc2) or (doc3 and doc4) are in labels1 + assertTrue((labels1.contains("doc1") && labels1.contains("doc2")) || + ((labels1.contains("doc3") && labels1.contains("doc4")))); + + //Assert that (doc1 and doc2) or (doc3 and doc4) are in labels2 + assertTrue((labels2.contains("doc1") && labels2.contains("doc2")) || + ((labels2.contains("doc3") && labels2.contains("doc4")))); + + if(labels1.contains("doc1")) { + assertEquals(centroids.get(0).get(0).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(0).get(1).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(0).get(2).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(0).get(3).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(0).get(4).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(0).get(5).doubleValue(), 0.0, 0.0); + + assertEquals(centroids.get(1).get(0).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(1).get(1).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(1).get(2).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(1).get(3).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(1).get(4).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(1).get(5).doubleValue(), 1.0, 0.0); + } else { + assertEquals(centroids.get(0).get(0).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(0).get(1).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(0).get(2).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(0).get(3).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(0).get(4).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(0).get(5).doubleValue(), 1.0, 0.0); + + assertEquals(centroids.get(1).get(0).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(1).get(1).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(1).get(2).doubleValue(), 1.0, 0.0); + assertEquals(centroids.get(1).get(3).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(1).get(4).doubleValue(), 0.0, 0.0); + assertEquals(centroids.get(1).get(5).doubleValue(), 0.0, 0.0); + } + } + + + @Test public void testFuzzyKmeans() throws Exception {