From ac7334d69b7662aa4780053ffb39a553813ac8cd Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Sat, 23 Oct 2010 19:30:48 +0000 Subject: [PATCH] Fixed k-means++ to add several strategies to deal with empty clusters that may appear during iterations JIRA: MATH-429 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/branches/MATH_2_X@1026666 13f79535-47bb-0310-9956-ffa450edef68 --- .../math/exception/ConvergenceException.java | 61 ++++++ .../exception/MathIllegalStateException.java | 94 ++++++++++ .../math/exception/util/LocalizedFormats.java | 1 + .../clustering/KMeansPlusPlusClusterer.java | 177 +++++++++++++++++- .../LocalizedFormats_fr.properties | 1 + src/site/xdoc/changes.xml | 4 + .../KMeansPlusPlusClustererTest.java | 50 +++++ 7 files changed, 385 insertions(+), 3 deletions(-) create mode 100644 src/main/java/org/apache/commons/math/exception/ConvergenceException.java create mode 100644 src/main/java/org/apache/commons/math/exception/MathIllegalStateException.java diff --git a/src/main/java/org/apache/commons/math/exception/ConvergenceException.java b/src/main/java/org/apache/commons/math/exception/ConvergenceException.java new file mode 100644 index 000000000..5ea585691 --- /dev/null +++ b/src/main/java/org/apache/commons/math/exception/ConvergenceException.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.exception; + +import org.apache.commons.math.exception.util.Localizable; +import org.apache.commons.math.exception.util.LocalizedFormats; + +/** + * Error thrown when a numerical computation can not be performed because the + * numerical result failed to converge to a finite value. + * + * @since 2.2 + * @version $Revision$ $Date$ + */ +public class ConvergenceException extends MathIllegalStateException { + /** Serializable version Id. */ + private static final long serialVersionUID = 4330003017885151975L; + + /** + * Construct the exception. + */ + public ConvergenceException() { + this(null); + } + /** + * Construct the exception with a specific context. + * + * @param specific Specific contexte pattern. + */ + public ConvergenceException(Localizable specific) { + this(specific, + LocalizedFormats.CONVERGENCE_FAILED, + null); + } + /** + * Construct the exception with a specific context and arguments. + * + * @param specific Specific contexte pattern. + * @param args Arguments. + */ + public ConvergenceException(Localizable specific, + Object ... args) { + super(specific, + LocalizedFormats.CONVERGENCE_FAILED, + args); + } +} diff --git a/src/main/java/org/apache/commons/math/exception/MathIllegalStateException.java b/src/main/java/org/apache/commons/math/exception/MathIllegalStateException.java new file mode 100644 index 000000000..12f9bce13 --- /dev/null +++ b/src/main/java/org/apache/commons/math/exception/MathIllegalStateException.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.exception; + +import java.util.Locale; + +import org.apache.commons.math.exception.util.ArgUtils; +import org.apache.commons.math.exception.util.MessageFactory; +import org.apache.commons.math.exception.util.Localizable; + +/** + * Base class for all exceptions that signal a mismatch between the + * current state and the user's expectations. + * + * @since 2.2 + * @version $Revision$ $Date$ + */ +public class MathIllegalStateException extends IllegalStateException { + + /** Serializable version Id. */ + private static final long serialVersionUID = -6024911025449780478L; + + /** + * Pattern used to build the message (specific context). + */ + private final Localizable specific; + /** + * Pattern used to build the message (general problem description). + */ + private final Localizable general; + /** + * Arguments used to build the message. + */ + private final Object[] arguments; + + /** + * @param specific Message pattern providing the specific context of + * the error. + * @param general Message pattern explaining the cause of the error. + * @param args Arguments. + */ + public MathIllegalStateException(Localizable specific, + Localizable general, + Object ... args) { + this.specific = specific; + this.general = general; + arguments = ArgUtils.flatten(args); + } + /** + * @param general Message pattern explaining the cause of the error. + * @param args Arguments. + */ + public MathIllegalStateException(Localizable general, + Object ... args) { + this(null, general, args); + } + + /** + * Get the message in a specified locale. + * + * @param locale Locale in which the message should be translated. + * + * @return the localized message. + */ + public String getMessage(final Locale locale) { + return MessageFactory.buildMessage(locale, specific, general, arguments); + } + + /** {@inheritDoc} */ + @Override + public String getMessage() { + return getMessage(Locale.US); + } + + /** {@inheritDoc} */ + @Override + public String getLocalizedMessage() { + return getMessage(Locale.getDefault()); + } +} diff --git a/src/main/java/org/apache/commons/math/exception/util/LocalizedFormats.java b/src/main/java/org/apache/commons/math/exception/util/LocalizedFormats.java index 5b1755f5f..4fa873074 100644 --- a/src/main/java/org/apache/commons/math/exception/util/LocalizedFormats.java +++ b/src/main/java/org/apache/commons/math/exception/util/LocalizedFormats.java @@ -84,6 +84,7 @@ public enum LocalizedFormats implements Localizable { DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN("Discrete cumulative probability function returned NaN for argument {0}"), DISTRIBUTION_NOT_LOADED("distribution not loaded"), DUPLICATED_ABSCISSA("Abscissa {0} is duplicated at both indices {1} and {2}"), + EMPTY_CLUSTER_IN_K_MEANS("empty cluster in k-means"), EMPTY_POLYNOMIALS_COEFFICIENTS_ARRAY("empty polynomials coefficients array"), /* keep */ EMPTY_SELECTED_COLUMN_INDEX_ARRAY("empty selected column index array"), EMPTY_SELECTED_ROW_INDEX_ARRAY("empty selected row index array"), diff --git a/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java index 47355df0b..56790b1c4 100644 --- a/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java +++ b/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java @@ -22,6 +22,10 @@ import java.util.Collection; import java.util.List; import java.util.Random; +import org.apache.commons.math.exception.ConvergenceException; +import org.apache.commons.math.exception.util.LocalizedFormats; +import org.apache.commons.math.stat.descriptive.moment.Variance; + /** * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. * @param type of the points to cluster @@ -31,14 +35,49 @@ import java.util.Random; */ public class KMeansPlusPlusClusterer> { + /** Strategies to use for replacing an empty cluster. */ + public static enum EmptyClusterStrategy { + + /** Split the cluster with largest distance variance. */ + LARGEST_VARIANCE, + + /** Split the cluster with largest number of points. */ + LARGEST_POINTS_NUMBER, + + /** Create a cluster around the point farthest from its centroid. */ + FARTHEST_POINT, + + /** Generate an error. */ + ERROR + + } + /** Random generator for choosing initial centers. */ private final Random random; + /** Selected strategy for empty clusters. */ + private final EmptyClusterStrategy emptyStrategy; + /** Build a clusterer. + *

+ * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + *

* @param random random generator to use for choosing initial centers */ public KMeansPlusPlusClusterer(final Random random) { - this.random = random; + this(random, EmptyClusterStrategy.LARGEST_VARIANCE); + } + + /** Build a clusterer. + * @param random random generator to use for choosing initial centers + * @param emptyStrategy strategy to use for handling empty clusters that + * may appear during algorithm iterations + * @since 2.2 + */ + public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) { + this.random = random; + this.emptyStrategy = emptyStrategy; } /** @@ -62,9 +101,27 @@ public class KMeansPlusPlusClusterer> { boolean clusteringChanged = false; List> newClusters = new ArrayList>(); for (final Cluster cluster : clusters) { - final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); - if (!newCenter.equals(cluster.getCenter())) { + final T newCenter; + if (cluster.getPoints().isEmpty()) { + switch (emptyStrategy) { + case LARGEST_VARIANCE : + newCenter = getPointFromLargestVarianceCluster(clusters); + break; + case LARGEST_POINTS_NUMBER : + newCenter = getPointFromLargestNumberCluster(clusters); + break; + case FARTHEST_POINT : + newCenter = getFarthestPoint(clusters); + break; + default : + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } clusteringChanged = true; + } else { + newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); + if (!newCenter.equals(cluster.getCenter())) { + clusteringChanged = true; + } } newClusters.add(new Cluster(newCenter)); } @@ -140,6 +197,120 @@ public class KMeansPlusPlusClusterer> { } + /** + * Get a random point from the {@link Cluster} with the largest distance variance. + * + * @param type of the points to cluster + * @param clusters the {@link Cluster}s to search + * @return a random point from the selected cluster + */ + private T getPointFromLargestVarianceCluster(final Collection> clusters) { + + double maxVariance = Double.NEGATIVE_INFINITY; + Cluster selected = null; + for (final Cluster cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + // compute the distance variance of the current cluster + final T center = cluster.getCenter(); + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(point.distanceFrom(center)); + } + final double variance = stat.getResult(); + + // select the cluster with the largest variance + if (variance > maxVariance) { + maxVariance = variance; + selected = cluster; + } + + } + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + + } + + /** + * Get a random point from the {@link Cluster} with the largest number of points + * + * @param type of the points to cluster + * @param clusters the {@link Cluster}s to search + * @return a random point from the selected cluster + */ + private T getPointFromLargestNumberCluster(final Collection> clusters) { + + int maxNumber = 0; + Cluster selected = null; + for (final Cluster cluster : clusters) { + + // get the number of points of the current cluster + final int number = cluster.getPoints().size(); + + // select the cluster with the largest number of points + if (number > maxNumber) { + maxNumber = number; + selected = cluster; + } + + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + + } + + /** + * Get the point farthest to its cluster center + * + * @param type of the points to cluster + * @param clusters the {@link Cluster}s to search + * @return point farthest to its cluster center + */ + private T getFarthestPoint(final Collection> clusters) { + + double maxDistance = Double.NEGATIVE_INFINITY; + Cluster selectedCluster = null; + int selectedPoint = -1; + for (final Cluster cluster : clusters) { + + // get the farthest point + final T center = cluster.getCenter(); + final List points = cluster.getPoints(); + for (int i = 0; i < points.size(); ++i) { + final double distance = points.get(i).distanceFrom(center); + if (distance > maxDistance) { + maxDistance = distance; + selectedCluster = cluster; + selectedPoint = i; + } + } + + } + + // did we find at least one non-empty cluster ? + if (selectedCluster == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + return selectedCluster.getPoints().remove(selectedPoint); + + } + /** * Returns the nearest {@link Cluster} to the given point * diff --git a/src/main/resources/META-INF/localization/LocalizedFormats_fr.properties b/src/main/resources/META-INF/localization/LocalizedFormats_fr.properties index c65096853..14f0b3524 100644 --- a/src/main/resources/META-INF/localization/LocalizedFormats_fr.properties +++ b/src/main/resources/META-INF/localization/LocalizedFormats_fr.properties @@ -56,6 +56,7 @@ DIMENSIONS_MISMATCH_SIMPLE = dimensions incoh\u00e9rentes {0} != {1} DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN = Discr\u00e8tes fonction de probabilit\u00e9 cumulative retourn\u00e9 NaN \u00e0 l''argument de {0} DISTRIBUTION_NOT_LOADED = aucune distribution n''a \u00e9t\u00e9 charg\u00e9e DUPLICATED_ABSCISSA = Abscisse {0} dupliqu\u00e9e aux indices {1} et {2} +EMPTY_CLUSTER_IN_K_MEANS = groupe vide dans l''algorithme des k-moyennes EMPTY_POLYNOMIALS_COEFFICIENTS_ARRAY = tableau de coefficients polyn\u00f4miaux vide EMPTY_SELECTED_COLUMN_INDEX_ARRAY = tableau des indices de colonnes s\u00e9lectionn\u00e9es vide EMPTY_SELECTED_ROW_INDEX_ARRAY = tableau des indices de lignes s\u00e9lectionn\u00e9es vide diff --git a/src/site/xdoc/changes.xml b/src/site/xdoc/changes.xml index dc71d95ef..459ac9e6d 100644 --- a/src/site/xdoc/changes.xml +++ b/src/site/xdoc/changes.xml @@ -52,6 +52,10 @@ The type attribute can be add,update,fix,remove. If the output is not quite correct, check for invisible trailing spaces! --> + + Fixed k-means++ to add several strategies to deal with empty clusters that may appear + during iterations + Improved Percentile performance by using a selection algorithm instead of a complete sort, and by allowing caching data array and pivots when several diff --git a/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java b/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java index 3113ad1d5..33b2d8af8 100644 --- a/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java +++ b/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.List; import java.util.Random; +import org.junit.Assert; import org.junit.Test; public class KMeansPlusPlusClustererTest { @@ -116,4 +117,53 @@ public class KMeansPlusPlusClustererTest { } + @Test + public void testCertainSpace() { + KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = { + KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE, + KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER, + KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT + }; + for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) { + KMeansPlusPlusClusterer transformer = + new KMeansPlusPlusClusterer(new Random(1746432956321l), strategy); + int numberOfVariables = 27; + // initialise testvalues + int position1 = 1; + int position2 = position1 + numberOfVariables; + int position3 = position2 + numberOfVariables; + int position4 = position3 + numberOfVariables; + // testvalues will be multiplied + int multiplier = 1000000; + + EuclideanIntegerPoint[] breakingPoints = new EuclideanIntegerPoint[numberOfVariables]; + // define the space which will break the cluster algorithm + for (int i = 0; i < numberOfVariables; i++) { + int points[] = { position1, position2, position3, position4 }; + // multiply the values + for (int j = 0; j < points.length; j++) { + points[j] = points[j] * multiplier; + } + EuclideanIntegerPoint euclideanIntegerPoint = new EuclideanIntegerPoint(points); + breakingPoints[i] = euclideanIntegerPoint; + position1 = position1 + numberOfVariables; + position2 = position2 + numberOfVariables; + position3 = position3 + numberOfVariables; + position4 = position4 + numberOfVariables; + } + + for (int n = 2; n < 27; ++n) { + List> clusters = + transformer.cluster(Arrays.asList(breakingPoints), n, 100); + Assert.assertEquals(n, clusters.size()); + int sum = 0; + for (Cluster cluster : clusters) { + sum += cluster.getPoints().size(); + } + Assert.assertEquals(numberOfVariables, sum); + } + } + + } + }