added multiple trials runs to K-means++ clustering algorithm.
JIRA: MATH-548 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1137759 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
234db70446
commit
ff92629a3b
|
@ -84,6 +84,60 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
|||
this.emptyStrategy = emptyStrategy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the K-means++ clustering algorithm.
|
||||
*
|
||||
* @param points the points to cluster
|
||||
* @param k the number of clusters to split the data into
|
||||
* @param maxIterations the maximum number of iterations to run the algorithm
|
||||
* for. If negative, no maximum will be used
|
||||
* @return a list of clusters containing the points
|
||||
* @throws MathIllegalArgumentException if the data points are null or the number
|
||||
* of clusters is larger than the number of data points
|
||||
*/
|
||||
public List<Cluster<T>> cluster(final Collection<T> points, final int k,
|
||||
int numTrials, int maxIterationsPerTrial)
|
||||
throws MathIllegalArgumentException {
|
||||
|
||||
// at first, we have not found any clusters list yet
|
||||
List<Cluster<T>> best = null;
|
||||
double bestVarianceSum = Double.POSITIVE_INFINITY;
|
||||
|
||||
// do several clustering trials
|
||||
for (int i = 0; i < numTrials; ++i) {
|
||||
|
||||
// compute a clusters list
|
||||
List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
|
||||
|
||||
// compute the variance of the current list
|
||||
double varianceSum = 0.0;
|
||||
for (final Cluster<T> cluster : clusters) {
|
||||
if (!cluster.getPoints().isEmpty()) {
|
||||
|
||||
// compute the distance variance of the current cluster
|
||||
final T center = cluster.getCenter();
|
||||
final Variance stat = new Variance();
|
||||
for (final T point : cluster.getPoints()) {
|
||||
stat.increment(point.distanceFrom(center));
|
||||
}
|
||||
varianceSum += stat.getResult();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if (varianceSum <= bestVarianceSum) {
|
||||
// this one is the best we have found so far, remember it
|
||||
best = clusters;
|
||||
bestVarianceSum = varianceSum;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// return the best clusters list found
|
||||
return best;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the K-means++ clustering algorithm.
|
||||
*
|
||||
|
|
|
@ -52,6 +52,9 @@ The <action> type attribute can be add,update,fix,remove.
|
|||
If the output is not quite correct, check for invisible trailing spaces!
|
||||
-->
|
||||
<release version="3.0" date="TBD" description="TBD">
|
||||
<action dev="luc" type="add" issue="MATH-548">
|
||||
K-means++ clustering can now run multiple trials
|
||||
</action>
|
||||
<action dev="luc" type="add" issue="MATH-591">
|
||||
Added a way to compute sub-lines intersections, considering sub-lines either
|
||||
as open sets or closed sets
|
||||
|
|
|
@ -65,7 +65,7 @@ public class KMeansPlusPlusClustererTest {
|
|||
|
||||
};
|
||||
List<Cluster<EuclideanIntegerPoint>> clusters =
|
||||
transformer.cluster(Arrays.asList(points), 3, 10);
|
||||
transformer.cluster(Arrays.asList(points), 3, 5, 10);
|
||||
|
||||
Assert.assertEquals(3, clusters.size());
|
||||
boolean cluster1Found = false;
|
||||
|
|
Loading…
Reference in New Issue