Fixed k-means++ to add several strategies to deal with empty clusters that may appear during iterations
JIRA: MATH-429 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/branches/MATH_2_X@1026666 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
fa55b9f280
commit
ac7334d69b
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math.exception;
|
||||
|
||||
import org.apache.commons.math.exception.util.Localizable;
|
||||
import org.apache.commons.math.exception.util.LocalizedFormats;
|
||||
|
||||
/**
|
||||
* Error thrown when a numerical computation can not be performed because the
|
||||
* numerical result failed to converge to a finite value.
|
||||
*
|
||||
* @since 2.2
|
||||
* @version $Revision$ $Date$
|
||||
*/
|
||||
public class ConvergenceException extends MathIllegalStateException {
|
||||
/** Serializable version Id. */
|
||||
private static final long serialVersionUID = 4330003017885151975L;
|
||||
|
||||
/**
|
||||
* Construct the exception.
|
||||
*/
|
||||
public ConvergenceException() {
|
||||
this(null);
|
||||
}
|
||||
/**
|
||||
* Construct the exception with a specific context.
|
||||
*
|
||||
* @param specific Specific contexte pattern.
|
||||
*/
|
||||
public ConvergenceException(Localizable specific) {
|
||||
this(specific,
|
||||
LocalizedFormats.CONVERGENCE_FAILED,
|
||||
null);
|
||||
}
|
||||
/**
|
||||
* Construct the exception with a specific context and arguments.
|
||||
*
|
||||
* @param specific Specific contexte pattern.
|
||||
* @param args Arguments.
|
||||
*/
|
||||
public ConvergenceException(Localizable specific,
|
||||
Object ... args) {
|
||||
super(specific,
|
||||
LocalizedFormats.CONVERGENCE_FAILED,
|
||||
args);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.commons.math.exception;
|
||||
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math.exception.util.ArgUtils;
|
||||
import org.apache.commons.math.exception.util.MessageFactory;
|
||||
import org.apache.commons.math.exception.util.Localizable;
|
||||
|
||||
/**
|
||||
* Base class for all exceptions that signal a mismatch between the
|
||||
* current state and the user's expectations.
|
||||
*
|
||||
* @since 2.2
|
||||
* @version $Revision$ $Date$
|
||||
*/
|
||||
public class MathIllegalStateException extends IllegalStateException {
|
||||
|
||||
/** Serializable version Id. */
|
||||
private static final long serialVersionUID = -6024911025449780478L;
|
||||
|
||||
/**
|
||||
* Pattern used to build the message (specific context).
|
||||
*/
|
||||
private final Localizable specific;
|
||||
/**
|
||||
* Pattern used to build the message (general problem description).
|
||||
*/
|
||||
private final Localizable general;
|
||||
/**
|
||||
* Arguments used to build the message.
|
||||
*/
|
||||
private final Object[] arguments;
|
||||
|
||||
/**
|
||||
* @param specific Message pattern providing the specific context of
|
||||
* the error.
|
||||
* @param general Message pattern explaining the cause of the error.
|
||||
* @param args Arguments.
|
||||
*/
|
||||
public MathIllegalStateException(Localizable specific,
|
||||
Localizable general,
|
||||
Object ... args) {
|
||||
this.specific = specific;
|
||||
this.general = general;
|
||||
arguments = ArgUtils.flatten(args);
|
||||
}
|
||||
/**
|
||||
* @param general Message pattern explaining the cause of the error.
|
||||
* @param args Arguments.
|
||||
*/
|
||||
public MathIllegalStateException(Localizable general,
|
||||
Object ... args) {
|
||||
this(null, general, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message in a specified locale.
|
||||
*
|
||||
* @param locale Locale in which the message should be translated.
|
||||
*
|
||||
* @return the localized message.
|
||||
*/
|
||||
public String getMessage(final Locale locale) {
|
||||
return MessageFactory.buildMessage(locale, specific, general, arguments);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public String getMessage() {
|
||||
return getMessage(Locale.US);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public String getLocalizedMessage() {
|
||||
return getMessage(Locale.getDefault());
|
||||
}
|
||||
}
|
|
@ -84,6 +84,7 @@ public enum LocalizedFormats implements Localizable {
|
|||
DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN("Discrete cumulative probability function returned NaN for argument {0}"),
|
||||
DISTRIBUTION_NOT_LOADED("distribution not loaded"),
|
||||
DUPLICATED_ABSCISSA("Abscissa {0} is duplicated at both indices {1} and {2}"),
|
||||
EMPTY_CLUSTER_IN_K_MEANS("empty cluster in k-means"),
|
||||
EMPTY_POLYNOMIALS_COEFFICIENTS_ARRAY("empty polynomials coefficients array"), /* keep */
|
||||
EMPTY_SELECTED_COLUMN_INDEX_ARRAY("empty selected column index array"),
|
||||
EMPTY_SELECTED_ROW_INDEX_ARRAY("empty selected row index array"),
|
||||
|
|
|
@ -22,6 +22,10 @@ import java.util.Collection;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import org.apache.commons.math.exception.ConvergenceException;
|
||||
import org.apache.commons.math.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math.stat.descriptive.moment.Variance;
|
||||
|
||||
/**
|
||||
* Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
|
||||
* @param <T> type of the points to cluster
|
||||
|
@ -31,14 +35,49 @@ import java.util.Random;
|
|||
*/
|
||||
public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
||||
|
||||
/** Strategies to use for replacing an empty cluster. */
|
||||
public static enum EmptyClusterStrategy {
|
||||
|
||||
/** Split the cluster with largest distance variance. */
|
||||
LARGEST_VARIANCE,
|
||||
|
||||
/** Split the cluster with largest number of points. */
|
||||
LARGEST_POINTS_NUMBER,
|
||||
|
||||
/** Create a cluster around the point farthest from its centroid. */
|
||||
FARTHEST_POINT,
|
||||
|
||||
/** Generate an error. */
|
||||
ERROR
|
||||
|
||||
}
|
||||
|
||||
/** Random generator for choosing initial centers. */
|
||||
private final Random random;
|
||||
|
||||
/** Selected strategy for empty clusters. */
|
||||
private final EmptyClusterStrategy emptyStrategy;
|
||||
|
||||
/** Build a clusterer.
|
||||
* <p>
|
||||
* The default strategy for handling empty clusters that may appear during
|
||||
* algorithm iterations is to split the cluster with largest distance variance.
|
||||
* </p>
|
||||
* @param random random generator to use for choosing initial centers
|
||||
*/
|
||||
public KMeansPlusPlusClusterer(final Random random) {
|
||||
this.random = random;
|
||||
this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
|
||||
}
|
||||
|
||||
/** Build a clusterer.
|
||||
* @param random random generator to use for choosing initial centers
|
||||
* @param emptyStrategy strategy to use for handling empty clusters that
|
||||
* may appear during algorithm iterations
|
||||
* @since 2.2
|
||||
*/
|
||||
public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
|
||||
this.random = random;
|
||||
this.emptyStrategy = emptyStrategy;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -62,9 +101,27 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
|||
boolean clusteringChanged = false;
|
||||
List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
|
||||
for (final Cluster<T> cluster : clusters) {
|
||||
final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
|
||||
if (!newCenter.equals(cluster.getCenter())) {
|
||||
final T newCenter;
|
||||
if (cluster.getPoints().isEmpty()) {
|
||||
switch (emptyStrategy) {
|
||||
case LARGEST_VARIANCE :
|
||||
newCenter = getPointFromLargestVarianceCluster(clusters);
|
||||
break;
|
||||
case LARGEST_POINTS_NUMBER :
|
||||
newCenter = getPointFromLargestNumberCluster(clusters);
|
||||
break;
|
||||
case FARTHEST_POINT :
|
||||
newCenter = getFarthestPoint(clusters);
|
||||
break;
|
||||
default :
|
||||
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||
}
|
||||
clusteringChanged = true;
|
||||
} else {
|
||||
newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
|
||||
if (!newCenter.equals(cluster.getCenter())) {
|
||||
clusteringChanged = true;
|
||||
}
|
||||
}
|
||||
newClusters.add(new Cluster<T>(newCenter));
|
||||
}
|
||||
|
@ -140,6 +197,120 @@ public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a random point from the {@link Cluster} with the largest distance variance.
|
||||
*
|
||||
* @param <T> type of the points to cluster
|
||||
* @param clusters the {@link Cluster}s to search
|
||||
* @return a random point from the selected cluster
|
||||
*/
|
||||
private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) {
|
||||
|
||||
double maxVariance = Double.NEGATIVE_INFINITY;
|
||||
Cluster<T> selected = null;
|
||||
for (final Cluster<T> cluster : clusters) {
|
||||
if (!cluster.getPoints().isEmpty()) {
|
||||
|
||||
// compute the distance variance of the current cluster
|
||||
final T center = cluster.getCenter();
|
||||
final Variance stat = new Variance();
|
||||
for (final T point : cluster.getPoints()) {
|
||||
stat.increment(point.distanceFrom(center));
|
||||
}
|
||||
final double variance = stat.getResult();
|
||||
|
||||
// select the cluster with the largest variance
|
||||
if (variance > maxVariance) {
|
||||
maxVariance = variance;
|
||||
selected = cluster;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// did we find at least one non-empty cluster ?
|
||||
if (selected == null) {
|
||||
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||
}
|
||||
|
||||
// extract a random point from the cluster
|
||||
final List<T> selectedPoints = selected.getPoints();
|
||||
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a random point from the {@link Cluster} with the largest number of points
|
||||
*
|
||||
* @param <T> type of the points to cluster
|
||||
* @param clusters the {@link Cluster}s to search
|
||||
* @return a random point from the selected cluster
|
||||
*/
|
||||
private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) {
|
||||
|
||||
int maxNumber = 0;
|
||||
Cluster<T> selected = null;
|
||||
for (final Cluster<T> cluster : clusters) {
|
||||
|
||||
// get the number of points of the current cluster
|
||||
final int number = cluster.getPoints().size();
|
||||
|
||||
// select the cluster with the largest number of points
|
||||
if (number > maxNumber) {
|
||||
maxNumber = number;
|
||||
selected = cluster;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// did we find at least one non-empty cluster ?
|
||||
if (selected == null) {
|
||||
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||
}
|
||||
|
||||
// extract a random point from the cluster
|
||||
final List<T> selectedPoints = selected.getPoints();
|
||||
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the point farthest to its cluster center
|
||||
*
|
||||
* @param <T> type of the points to cluster
|
||||
* @param clusters the {@link Cluster}s to search
|
||||
* @return point farthest to its cluster center
|
||||
*/
|
||||
private T getFarthestPoint(final Collection<Cluster<T>> clusters) {
|
||||
|
||||
double maxDistance = Double.NEGATIVE_INFINITY;
|
||||
Cluster<T> selectedCluster = null;
|
||||
int selectedPoint = -1;
|
||||
for (final Cluster<T> cluster : clusters) {
|
||||
|
||||
// get the farthest point
|
||||
final T center = cluster.getCenter();
|
||||
final List<T> points = cluster.getPoints();
|
||||
for (int i = 0; i < points.size(); ++i) {
|
||||
final double distance = points.get(i).distanceFrom(center);
|
||||
if (distance > maxDistance) {
|
||||
maxDistance = distance;
|
||||
selectedCluster = cluster;
|
||||
selectedPoint = i;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// did we find at least one non-empty cluster ?
|
||||
if (selectedCluster == null) {
|
||||
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||
}
|
||||
|
||||
return selectedCluster.getPoints().remove(selectedPoint);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the nearest {@link Cluster} to the given point
|
||||
*
|
||||
|
|
|
@ -56,6 +56,7 @@ DIMENSIONS_MISMATCH_SIMPLE = dimensions incoh\u00e9rentes {0} != {1}
|
|||
DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN = Discr\u00e8tes fonction de probabilit\u00e9 cumulative retourn\u00e9 NaN \u00e0 l''argument de {0}
|
||||
DISTRIBUTION_NOT_LOADED = aucune distribution n''a \u00e9t\u00e9 charg\u00e9e
|
||||
DUPLICATED_ABSCISSA = Abscisse {0} dupliqu\u00e9e aux indices {1} et {2}
|
||||
EMPTY_CLUSTER_IN_K_MEANS = groupe vide dans l''algorithme des k-moyennes
|
||||
EMPTY_POLYNOMIALS_COEFFICIENTS_ARRAY = tableau de coefficients polyn\u00f4miaux vide
|
||||
EMPTY_SELECTED_COLUMN_INDEX_ARRAY = tableau des indices de colonnes s\u00e9lectionn\u00e9es vide
|
||||
EMPTY_SELECTED_ROW_INDEX_ARRAY = tableau des indices de lignes s\u00e9lectionn\u00e9es vide
|
||||
|
|
|
@ -52,6 +52,10 @@ The <action> type attribute can be add,update,fix,remove.
|
|||
If the output is not quite correct, check for invisible trailing spaces!
|
||||
-->
|
||||
<release version="2.2" date="TBD" description="TBD">
|
||||
<action dev="luc" type="fix" issue="MATH-429">
|
||||
Fixed k-means++ to add several strategies to deal with empty clusters that may appear
|
||||
during iterations
|
||||
</action>
|
||||
<action dev="luc" type="update" issue="MATH-417">
|
||||
Improved Percentile performance by using a selection algorithm instead of a
|
||||
complete sort, and by allowing caching data array and pivots when several
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class KMeansPlusPlusClustererTest {
|
||||
|
@ -116,4 +117,53 @@ public class KMeansPlusPlusClustererTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCertainSpace() {
|
||||
KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = {
|
||||
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE,
|
||||
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER,
|
||||
KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT
|
||||
};
|
||||
for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) {
|
||||
KMeansPlusPlusClusterer<EuclideanIntegerPoint> transformer =
|
||||
new KMeansPlusPlusClusterer<EuclideanIntegerPoint>(new Random(1746432956321l), strategy);
|
||||
int numberOfVariables = 27;
|
||||
// initialise testvalues
|
||||
int position1 = 1;
|
||||
int position2 = position1 + numberOfVariables;
|
||||
int position3 = position2 + numberOfVariables;
|
||||
int position4 = position3 + numberOfVariables;
|
||||
// testvalues will be multiplied
|
||||
int multiplier = 1000000;
|
||||
|
||||
EuclideanIntegerPoint[] breakingPoints = new EuclideanIntegerPoint[numberOfVariables];
|
||||
// define the space which will break the cluster algorithm
|
||||
for (int i = 0; i < numberOfVariables; i++) {
|
||||
int points[] = { position1, position2, position3, position4 };
|
||||
// multiply the values
|
||||
for (int j = 0; j < points.length; j++) {
|
||||
points[j] = points[j] * multiplier;
|
||||
}
|
||||
EuclideanIntegerPoint euclideanIntegerPoint = new EuclideanIntegerPoint(points);
|
||||
breakingPoints[i] = euclideanIntegerPoint;
|
||||
position1 = position1 + numberOfVariables;
|
||||
position2 = position2 + numberOfVariables;
|
||||
position3 = position3 + numberOfVariables;
|
||||
position4 = position4 + numberOfVariables;
|
||||
}
|
||||
|
||||
for (int n = 2; n < 27; ++n) {
|
||||
List<Cluster<EuclideanIntegerPoint>> clusters =
|
||||
transformer.cluster(Arrays.asList(breakingPoints), n, 100);
|
||||
Assert.assertEquals(n, clusters.size());
|
||||
int sum = 0;
|
||||
for (Cluster<EuclideanIntegerPoint> cluster : clusters) {
|
||||
sum += cluster.getPoints().size();
|
||||
}
|
||||
Assert.assertEquals(numberOfVariables, sum);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue