Improved k-means++ clustering performances and initial cluster center choice.

JIRA: MATH-584

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1132448 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2011-06-05 16:27:53 +00:00
parent 9dfd54ef0f
commit e00a9b226f
2 changed files with 104 additions and 20 deletions

View File

@ -19,6 +19,7 @@ package org.apache.commons.math.stat.clustering;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
@ -193,41 +194,121 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
private static <T extends Clusterable<T>> List<Cluster<T>>
chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
final List<T> pointSet = new ArrayList<T>(points);
// Convert to list for indexed access. Make it unmodifiable, since removal of items
// would screw up the logic of this method.
final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
// The number of points in the list.
final int numPoints = pointList.size();
// Set the corresponding element in this array to indicate when
// elements of pointList are no longer available.
final boolean[] taken = new boolean[numPoints];
// The resulting list of initial centers.
final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
// Choose one center uniformly at random from among the data points.
final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
final int firstPointIndex = random.nextInt(numPoints);
final T firstPoint = pointList.get(firstPointIndex);
resultSet.add(new Cluster<T>(firstPoint));
final double[] dx2 = new double[pointSet.size()];
// Must mark it as taken
taken[firstPointIndex] = true;
// To keep track of the minimum distance squared of elements of
// pointList to elements of resultSet.
final double[] minDistSquared = new double[numPoints];
// Initialize the elements. Since the only point in resultSet is firstPoint,
// this is very easy.
for (int i = 0; i < numPoints; i++) {
if (i != firstPointIndex) { // That point isn't considered
double d = firstPoint.distanceFrom(pointList.get(i));
minDistSquared[i] = d*d;
}
}
while (resultSet.size() < k) {
// For each data point x, compute D(x), the distance between x and
// the nearest center that has already been chosen.
double sum = 0;
for (int i = 0; i < pointSet.size(); i++) {
final T p = pointSet.get(i);
int nearestClusterIndex = getNearestCluster(resultSet, p);
final Cluster<T> nearest = resultSet.get(nearestClusterIndex);
final double d = p.distanceFrom(nearest.getCenter());
sum += d * d;
dx2[i] = sum;
// Sum up the squared distances for the points in pointList not
// already taken.
double distSqSum = 0.0;
for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
distSqSum += minDistSquared[i];
}
}
// Add one new data point as a center. Each point x is chosen with
// probability proportional to D(x)2
final double r = random.nextDouble() * sum;
for (int i = 0 ; i < dx2.length; i++) {
if (dx2[i] >= r) {
final T p = pointSet.remove(i);
resultSet.add(new Cluster<T>(p));
break;
final double r = random.nextDouble() * distSqSum;
// The index of the next point to be added to the resultSet.
int nextPointIndex = -1;
// Sum through the squared min distances again, stopping when
// sum >= r.
double sum = 0.0;
for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
sum += minDistSquared[i];
if (sum >= r) {
nextPointIndex = i;
break;
}
}
}
// If it's not set to >= 0, the point wasn't found in the previous
// for loop, probably because distances are extremely small. Just pick
// the last available point.
if (nextPointIndex == -1) {
for (int i = numPoints - 1; i >= 0; i--) {
if (!taken[i]) {
nextPointIndex = i;
break;
}
}
}
// We found one.
if (nextPointIndex >= 0) {
final T p = pointList.get(nextPointIndex);
resultSet.add(new Cluster<T> (p));
// Mark it as taken.
taken[nextPointIndex] = true;
if (resultSet.size() < k) {
// Now update elements of minDistSquared. We only have to compute
// the distance to the new center to do this.
for (int j = 0; j < numPoints; j++) {
// Only have to worry about the points still not taken.
if (!taken[j]) {
double d = p.distanceFrom(pointList.get(j));
double d2 = d * d;
if (d2 < minDistSquared[j]) {
minDistSquared[j] = d2;
}
}
}
}
} else {
// None found --
// Break from the while loop to prevent
// an infinite loop.
break;
}
}
return resultSet;
}
/**

View File

@ -52,6 +52,9 @@ The <action> type attribute can be add,update,fix,remove.
If the output is not quite correct, check for invisible trailing spaces!
-->
<release version="3.0" date="TBD" description="TBD">
<action dev="luc" type="fix" issue="MATH-584" due-to="Randall Scarberry">
Improved k-means++ clustering performances and initial cluster center choice.
</action>
<action dev="luc" type="fix" issue="MATH-504" due-to="X. B.">
Fixed tricube function implementation in Loess interpolator.
</action>