mirror of
https://github.com/apache/commons-math.git
synced 2025-02-06 01:59:13 +00:00
Improved k-means++ clustering performances and initial cluster center choice.
JIRA: MATH-584 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1132448 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
9dfd54ef0f
commit
e00a9b226f
@ -19,6 +19,7 @@ package org.apache.commons.math.stat.clustering;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
@ -193,41 +194,121 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
||||
private static <T extends Clusterable<T>> List<Cluster<T>>
|
||||
chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
|
||||
|
||||
final List<T> pointSet = new ArrayList<T>(points);
|
||||
// Convert to list for indexed access. Make it unmodifiable, since removal of items
|
||||
// would screw up the logic of this method.
|
||||
final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
|
||||
|
||||
// The number of points in the list.
|
||||
final int numPoints = pointList.size();
|
||||
|
||||
// Set the corresponding element in this array to indicate when
|
||||
// elements of pointList are no longer available.
|
||||
final boolean[] taken = new boolean[numPoints];
|
||||
|
||||
// The resulting list of initial centers.
|
||||
final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
|
||||
|
||||
// Choose one center uniformly at random from among the data points.
|
||||
final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
|
||||
final int firstPointIndex = random.nextInt(numPoints);
|
||||
|
||||
final T firstPoint = pointList.get(firstPointIndex);
|
||||
|
||||
resultSet.add(new Cluster<T>(firstPoint));
|
||||
|
||||
final double[] dx2 = new double[pointSet.size()];
|
||||
// Must mark it as taken
|
||||
taken[firstPointIndex] = true;
|
||||
|
||||
// To keep track of the minimum distance squared of elements of
|
||||
// pointList to elements of resultSet.
|
||||
final double[] minDistSquared = new double[numPoints];
|
||||
|
||||
// Initialize the elements. Since the only point in resultSet is firstPoint,
|
||||
// this is very easy.
|
||||
for (int i = 0; i < numPoints; i++) {
|
||||
if (i != firstPointIndex) { // That point isn't considered
|
||||
double d = firstPoint.distanceFrom(pointList.get(i));
|
||||
minDistSquared[i] = d*d;
|
||||
}
|
||||
}
|
||||
|
||||
while (resultSet.size() < k) {
|
||||
// For each data point x, compute D(x), the distance between x and
|
||||
// the nearest center that has already been chosen.
|
||||
double sum = 0;
|
||||
for (int i = 0; i < pointSet.size(); i++) {
|
||||
final T p = pointSet.get(i);
|
||||
int nearestClusterIndex = getNearestCluster(resultSet, p);
|
||||
final Cluster<T> nearest = resultSet.get(nearestClusterIndex);
|
||||
final double d = p.distanceFrom(nearest.getCenter());
|
||||
sum += d * d;
|
||||
dx2[i] = sum;
|
||||
|
||||
// Sum up the squared distances for the points in pointList not
|
||||
// already taken.
|
||||
double distSqSum = 0.0;
|
||||
|
||||
for (int i = 0; i < numPoints; i++) {
|
||||
if (!taken[i]) {
|
||||
distSqSum += minDistSquared[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Add one new data point as a center. Each point x is chosen with
|
||||
// probability proportional to D(x)2
|
||||
final double r = random.nextDouble() * sum;
|
||||
for (int i = 0 ; i < dx2.length; i++) {
|
||||
if (dx2[i] >= r) {
|
||||
final T p = pointSet.remove(i);
|
||||
resultSet.add(new Cluster<T>(p));
|
||||
break;
|
||||
final double r = random.nextDouble() * distSqSum;
|
||||
|
||||
// The index of the next point to be added to the resultSet.
|
||||
int nextPointIndex = -1;
|
||||
|
||||
// Sum through the squared min distances again, stopping when
|
||||
// sum >= r.
|
||||
double sum = 0.0;
|
||||
for (int i = 0; i < numPoints; i++) {
|
||||
if (!taken[i]) {
|
||||
sum += minDistSquared[i];
|
||||
if (sum >= r) {
|
||||
nextPointIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If it's not set to >= 0, the point wasn't found in the previous
|
||||
// for loop, probably because distances are extremely small. Just pick
|
||||
// the last available point.
|
||||
if (nextPointIndex == -1) {
|
||||
for (int i = numPoints - 1; i >= 0; i--) {
|
||||
if (!taken[i]) {
|
||||
nextPointIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We found one.
|
||||
if (nextPointIndex >= 0) {
|
||||
|
||||
final T p = pointList.get(nextPointIndex);
|
||||
|
||||
resultSet.add(new Cluster<T> (p));
|
||||
|
||||
// Mark it as taken.
|
||||
taken[nextPointIndex] = true;
|
||||
|
||||
if (resultSet.size() < k) {
|
||||
// Now update elements of minDistSquared. We only have to compute
|
||||
// the distance to the new center to do this.
|
||||
for (int j = 0; j < numPoints; j++) {
|
||||
// Only have to worry about the points still not taken.
|
||||
if (!taken[j]) {
|
||||
double d = p.distanceFrom(pointList.get(j));
|
||||
double d2 = d * d;
|
||||
if (d2 < minDistSquared[j]) {
|
||||
minDistSquared[j] = d2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
// None found --
|
||||
// Break from the while loop to prevent
|
||||
// an infinite loop.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return resultSet;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -52,6 +52,9 @@ The <action> type attribute can be add,update,fix,remove.
|
||||
If the output is not quite correct, check for invisible trailing spaces!
|
||||
-->
|
||||
<release version="3.0" date="TBD" description="TBD">
|
||||
<action dev="luc" type="fix" issue="MATH-584" due-to="Randall Scarberry">
|
||||
Improved k-means++ clustering performances and initial cluster center choice.
|
||||
</action>
|
||||
<action dev="luc" type="fix" issue="MATH-504" due-to="X. B.">
|
||||
Fixed tricube function implementation in Loess interpolator.
|
||||
</action>
|
||||
|
Loading…
x
Reference in New Issue
Block a user