From 84102c0c4ca34d5d024cf6d02c733c8583fb1b2a Mon Sep 17 00:00:00 2001 From: CT Date: Wed, 11 Mar 2020 01:48:26 +0800 Subject: [PATCH] MATH-1524 Move "chooseInitialCenters" out of the KMeansPlusPlusClusterer --- .../clustering/KMeansPlusPlusClusterer.java | 134 +------------ .../initialization/CentroidInitializer.java | 39 ++++ .../KMeansPlusPlusCentroidInitializer.java | 186 ++++++++++++++++++ .../RandomCentroidInitializer.java | 65 ++++++ .../CentroidInitializerTest.java | 49 +++++ 5 files changed, 347 insertions(+), 126 deletions(-) create mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java create mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java create mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java create mode 100644 src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java index e5dea41f4..e05918e0a 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java @@ -26,6 +26,8 @@ import org.apache.commons.math4.exception.ConvergenceException; import org.apache.commons.math4.exception.MathIllegalArgumentException; import org.apache.commons.math4.exception.NumberIsTooSmallException; import org.apache.commons.math4.exception.util.LocalizedFormats; +import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer; +import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer; import org.apache.commons.math4.ml.distance.DistanceMeasure; import org.apache.commons.math4.ml.distance.EuclideanDistance; import org.apache.commons.rng.simple.RandomSource; @@ -70,6 +72,9 @@ public class KMeansPlusPlusClusterer extends Clusterer /** Selected strategy for empty clusters. */ private final EmptyClusterStrategy emptyStrategy; + /** Clusters centroids initializer. */ + private final CentroidInitializer centroidInitializer; + /** Build a clusterer. *

* The default strategy for handling empty clusters that may appear during @@ -148,6 +153,8 @@ public class KMeansPlusPlusClusterer extends Clusterer this.maxIterations = maxIterations; this.random = random; this.emptyStrategy = emptyStrategy; + // Use K-means++ to choose the initial centers. + this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure, random); } /** @@ -203,7 +210,7 @@ public class KMeansPlusPlusClusterer extends Clusterer } // create the initial clusters - List> clusters = chooseInitialCenters(points); + List> clusters = centroidInitializer.selectCentroids(points, k); // create an array containing the latest assignment of a point to a cluster // no need to initialize the array, as it will be filled with the first assignment @@ -276,131 +283,6 @@ public class KMeansPlusPlusClusterer extends Clusterer return assignedDifferently; } - /** - * Use K-means++ to choose the initial centers. - * - * @param points the points to choose the initial centers from - * @return the initial centers - */ - private List> chooseInitialCenters(final Collection points) { - - // Convert to list for indexed access. Make it unmodifiable, since removal of items - // would screw up the logic of this method. - final List pointList = Collections.unmodifiableList(new ArrayList<> (points)); - - // The number of points in the list. - final int numPoints = pointList.size(); - - // Set the corresponding element in this array to indicate when - // elements of pointList are no longer available. - final boolean[] taken = new boolean[numPoints]; - - // The resulting list of initial centers. - final List> resultSet = new ArrayList<>(); - - // Choose one center uniformly at random from among the data points. - final int firstPointIndex = random.nextInt(numPoints); - - final T firstPoint = pointList.get(firstPointIndex); - - resultSet.add(new CentroidCluster(firstPoint)); - - // Must mark it as taken - taken[firstPointIndex] = true; - - // To keep track of the minimum distance squared of elements of - // pointList to elements of resultSet. - final double[] minDistSquared = new double[numPoints]; - - // Initialize the elements. Since the only point in resultSet is firstPoint, - // this is very easy. - for (int i = 0; i < numPoints; i++) { - if (i != firstPointIndex) { // That point isn't considered - double d = distance(firstPoint, pointList.get(i)); - minDistSquared[i] = d*d; - } - } - - while (resultSet.size() < k) { - - // Sum up the squared distances for the points in pointList not - // already taken. - double distSqSum = 0.0; - - for (int i = 0; i < numPoints; i++) { - if (!taken[i]) { - distSqSum += minDistSquared[i]; - } - } - - // Add one new data point as a center. Each point x is chosen with - // probability proportional to D(x)2 - final double r = random.nextDouble() * distSqSum; - - // The index of the next point to be added to the resultSet. - int nextPointIndex = -1; - - // Sum through the squared min distances again, stopping when - // sum >= r. - double sum = 0.0; - for (int i = 0; i < numPoints; i++) { - if (!taken[i]) { - sum += minDistSquared[i]; - if (sum >= r) { - nextPointIndex = i; - break; - } - } - } - - // If it's not set to >= 0, the point wasn't found in the previous - // for loop, probably because distances are extremely small. Just pick - // the last available point. - if (nextPointIndex == -1) { - for (int i = numPoints - 1; i >= 0; i--) { - if (!taken[i]) { - nextPointIndex = i; - break; - } - } - } - - // We found one. - if (nextPointIndex >= 0) { - - final T p = pointList.get(nextPointIndex); - - resultSet.add(new CentroidCluster (p)); - - // Mark it as taken. - taken[nextPointIndex] = true; - - if (resultSet.size() < k) { - // Now update elements of minDistSquared. We only have to compute - // the distance to the new center to do this. - for (int j = 0; j < numPoints; j++) { - // Only have to worry about the points still not taken. - if (!taken[j]) { - double d = distance(p, pointList.get(j)); - double d2 = d * d; - if (d2 < minDistSquared[j]) { - minDistSquared[j] = d2; - } - } - } - } - - } else { - // None found -- - // Break from the while loop to prevent - // an infinite loop. - break; - } - } - - return resultSet; - } - /** * Get a random point from the {@link Cluster} with the largest distance variance. * diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java new file mode 100644 index 000000000..dcddc53bc --- /dev/null +++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math4.ml.clustering.initialization; + +import org.apache.commons.math4.ml.clustering.CentroidCluster; +import org.apache.commons.math4.ml.clustering.Clusterable; + +import java.util.Collection; +import java.util.List; + +/** + * Interface abstract the algorithm for clusterer to choose the initial centers. + */ +public interface CentroidInitializer { + + /** + * Choose the initial centers. + * + * @param points the points to choose the initial centers from + * @param k The number of clusters + * @return the initial centers + */ + List> selectCentroids(final Collection points, final int k); +} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java new file mode 100644 index 000000000..f0ab288c2 --- /dev/null +++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math4.ml.clustering.initialization; + +import org.apache.commons.math4.ml.clustering.CentroidCluster; +import org.apache.commons.math4.ml.clustering.Clusterable; +import org.apache.commons.math4.ml.distance.DistanceMeasure; +import org.apache.commons.rng.UniformRandomProvider; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +/** + * Use K-means++ to choose the initial centers. + * + * @see K-means++ (wikipedia) + */ +public class KMeansPlusPlusCentroidInitializer implements CentroidInitializer { + private final DistanceMeasure measure; + private final UniformRandomProvider random; + + /** + * Build a K-means++ CentroidInitializer + * @param measure the distance measure to use + * @param random the random to use. + */ + public KMeansPlusPlusCentroidInitializer(final DistanceMeasure measure, final UniformRandomProvider random) { + this.measure = measure; + this.random = random; + } + + /** + * Use K-means++ to choose the initial centers. + * + * @param points the points to choose the initial centers from + * @param k The number of clusters + * @return the initial centers + */ + @Override + public List> selectCentroids(final Collection points, final int k) { + // Convert to list for indexed access. Make it unmodifiable, since removal of items + // would screw up the logic of this method. + final List pointList = Collections.unmodifiableList(new ArrayList<>(points)); + + // The number of points in the list. + final int numPoints = pointList.size(); + + // Set the corresponding element in this array to indicate when + // elements of pointList are no longer available. + final boolean[] taken = new boolean[numPoints]; + + // The resulting list of initial centers. + final List> resultSet = new ArrayList<>(); + + // Choose one center uniformly at random from among the data points. + final int firstPointIndex = random.nextInt(numPoints); + + final T firstPoint = pointList.get(firstPointIndex); + + resultSet.add(new CentroidCluster<>(firstPoint)); + + // Must mark it as taken + taken[firstPointIndex] = true; + + // To keep track of the minimum distance squared of elements of + // pointList to elements of resultSet. + final double[] minDistSquared = new double[numPoints]; + + // Initialize the elements. Since the only point in resultSet is firstPoint, + // this is very easy. + for (int i = 0; i < numPoints; i++) { + if (i != firstPointIndex) { // That point isn't considered + double d = distance(firstPoint, pointList.get(i)); + minDistSquared[i] = d * d; + } + } + + while (resultSet.size() < k) { + + // Sum up the squared distances for the points in pointList not + // already taken. + double distSqSum = 0.0; + + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + distSqSum += minDistSquared[i]; + } + } + + // Add one new data point as a center. Each point x is chosen with + // probability proportional to D(x)2 + final double r = random.nextDouble() * distSqSum; + + // The index of the next point to be added to the resultSet. + int nextPointIndex = -1; + + // Sum through the squared min distances again, stopping when + // sum >= r. + double sum = 0.0; + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + sum += minDistSquared[i]; + if (sum >= r) { + nextPointIndex = i; + break; + } + } + } + + // If it's not set to >= 0, the point wasn't found in the previous + // for loop, probably because distances are extremely small. Just pick + // the last available point. + if (nextPointIndex == -1) { + for (int i = numPoints - 1; i >= 0; i--) { + if (!taken[i]) { + nextPointIndex = i; + break; + } + } + } + + // We found one. + if (nextPointIndex >= 0) { + + final T p = pointList.get(nextPointIndex); + + resultSet.add(new CentroidCluster<>(p)); + + // Mark it as taken. + taken[nextPointIndex] = true; + + if (resultSet.size() < k) { + // Now update elements of minDistSquared. We only have to compute + // the distance to the new center to do this. + for (int j = 0; j < numPoints; j++) { + // Only have to worry about the points still not taken. + if (!taken[j]) { + double d = distance(p, pointList.get(j)); + double d2 = d * d; + if (d2 < minDistSquared[j]) { + minDistSquared[j] = d2; + } + } + } + } + + } else { + // None found -- + // Break from the while loop to prevent + // an infinite loop. + break; + } + } + + return resultSet; + } + + /** + * Calculates the distance between two {@link Clusterable} instances + * with the configured {@link DistanceMeasure}. + * + * @param p1 the first clusterable + * @param p2 the second clusterable + * @return the distance between the two clusterables + */ + protected double distance(final Clusterable p1, final Clusterable p2) { + return measure.compute(p1.getPoint(), p2.getPoint()); + } +} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java new file mode 100644 index 000000000..f3f561d15 --- /dev/null +++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math4.ml.clustering.initialization; + +import org.apache.commons.math4.ml.clustering.CentroidCluster; +import org.apache.commons.math4.ml.clustering.Clusterable; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.ListSampler; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +/** + * Random choose the initial centers. + */ +public class RandomCentroidInitializer implements CentroidInitializer { + private final UniformRandomProvider random; + + /** + * Build a random RandomCentroidInitializer + * + * @param random the random to use. + */ + public RandomCentroidInitializer(final UniformRandomProvider random) { + this.random = random; + } + + /** + * Random choose the initial centers. + * + * @param points the points to choose the initial centers from + * @param k The number of clusters + * @return the initial centers + */ + @Override + public List> selectCentroids(final Collection points, final int k) { + if (k < 1) { + return Collections.emptyList(); + } + final ArrayList list = new ArrayList<>(points); + ListSampler.shuffle(random, list); + final List> result = new ArrayList<>(k); + for (int i = 0; i < k; i++) { + result.add(new CentroidCluster<>(list.get(i))); + } + return result; + } +} diff --git a/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java b/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java new file mode 100644 index 000000000..989fd14c3 --- /dev/null +++ b/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java @@ -0,0 +1,49 @@ +package org.apache.commons.math4.ml.clustering.initialization; + +import org.apache.commons.math4.ml.clustering.CentroidCluster; +import org.apache.commons.math4.ml.clustering.DoublePoint; +import org.apache.commons.math4.ml.distance.EuclideanDistance; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +public class CentroidInitializerTest { + private void test_generate_appropriate_number_of_cluster( + final CentroidInitializer initializer) { + // Generate some data + final List points = new ArrayList<>(); + final UniformRandomProvider rnd = RandomSource.create(RandomSource.MT_64); + for (int i = 0; i < 500; i++) { + double[] p = new double[2]; + p[0] = rnd.nextDouble(); + p[1] = rnd.nextDouble(); + points.add(new DoublePoint(p)); + } + // We can only assert that the centroid initializer + // implementation generate appropriate number of cluster + for (int k = 1; k < 50; k++) { + final List> centroidClusters = + initializer.selectCentroids(points, k); + Assert.assertEquals(k, centroidClusters.size()); + } + } + + @Test + public void test_RandomCentroidInitializer() { + final CentroidInitializer initializer = + new RandomCentroidInitializer(RandomSource.create(RandomSource.MT_64)); + test_generate_appropriate_number_of_cluster(initializer); + } + + @Test + public void test_KMeanPlusPlusCentroidInitializer() { + final CentroidInitializer initializer = + new KMeansPlusPlusCentroidInitializer(new EuclideanDistance(), + RandomSource.create(RandomSource.MT_64)); + test_generate_appropriate_number_of_cluster(initializer); + } +}