Fixed k-means++ to add several strategies to deal with empty clusters that may appear during iterations

JIRA: MATH-429

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/branches/MATH_2_X@1026666 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2010-10-23 19:30:48 +00:00
parent fa55b9f280
commit ac7334d69b
7 changed files with 385 additions and 3 deletions

View File

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.exception;
import org.apache.commons.math.exception.util.Localizable;
import org.apache.commons.math.exception.util.LocalizedFormats;
/**
* Error thrown when a numerical computation can not be performed because the
* numerical result failed to converge to a finite value.
*
* @since 2.2
* @version $Revision$ $Date$
*/
public class ConvergenceException extends MathIllegalStateException {
/** Serializable version Id. */
private static final long serialVersionUID = 4330003017885151975L;
/**
* Construct the exception.
*/
public ConvergenceException() {
this(null);
}
/**
* Construct the exception with a specific context.
*
* @param specific Specific contexte pattern.
*/
public ConvergenceException(Localizable specific) {
this(specific,
LocalizedFormats.CONVERGENCE_FAILED,
null);
}
/**
* Construct the exception with a specific context and arguments.
*
* @param specific Specific contexte pattern.
* @param args Arguments.
*/
public ConvergenceException(Localizable specific,
Object ... args) {
super(specific,
LocalizedFormats.CONVERGENCE_FAILED,
args);
}
}

View File

@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.exception;
import java.util.Locale;
import org.apache.commons.math.exception.util.ArgUtils;
import org.apache.commons.math.exception.util.MessageFactory;
import org.apache.commons.math.exception.util.Localizable;
/**
* Base class for all exceptions that signal a mismatch between the
* current state and the user's expectations.
*
* @since 2.2
* @version $Revision$ $Date$
*/
public class MathIllegalStateException extends IllegalStateException {
/** Serializable version Id. */
private static final long serialVersionUID = -6024911025449780478L;
/**
* Pattern used to build the message (specific context).
*/
private final Localizable specific;
/**
* Pattern used to build the message (general problem description).
*/
private final Localizable general;
/**
* Arguments used to build the message.
*/
private final Object[] arguments;
/**
* @param specific Message pattern providing the specific context of
* the error.
* @param general Message pattern explaining the cause of the error.
* @param args Arguments.
*/
public MathIllegalStateException(Localizable specific,
Localizable general,
Object ... args) {
this.specific = specific;
this.general = general;
arguments = ArgUtils.flatten(args);
}
/**
* @param general Message pattern explaining the cause of the error.
* @param args Arguments.
*/
public MathIllegalStateException(Localizable general,
Object ... args) {
this(null, general, args);
}
/**
* Get the message in a specified locale.
*
* @param locale Locale in which the message should be translated.
*
* @return the localized message.
*/
public String getMessage(final Locale locale) {
return MessageFactory.buildMessage(locale, specific, general, arguments);
}
/** {@inheritDoc} */
@Override
public String getMessage() {
return getMessage(Locale.US);
}
/** {@inheritDoc} */
@Override
public String getLocalizedMessage() {
return getMessage(Locale.getDefault());
}
}

View File

