Improved robustness of k-means++ algorithm, by tracking changes in points assignments to clusters

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1088702 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2011-04-04 18:32:52 +00:00
parent b7d7b20ad0
commit 129bca8975
2 changed files with 40 additions and 17 deletions

View File

@ -108,12 +108,16 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
// create the initial clusters // create the initial clusters
List<Cluster<T>> clusters = chooseInitialCenters(points, k, random); List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
assignPointsToClusters(clusters, points);
// create an array containing the latest assignment of a point to a cluster
// no need to initialize the array, as it will be filled with the first assignment
int[] assignments = new int[points.size()];
assignPointsToClusters(clusters, points, assignments);
// iterate through updating the centers until we're done // iterate through updating the centers until we're done
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
for (int count = 0; count < max; count++) { for (int count = 0; count < max; count++) {
boolean clusteringChanged = false; boolean emptyCluster = false;
List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>(); List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
for (final Cluster<T> cluster : clusters) { for (final Cluster<T> cluster : clusters) {
final T newCenter; final T newCenter;
@ -131,20 +135,20 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
default : default :
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
} }
clusteringChanged = true; emptyCluster = true;
} else { } else {
newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
if (!newCenter.equals(cluster.getCenter())) {
clusteringChanged = true;
}
} }
newClusters.add(new Cluster<T>(newCenter)); newClusters.add(new Cluster<T>(newCenter));
} }
if (!clusteringChanged) { int changes = assignPointsToClusters(newClusters, points, assignments);
clusters = newClusters;
// if there were no more changes in the point-to-cluster assignment
// and there are no empty clusters left, return the current clusters
if (changes == 0 && !emptyCluster) {
return clusters; return clusters;
} }
assignPointsToClusters(newClusters, points);
clusters = newClusters;
} }
return clusters; return clusters;
} }
@ -155,13 +159,25 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
* @param <T> type of the points to cluster * @param <T> type of the points to cluster
* @param clusters the {@link Cluster}s to add the points to * @param clusters the {@link Cluster}s to add the points to
* @param points the points to add to the given {@link Cluster}s * @param points the points to add to the given {@link Cluster}s
* @return the number of points assigned to different clusters as the iteration before
*/ */
private static <T extends Clusterable<T>> void private static <T extends Clusterable<T>> int
assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) { assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
final int[] assignments) {
int assignedDifferently = 0;
int pointIndex = 0;
for (final T p : points) { for (final T p : points) {
Cluster<T> cluster = getNearestCluster(clusters, p); int clusterIndex = getNearestCluster(clusters, p);
if (clusterIndex != assignments[pointIndex]) {
assignedDifferently++;
}
Cluster<T> cluster = clusters.get(clusterIndex);
cluster.addPoint(p); cluster.addPoint(p);
assignments[pointIndex++] = clusterIndex;
} }
return assignedDifferently;
} }
/** /**
@ -190,7 +206,8 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
double sum = 0; double sum = 0;
for (int i = 0; i < pointSet.size(); i++) { for (int i = 0; i < pointSet.size(); i++) {
final T p = pointSet.get(i); final T p = pointSet.get(i);
final Cluster<T> nearest = getNearestCluster(resultSet, p); int nearestClusterIndex = getNearestCluster(resultSet, p);
final Cluster<T> nearest = resultSet.get(nearestClusterIndex);
final double d = p.distanceFrom(nearest.getCenter()); final double d = p.distanceFrom(nearest.getCenter());
sum += d * d; sum += d * d;
dx2[i] = sum; dx2[i] = sum;
@ -329,18 +346,20 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
* @param <T> type of the points to cluster * @param <T> type of the points to cluster
* @param clusters the {@link Cluster}s to search * @param clusters the {@link Cluster}s to search
* @param point the point to find the nearest {@link Cluster} for * @param point the point to find the nearest {@link Cluster} for
* @return the nearest {@link Cluster} to the given point * @return the index of the nearest {@link Cluster} to the given point
*/ */
private static <T extends Clusterable<T>> Cluster<T> private static <T extends Clusterable<T>> int
getNearestCluster(final Collection<Cluster<T>> clusters, final T point) { getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
double minDistance = Double.MAX_VALUE; double minDistance = Double.MAX_VALUE;
Cluster<T> minCluster = null; int clusterIndex = 0;
int minCluster = 0;
for (final Cluster<T> c : clusters) { for (final Cluster<T> c : clusters) {
final double distance = point.distanceFrom(c.getCenter()); final double distance = point.distanceFrom(c.getCenter());
if (distance < minDistance) { if (distance < minDistance) {
minDistance = distance; minDistance = distance;
minCluster = c; minCluster = clusterIndex;
} }
clusterIndex++;
} }
return minCluster; return minCluster;
} }

View File

@ -52,6 +52,10 @@ The <action> type attribute can be add,update,fix,remove.
If the output is not quite correct, check for invisible trailing spaces! If the output is not quite correct, check for invisible trailing spaces!
--> -->
<release version="3.0" date="TBD" description="TBD"> <release version="3.0" date="TBD" description="TBD">
<action dev="luc" type="fix" issue="MATH-547" due-to="Thomas Neidhart">
Improved robustness of k-means++ algorithm, by tracking changes in points assignments
to clusters.
</action>
<action dev="psteitz" type="update" issue="MATH-555"> <action dev="psteitz" type="update" issue="MATH-555">
Changed MathUtils.round(double,int,int) to propagate rather than Changed MathUtils.round(double,int,int) to propagate rather than
wrap runtime exceptions. Instead of MathRuntimeException, this method wrap runtime exceptions. Instead of MathRuntimeException, this method