Improved robustness of k-means++ algorithm, by tracking changes in points assignments to clusters
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1088702 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
b7d7b20ad0
commit
129bca8975
|
@ -108,12 +108,16 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
||||||
|
|
||||||
// create the initial clusters
|
// create the initial clusters
|
||||||
List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
|
List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
|
||||||
assignPointsToClusters(clusters, points);
|
|
||||||
|
// create an array containing the latest assignment of a point to a cluster
|
||||||
|
// no need to initialize the array, as it will be filled with the first assignment
|
||||||
|
int[] assignments = new int[points.size()];
|
||||||
|
assignPointsToClusters(clusters, points, assignments);
|
||||||
|
|
||||||
// iterate through updating the centers until we're done
|
// iterate through updating the centers until we're done
|
||||||
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
|
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
|
||||||
for (int count = 0; count < max; count++) {
|
for (int count = 0; count < max; count++) {
|
||||||
boolean clusteringChanged = false;
|
boolean emptyCluster = false;
|
||||||
List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
|
List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
|
||||||
for (final Cluster<T> cluster : clusters) {
|
for (final Cluster<T> cluster : clusters) {
|
||||||
final T newCenter;
|
final T newCenter;
|
||||||
|
@ -131,20 +135,20 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
||||||
default :
|
default :
|
||||||
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||||
}
|
}
|
||||||
clusteringChanged = true;
|
emptyCluster = true;
|
||||||
} else {
|
} else {
|
||||||
newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
|
newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
|
||||||
if (!newCenter.equals(cluster.getCenter())) {
|
|
||||||
clusteringChanged = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
newClusters.add(new Cluster<T>(newCenter));
|
newClusters.add(new Cluster<T>(newCenter));
|
||||||
}
|
}
|
||||||
if (!clusteringChanged) {
|
int changes = assignPointsToClusters(newClusters, points, assignments);
|
||||||
|
clusters = newClusters;
|
||||||
|
|
||||||
|
// if there were no more changes in the point-to-cluster assignment
|
||||||
|
// and there are no empty clusters left, return the current clusters
|
||||||
|
if (changes == 0 && !emptyCluster) {
|
||||||
return clusters;
|
return clusters;
|
||||||
}
|
}
|
||||||
assignPointsToClusters(newClusters, points);
|
|
||||||
clusters = newClusters;
|
|
||||||
}
|
}
|
||||||
return clusters;
|
return clusters;
|
||||||
}
|
}
|
||||||
|
@ -155,13 +159,25 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
||||||
* @param <T> type of the points to cluster
|
* @param <T> type of the points to cluster
|
||||||
* @param clusters the {@link Cluster}s to add the points to
|
* @param clusters the {@link Cluster}s to add the points to
|
||||||
* @param points the points to add to the given {@link Cluster}s
|
* @param points the points to add to the given {@link Cluster}s
|
||||||
|
* @return the number of points assigned to different clusters as the iteration before
|
||||||
*/
|
*/
|
||||||
private static <T extends Clusterable<T>> void
|
private static <T extends Clusterable<T>> int
|
||||||
assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
|
assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
|
||||||
|
final int[] assignments) {
|
||||||
|
int assignedDifferently = 0;
|
||||||
|
int pointIndex = 0;
|
||||||
for (final T p : points) {
|
for (final T p : points) {
|
||||||
Cluster<T> cluster = getNearestCluster(clusters, p);
|
int clusterIndex = getNearestCluster(clusters, p);
|
||||||
|
if (clusterIndex != assignments[pointIndex]) {
|
||||||
|
assignedDifferently++;
|
||||||
|
}
|
||||||
|
|
||||||
|
Cluster<T> cluster = clusters.get(clusterIndex);
|
||||||
cluster.addPoint(p);
|
cluster.addPoint(p);
|
||||||
|
assignments[pointIndex++] = clusterIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return assignedDifferently;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -190,7 +206,8 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
||||||
double sum = 0;
|
double sum = 0;
|
||||||
for (int i = 0; i < pointSet.size(); i++) {
|
for (int i = 0; i < pointSet.size(); i++) {
|
||||||
final T p = pointSet.get(i);
|
final T p = pointSet.get(i);
|
||||||
final Cluster<T> nearest = getNearestCluster(resultSet, p);
|
int nearestClusterIndex = getNearestCluster(resultSet, p);
|
||||||
|
final Cluster<T> nearest = resultSet.get(nearestClusterIndex);
|
||||||
final double d = p.distanceFrom(nearest.getCenter());
|
final double d = p.distanceFrom(nearest.getCenter());
|
||||||
sum += d * d;
|
sum += d * d;
|
||||||
dx2[i] = sum;
|
dx2[i] = sum;
|
||||||
|
@ -329,18 +346,20 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
||||||
* @param <T> type of the points to cluster
|
* @param <T> type of the points to cluster
|
||||||
* @param clusters the {@link Cluster}s to search
|
* @param clusters the {@link Cluster}s to search
|
||||||
* @param point the point to find the nearest {@link Cluster} for
|
* @param point the point to find the nearest {@link Cluster} for
|
||||||
* @return the nearest {@link Cluster} to the given point
|
* @return the index of the nearest {@link Cluster} to the given point
|
||||||
*/
|
*/
|
||||||
private static <T extends Clusterable<T>> Cluster<T>
|
private static <T extends Clusterable<T>> int
|
||||||
getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
|
getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
|
||||||
double minDistance = Double.MAX_VALUE;
|
double minDistance = Double.MAX_VALUE;
|
||||||
Cluster<T> minCluster = null;
|
int clusterIndex = 0;
|
||||||
|
int minCluster = 0;
|
||||||
for (final Cluster<T> c : clusters) {
|
for (final Cluster<T> c : clusters) {
|
||||||
final double distance = point.distanceFrom(c.getCenter());
|
final double distance = point.distanceFrom(c.getCenter());
|
||||||
if (distance < minDistance) {
|
if (distance < minDistance) {
|
||||||
minDistance = distance;
|
minDistance = distance;
|
||||||
minCluster = c;
|
minCluster = clusterIndex;
|
||||||
}
|
}
|
||||||
|
clusterIndex++;
|
||||||
}
|
}
|
||||||
return minCluster;
|
return minCluster;
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,6 +52,10 @@ The <action> type attribute can be add,update,fix,remove.
|
||||||
If the output is not quite correct, check for invisible trailing spaces!
|
If the output is not quite correct, check for invisible trailing spaces!
|
||||||
-->
|
-->
|
||||||
<release version="3.0" date="TBD" description="TBD">
|
<release version="3.0" date="TBD" description="TBD">
|
||||||
|
<action dev="luc" type="fix" issue="MATH-547" due-to="Thomas Neidhart">
|
||||||
|
Improved robustness of k-means++ algorithm, by tracking changes in points assignments
|
||||||
|
to clusters.
|
||||||
|
</action>
|
||||||
<action dev="psteitz" type="update" issue="MATH-555">
|
<action dev="psteitz" type="update" issue="MATH-555">
|
||||||
Changed MathUtils.round(double,int,int) to propagate rather than
|
Changed MathUtils.round(double,int,int) to propagate rather than
|
||||||
wrap runtime exceptions. Instead of MathRuntimeException, this method
|
wrap runtime exceptions. Instead of MathRuntimeException, this method
|
||||||
|
|
Loading…
Reference in New Issue