From 129bca897596bbf8fd250d21758402d416dbcd17 Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Mon, 4 Apr 2011 18:32:52 +0000 Subject: [PATCH] Improved robustness of k-means++ algorithm, by tracking changes in points assignments to clusters git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1088702 13f79535-47bb-0310-9956-ffa450edef68 --- .../clustering/KMeansPlusPlusClusterer.java | 53 +++++++++++++------ src/site/xdoc/changes.xml | 4 ++ 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java index 30b7ccb8c..1b2ca057e 100644 --- a/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java +++ b/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java @@ -108,12 +108,16 @@ public class KMeansPlusPlusClusterer> { // create the initial clusters List> clusters = chooseInitialCenters(points, k, random); - assignPointsToClusters(clusters, points); + + // create an array containing the latest assignment of a point to a cluster + // no need to initialize the array, as it will be filled with the first assignment + int[] assignments = new int[points.size()]; + assignPointsToClusters(clusters, points, assignments); // iterate through updating the centers until we're done final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; for (int count = 0; count < max; count++) { - boolean clusteringChanged = false; + boolean emptyCluster = false; List> newClusters = new ArrayList>(); for (final Cluster cluster : clusters) { final T newCenter; @@ -131,20 +135,20 @@ public class KMeansPlusPlusClusterer> { default : throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); } - clusteringChanged = true; + emptyCluster = true; } else { newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); - if (!newCenter.equals(cluster.getCenter())) { - clusteringChanged = true; - } } newClusters.add(new Cluster(newCenter)); } - if (!clusteringChanged) { + int changes = assignPointsToClusters(newClusters, points, assignments); + clusters = newClusters; + + // if there were no more changes in the point-to-cluster assignment + // and there are no empty clusters left, return the current clusters + if (changes == 0 && !emptyCluster) { return clusters; } - assignPointsToClusters(newClusters, points); - clusters = newClusters; } return clusters; } @@ -155,13 +159,25 @@ public class KMeansPlusPlusClusterer> { * @param type of the points to cluster * @param clusters the {@link Cluster}s to add the points to * @param points the points to add to the given {@link Cluster}s + * @return the number of points assigned to different clusters as the iteration before */ - private static > void - assignPointsToClusters(final Collection> clusters, final Collection points) { + private static > int + assignPointsToClusters(final List> clusters, final Collection points, + final int[] assignments) { + int assignedDifferently = 0; + int pointIndex = 0; for (final T p : points) { - Cluster cluster = getNearestCluster(clusters, p); + int clusterIndex = getNearestCluster(clusters, p); + if (clusterIndex != assignments[pointIndex]) { + assignedDifferently++; + } + + Cluster cluster = clusters.get(clusterIndex); cluster.addPoint(p); + assignments[pointIndex++] = clusterIndex; } + + return assignedDifferently; } /** @@ -190,7 +206,8 @@ public class KMeansPlusPlusClusterer> { double sum = 0; for (int i = 0; i < pointSet.size(); i++) { final T p = pointSet.get(i); - final Cluster nearest = getNearestCluster(resultSet, p); + int nearestClusterIndex = getNearestCluster(resultSet, p); + final Cluster nearest = resultSet.get(nearestClusterIndex); final double d = p.distanceFrom(nearest.getCenter()); sum += d * d; dx2[i] = sum; @@ -329,18 +346,20 @@ public class KMeansPlusPlusClusterer> { * @param type of the points to cluster * @param clusters the {@link Cluster}s to search * @param point the point to find the nearest {@link Cluster} for - * @return the nearest {@link Cluster} to the given point + * @return the index of the nearest {@link Cluster} to the given point */ - private static > Cluster + private static > int getNearestCluster(final Collection> clusters, final T point) { double minDistance = Double.MAX_VALUE; - Cluster minCluster = null; + int clusterIndex = 0; + int minCluster = 0; for (final Cluster c : clusters) { final double distance = point.distanceFrom(c.getCenter()); if (distance < minDistance) { minDistance = distance; - minCluster = c; + minCluster = clusterIndex; } + clusterIndex++; } return minCluster; } diff --git a/src/site/xdoc/changes.xml b/src/site/xdoc/changes.xml index a953b91d2..c8b1ec95a 100644 --- a/src/site/xdoc/changes.xml +++ b/src/site/xdoc/changes.xml @@ -52,6 +52,10 @@ The type attribute can be add,update,fix,remove. If the output is not quite correct, check for invisible trailing spaces! --> + + Improved robustness of k-means++ algorithm, by tracking changes in points assignments + to clusters. + Changed MathUtils.round(double,int,int) to propagate rather than wrap runtime exceptions. Instead of MathRuntimeException, this method