@ -84,6 +84,7 @@ public enum LocalizedFormats implements Localizable {
DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN("Discrete cumulative probability function returned NaN for argument {0}"),
DISTRIBUTION_NOT_LOADED("distribution not loaded"),
DUPLICATED_ABSCISSA("Abscissa {0} is duplicated at both indices {1} and {2}"),
EMPTY_CLUSTER_IN_K_MEANS("empty cluster in k-means"),
EMPTY_POLYNOMIALS_COEFFICIENTS_ARRAY("empty polynomials coefficients array"), /* keep */
EMPTY_SELECTED_COLUMN_INDEX_ARRAY("empty selected column index array"),
EMPTY_SELECTED_ROW_INDEX_ARRAY("empty selected row index array"),

View File

@ -22,6 +22,10 @@ import java.util.Collection;
import java.util.List;
import java.util.Random;
import org.apache.commons.math.exception.ConvergenceException;
import org.apache.commons.math.exception.util.LocalizedFormats;
import org.apache.commons.math.stat.descriptive.moment.Variance;
/**
* Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
* @param <T> type of the points to cluster
@ -31,14 +35,49 @@ import java.util.Random;
*/
public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
/** Strategies to use for replacing an empty cluster. */
public static enum EmptyClusterStrategy {
/** Split the cluster with largest distance variance. */
LARGEST_VARIANCE,
/** Split the cluster with largest number of points. */
LARGEST_POINTS_NUMBER,
/** Create a cluster around the point farthest from its centroid. */
FARTHEST_POINT,
/** Generate an error. */
ERROR
}
/** Random generator for choosing initial centers. */
private final Random random;
/** Selected strategy for empty clusters. */
private final EmptyClusterStrategy emptyStrategy;
/** Build a clusterer.
* <p>
* The default strategy for handling empty clusters that may appear during
* algorithm iterations is to split the cluster with largest distance variance.
* </p>
* @param random random generator to use for choosing initial centers
*/
public KMeansPlusPlusClusterer(final Random random) {
this.random = random;
this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
}
/** Build a clusterer.
* @param random random generator to use for choosing initial centers
* @param emptyStrategy strategy to use for handling empty clusters that
* may appear during algorithm iterations
* @since 2.2
*/
public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
this.random = random;
this.emptyStrategy = emptyStrategy;
}
/**
@ -62,9 +101,27 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
boolean clusteringChanged = false;
List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
for (final Cluster<T> cluster : clusters) {
final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
if (!newCenter.equals(cluster.getCenter())) {
final T newCenter;
if (cluster.getPoints().isEmpty()) {
switch (emptyStrategy) {
case LARGEST_VARIANCE :
newCenter = getPointFromLargestVarianceCluster(clusters);
break;
case LARGEST_POINTS_NUMBER :
newCenter = getPointFromLargestNumberCluster(clusters);
break;
case FARTHEST_POINT :
newCenter = getFarthestPoint(clusters);
break;
default :
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
}
clusteringChanged = true;
} else {
newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
if (!newCenter.equals(cluster.getCenter())) {
clusteringChanged = true;
}
}
newClusters.add(new Cluster<T>(newCenter));
}
@ -140,6 +197,120 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
}
/**
* Get a random point from the {@link Cluster} with the largest distance variance.
*
* @param <T> type of the points to cluster
* @param clusters the {@link Cluster}s to search
* @return a random point from the selected cluster
*/
private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) {
double maxVariance = Double.NEGATIVE_INFINITY;
Cluster<T> selected = null;
for (final Cluster<T> cluster : clusters) {
if (!cluster.getPoints().isEmpty()) {
// compute the distance variance of the current cluster
final T center = cluster.getCenter();
final Variance stat = new Variance();
for (final T point : cluster.getPoints()) {
stat.increment(point.distanceFrom(center));
}
final double variance = stat.getResult();
// select the cluster with the largest variance
if (variance > maxVariance) {
maxVariance = variance;
selected = cluster;
}
}
}
// did we find at least one non-empty cluster ?
if (selected == null) {
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
}
// extract a random point from the cluster
final List<T> selectedPoints = selected.getPoints();
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
}
/**
* Get a random point from the {@link Cluster} with the largest number of points
*
* @param <T> type of the points to cluster
* @param clusters the {@link Cluster}s to search
* @return a random point from the selected cluster
*/
private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) {
int maxNumber = 0;
Cluster<T> selected = null;
for (final Cluster<T> cluster : clusters) {
// get the number of points of the current cluster
final int number = cluster.getPoints().size();
// select the cluster with the largest number of points
if (number > maxNumber) {
maxNumber = number;
selected = cluster;
}
}
// did we find at least one non-empty cluster ?
if (selected == null) {
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
}
// extract a random point from the cluster
final List<T> selectedPoints = selected.getPoints();
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
}
/**
* Get the point farthest to its cluster center
*
* @param <T> type of the points to cluster
* @param clusters the {@link Cluster}s to search
* @return point farthest to its cluster center
*/
private T getFarthestPoint(final Collection<Cluster<T>> clusters) {
double maxDistance = Double.NEGATIVE_INFINITY;
Cluster<T> selectedCluster = null;
int selectedPoint = -1;
for (final Cluster<T> cluster : clusters) {
// get the farthest point
final T center = cluster.getCenter();
final List<T> points = cluster.getPoints();
for (int i = 0; i < points.size(); ++i) {
final double distance = points.get(i).distanceFrom(center);
if (distance > maxDistance) {
maxDistance = distance;
selectedCluster = cluster;
selectedPoint = i;
}
}
}
// did we find at least one non-empty cluster ?
if (selectedCluster == null) {
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
}
return selectedCluster.getPoints().remove(selectedPoint);
}
/**
* Returns the nearest {@link Cluster} to the given point
*

View File

@ -56,6 +56,7 @@ DIMENSIONS_MISMATCH_SIMPLE = dimensions incoh\u00e9rentes {0} != {1}
DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN = Discr\u00e8tes fonction de probabilit\u00e9 cumulative retourn\u00e9 NaN \u00e0 l''argument de {0}
DISTRIBUTION_NOT_LOADED = aucune distribution n''a \u00e9t\u00e9 charg\u00e9e
DUPLICATED_ABSCISSA = Abscisse {0} dupliqu\u00e9e aux indices {1} et {2}
EMPTY_CLUSTER_IN_K_MEANS = groupe vide dans l''algorithme des k-moyennes
EMPTY_POLYNOMIALS_COEFFICIENTS_ARRAY = tableau de coefficients polyn\u00f4miaux vide
EMPTY_SELECTED_COLUMN_INDEX_ARRAY = tableau des indices de colonnes s\u00e9lectionn\u00e9es vide
EMPTY_SELECTED_ROW_INDEX_ARRAY = tableau des indices de lignes s\u00e9lectionn\u00e9es vide

View File

@ -52,6 +52,10 @@ The <action> type attribute can be add,update,fix,remove.
If the output is not quite correct, check for invisible trailing spaces!
-->
<release version="2.2" date="TBD" description="TBD">
<action dev="luc" type="fix" issue="MATH-429">
Fixed k-means++ to add several strategies to deal with empty clusters that may appear
during iterations
</action>
<action dev="luc" type="update" issue="MATH-417">
Improved Percentile performance by using a selection algorithm instead of a
complete sort, and by allowing caching data array and pivots when several

View File

@ -24,6 +24,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.junit.Assert;
import org.junit.Test;
public class KMeansPlusPlusClustererTest {
@ -116,4 +117,53 @@ public class KMeansPlusPlusClustererTest {
}
@Test
public void testCertainSpace() {
KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = {
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE,
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER,
KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT
};
for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) {
KMeansPlusPlusClusterer<EuclideanIntegerPoint> transformer =
new KMeansPlusPlusClusterer<EuclideanIntegerPoint>(new Random(1746432956321l), strategy);
int numberOfVariables = 27;
// initialise testvalues
int position1 = 1;
int position2 = position1 + numberOfVariables;
int position3 = position2 + numberOfVariables;
int position4 = position3 + numberOfVariables;
// testvalues will be multiplied
int multiplier = 1000000;
EuclideanIntegerPoint[] breakingPoints = new EuclideanIntegerPoint[numberOfVariables];
// define the space which will break the cluster algorithm
for (int i = 0; i < numberOfVariables; i++) {
int points[] = { position1, position2, position3, position4 };
// multiply the values
for (int j = 0; j < points.length; j++) {
points[j] = points[j] * multiplier;
}
EuclideanIntegerPoint euclideanIntegerPoint = new EuclideanIntegerPoint(points);
breakingPoints[i] = euclideanIntegerPoint;
position1 = position1 + numberOfVariables;
position2 = position2 + numberOfVariables;
position3 = position3 + numberOfVariables;
position4 = position4 + numberOfVariables;
}
for (int n = 2; n < 27; ++n) {
List<Cluster<EuclideanIntegerPoint>> clusters =
transformer.cluster(Arrays.asList(breakingPoints), n, 100);
Assert.assertEquals(n, clusters.size());
int sum = 0;
for (Cluster<EuclideanIntegerPoint> cluster : clusters) {
sum += cluster.getPoints().size();
}
Assert.assertEquals(numberOfVariables, sum);
}
}
}
}