Fixed k-means++ to add several strategies to deal with empty clusters that may appear during iterations
JIRA: MATH-429 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1026667 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
1b238344cd
commit
53ccce4ac1
|
@ -23,7 +23,7 @@ import org.apache.commons.math.exception.util.LocalizedFormats;
|
|||
* Error thrown when a numerical computation can not be performed because the
|
||||
* numerical result failed to converge to a finite value.
|
||||
*
|
||||
* @since 3.0
|
||||
* @since 2.2
|
||||
* @version $Revision$ $Date$
|
||||
*/
|
||||
public class ConvergenceException extends MathIllegalStateException {
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.apache.commons.math.exception.util.Localizable;
|
|||
* Base class for all exceptions that signal a mismatch between the
|
||||
* current state and the user's expectations.
|
||||
*
|
||||
* @since 3.0
|
||||
* @since 2.2
|
||||
* @version $Revision$ $Date$
|
||||
*/
|
||||
public class MathIllegalStateException extends IllegalStateException {
|
||||
|
|
|
@ -86,6 +86,7 @@ public enum LocalizedFormats implements Localizable {
|
|||
DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN("Discrete cumulative probability function returned NaN for argument {0}"),
|
||||
DISTRIBUTION_NOT_LOADED("distribution not loaded"),
|
||||
DUPLICATED_ABSCISSA("Abscissa {0} is duplicated at both indices {1} and {2}"),
|
||||
EMPTY_CLUSTER_IN_K_MEANS("empty cluster in k-means"),
|
||||
EMPTY_POLYNOMIALS_COEFFICIENTS_ARRAY("empty polynomials coefficients array"), /* keep */
|
||||
EMPTY_SELECTED_COLUMN_INDEX_ARRAY("empty selected column index array"),
|
||||
EMPTY_SELECTED_ROW_INDEX_ARRAY("empty selected row index array"),
|
||||
|
|
|
@ -22,6 +22,10 @@ import java.util.Collection;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import org.apache.commons.math.exception.ConvergenceException;
|
||||
import org.apache.commons.math.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math.stat.descriptive.moment.Variance;
|
||||
|
||||
/**
|
||||
* Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
|
||||
* @param <T> type of the points to cluster
|
||||
|
@ -31,14 +35,49 @@ import java.util.Random;
|
|||
*/
|
||||
public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
||||
|
||||
/** Strategies to use for replacing an empty cluster. */
|
||||
public static enum EmptyClusterStrategy {
|
||||
|
||||
/** Split the cluster with largest distance variance. */
|
||||
LARGEST_VARIANCE,
|
||||
|
||||
/** Split the cluster with largest number of points. */
|
||||
LARGEST_POINTS_NUMBER,
|
||||
|
||||
/** Create a cluster around the point farthest from its centroid. */
|
||||
FARTHEST_POINT,
|
||||
|
||||
/** Generate an error. */
|
||||
ERROR
|
||||
|
||||
}
|
||||
|
||||
/** Random generator for choosing initial centers. */
|
||||
private final Random random;
|
||||
|
||||
/** Selected strategy for empty clusters. */
|
||||
private final EmptyClusterStrategy emptyStrategy;
|
||||
|
||||
/** Build a clusterer.
|
||||
* <p>
|
||||
* The default strategy for handling empty clusters that may appear during
|
||||
* algorithm iterations is to split the cluster with largest distance variance.
|
||||
* </p>
|
||||
* @param random random generator to use for choosing initial centers
|
||||
*/
|
||||
public KMeansPlusPlusClusterer(final Random random) {
|
||||
this.random = random;
|
||||
this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
|
||||
}
|
||||
|
||||
/** Build a clusterer.
|
||||
* @param random random generator to use for choosing initial centers
|
||||
* @param emptyStrategy strategy to use for handling empty clusters that
|
||||
* may appear during algorithm iterations
|
||||
* @since 2.2
|
||||
*/
|
||||
public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
|
||||
this.random = random;
|
||||
this.emptyStrategy = emptyStrategy;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -62,9 +101,27 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
|||
boolean clusteringChanged = false;
|
||||
List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
|
||||
for (final Cluster<T> cluster : clusters) {
|
||||
final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
|
||||
if (!newCenter.equals(cluster.getCenter())) {
|
||||
final T newCenter;
|
||||
if (cluster.getPoints().isEmpty()) {
|
||||
switch (emptyStrategy) {
|
||||
case LARGEST_VARIANCE :
|
||||
newCenter = getPointFromLargestVarianceCluster(clusters);
|
||||
break;
|
||||
case LARGEST_POINTS_NUMBER :
|
||||
newCenter = getPointFromLargestNumberCluster(clusters);
|
||||
break;
|
||||
case FARTHEST_POINT :
|
||||
newCenter = getFarthestPoint(clusters);
|
||||
break;
|
||||
default :
|
||||
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||
}
|
||||
clusteringChanged = true;
|
||||
} else {
|
||||
newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
|
||||
if (!newCenter.equals(cluster.getCenter())) {
|
||||
clusteringChanged = true;
|
||||
}
|
||||
}
|
||||
newClusters.add(new Cluster<T>(newCenter));
|
||||
}
|
||||
|
@ -140,6 +197,120 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a random point from the {@link Cluster} with the largest distance variance.
|
||||
*
|
||||
* @param <T> type of the points to cluster
|
||||
* @param clusters the {@link Cluster}s to search
|
||||
* @return a random point from the selected cluster
|
||||
*/
|
||||
private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) {
|
||||
|
||||
double maxVariance = Double.NEGATIVE_INFINITY;
|
||||
Cluster<T> selected = null;
|
||||
for (final Cluster<T> cluster : clusters) {
|
||||
if (!cluster.getPoints().isEmpty()) {
|
||||
|
||||
// compute the distance variance of the current cluster
|
||||
final T center = cluster.getCenter();
|
||||
final Variance stat = new Variance();
|
||||
for (final T point : cluster.getPoints()) {
|
||||
stat.increment(point.distanceFrom(center));
|
||||
}
|
||||
final double variance = stat.getResult();
|
||||
|
||||
// select the cluster with the largest variance
|
||||
if (variance > maxVariance) {
|
||||
maxVariance = variance;
|
||||
selected = cluster;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// did we find at least one non-empty cluster ?
|
||||
if (selected == null) {
|
||||
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||
}
|
||||
|
||||
// extract a random point from the cluster
|
||||
final List<T> selectedPoints = selected.getPoints();
|
||||
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a random point from the {@link Cluster} with the largest number of points
|
||||
*
|
||||
* @param <T> type of the points to cluster
|
||||
* @param clusters the {@link Cluster}s to search
|
||||
* @return a random point from the selected cluster
|
||||
*/
|
||||
private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) {
|
||||
|
||||
int maxNumber = 0;
|
||||
Cluster<T> selected = null;
|
||||
for (final Cluster<T> cluster : clusters) {
|
||||
|
||||
// get the number of points of the current cluster
|
||||
final int number = cluster.getPoints().size();
|
||||
|
||||
// select the cluster with the largest number of points
|
||||
if (number > maxNumber) {
|
||||
maxNumber = number;
|
||||
selected = cluster;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// did we find at least one non-empty cluster ?
|
||||
if (selected == null) {
|
||||
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||
}
|
||||
|
||||
// extract a random point from the cluster
|
||||
final List<T> selectedPoints = selected.getPoints();
|
||||
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the point farthest to its cluster center
|
||||
*
|
||||
* @param <T> type of the points to cluster
|
||||
* @param clusters the {@link Cluster}s to search
|
||||
* @return point farthest to its cluster center
|
||||
*/
|
||||
private T getFarthestPoint(final Collection<Cluster<T>> clusters) {
|
||||
|
||||
double maxDistance = Double.NEGATIVE_INFINITY;
|
||||
Cluster<T> selectedCluster = null;
|
||||
int selectedPoint = -1;
|
||||
for (final Cluster<T> cluster : clusters) {
|
||||
|
||||
// get the farthest point
|
||||
final T center = cluster.getCenter();
|
||||
final List<T> points = cluster.getPoints();
|
||||
for (int i = 0; i < points.size(); ++i) {
|
||||
final double distance = points.get(i).distanceFrom(center);
|
||||
if (distance > maxDistance) {
|
||||
maxDistance = distance;
|
||||
selectedCluster = cluster;
|
||||
selectedPoint = i;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// did we find at least one non-empty cluster ?
|
||||
if (selectedCluster == null) {
|
||||
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||
}
|
||||
|
||||
return selectedCluster.getPoints().remove(selectedPoint);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the nearest {@link Cluster} to the given point
|
||||
*
|
||||
|
|
|
@ -58,6 +58,7 @@ DIMENSIONS_MISMATCH = dimensions incoh\u00e9rentes
|
|||
DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN = Discr\u00e8tes fonction de probabilit\u00e9 cumulative retourn\u00e9 NaN \u00e0 l''argument de {0}
|
||||
DISTRIBUTION_NOT_LOADED = aucune distribution n''a \u00e9t\u00e9 charg\u00e9e
|
||||
DUPLICATED_ABSCISSA = Abscisse {0} dupliqu\u00e9e aux indices {1} et {2}
|
||||
EMPTY_CLUSTER_IN_K_MEANS = groupe vide dans l''algorithme des k-moyennes
|
||||
EMPTY_POLYNOMIALS_COEFFICIENTS_ARRAY = tableau de coefficients polyn\u00f4miaux vide
|
||||
EMPTY_SELECTED_COLUMN_INDEX_ARRAY = tableau des indices de colonnes s\u00e9lectionn\u00e9es vide
|
||||
EMPTY_SELECTED_ROW_INDEX_ARRAY = tableau des indices de lignes s\u00e9lectionn\u00e9es vide
|
||||
|
|
|
@ -85,6 +85,10 @@ The <action> type attribute can be add,update,fix,remove.
|
|||
</action>
|
||||
</release>
|
||||
<release version="2.2" date="TBD" description="TBD">
|
||||
<action dev="luc" type="fix" issue="MATH-429">
|
||||
Fixed k-means++ to add several strategies to deal with empty clusters that may appear
|
||||
during iterations
|
||||
</action>
|
||||
<action dev="luc" type="update" issue="MATH-417">
|
||||
Improved Percentile performance by using a selection algorithm instead of a
|
||||
complete sort, and by allowing caching data array and pivots when several
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class KMeansPlusPlusClustererTest {
|
||||
|
@ -116,4 +117,53 @@ public class KMeansPlusPlusClustererTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCertainSpace() {
|
||||
KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = {
|
||||
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE,
|
||||
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER,
|
||||
KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT
|
||||
};
|
||||
for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) {
|
||||
KMeansPlusPlusClusterer<EuclideanIntegerPoint> transformer =
|
||||
new KMeansPlusPlusClusterer<EuclideanIntegerPoint>(new Random(1746432956321l), strategy);
|
||||
int numberOfVariables = 27;
|
||||
// initialise testvalues
|
||||
int position1 = 1;
|
||||
int position2 = position1 + numberOfVariables;
|
||||
int position3 = position2 + numberOfVariables;
|
||||
int position4 = position3 + numberOfVariables;
|
||||
// testvalues will be multiplied
|
||||
int multiplier = 1000000;
|
||||
|
||||
EuclideanIntegerPoint[] breakingPoints = new EuclideanIntegerPoint[numberOfVariables];
|
||||
// define the space which will break the cluster algorithm
|
||||
for (int i = 0; i < numberOfVariables; i++) {
|
||||
int points[] = { position1, position2, position3, position4 };
|
||||
// multiply the values
|
||||
for (int j = 0; j < points.length; j++) {
|
||||
points[j] = points[j] * multiplier;
|
||||
}
|
||||
EuclideanIntegerPoint euclideanIntegerPoint = new EuclideanIntegerPoint(points);
|
||||
breakingPoints[i] = euclideanIntegerPoint;
|
||||
position1 = position1 + numberOfVariables;
|
||||
position2 = position2 + numberOfVariables;
|
||||
position3 = position3 + numberOfVariables;
|
||||
position4 = position4 + numberOfVariables;
|
||||
}
|
||||
|
||||
for (int n = 2; n < 27; ++n) {
|
||||
List<Cluster<EuclideanIntegerPoint>> clusters =
|
||||
transformer.cluster(Arrays.asList(breakingPoints), n, 100);
|
||||
Assert.assertEquals(n, clusters.size());
|
||||
int sum = 0;
|
||||
for (Cluster<EuclideanIntegerPoint> cluster : clusters) {
|
||||
sum += cluster.getPoints().size();
|
||||
}
|
||||
Assert.assertEquals(numberOfVariables, sum);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue