diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 5fa4572de..df71204e5 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -73,6 +73,10 @@ Users are encouraged to upgrade to this version as this release not
2. A few methods in the FastMath class are in fact slower that their
counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901).
">
+
+ "FuzzyKMeansClusterer" has thrown an exception in case one of the data
+ points was equal to a cluster center.
+
Interface to allow parameter validation in "o.a.c.m.fitting.leastsquares":
the point computed by by the optimizer can be modified before evaluation.
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java
index ed2204789..5f89934a7 100644
--- a/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java
@@ -346,18 +346,32 @@ public class FuzzyKMeansClusterer extends Clusterer {
private void updateMembershipMatrix() {
for (int i = 0; i < points.size(); i++) {
final T point = points.get(i);
- double maxMembership = 0.0;
+ double maxMembership = Double.MIN_VALUE;
int newCluster = -1;
for (int j = 0; j < clusters.size(); j++) {
double sum = 0.0;
final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter()));
- for (final CentroidCluster c : clusters) {
- final double distB = FastMath.abs(distance(point, c.getCenter()));
- sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
+ if (distA != 0.0) {
+ for (final CentroidCluster c : clusters) {
+ final double distB = FastMath.abs(distance(point, c.getCenter()));
+ if (distB == 0.0) {
+ sum = Double.POSITIVE_INFINITY;
+ break;
+ }
+ sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
+ }
}
- membershipMatrix[i][j] = 1.0 / sum;
+ double membership;
+ if (sum == 0.0) {
+ membership = 1.0;
+ } else if (sum == Double.POSITIVE_INFINITY) {
+ membership = 0.0;
+ } else {
+ membership = 1.0 / sum;
+ }
+ membershipMatrix[i][j] = membership;
if (membershipMatrix[i][j] > maxMembership) {
maxMembership = membershipMatrix[i][j];
diff --git a/src/test/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClustererTest.java b/src/test/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClustererTest.java
index a4a6b8c1f..885b9c25d 100644
--- a/src/test/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClustererTest.java
+++ b/src/test/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClustererTest.java
@@ -16,9 +16,6 @@
*/
package org.apache.commons.math3.ml.clustering;
-import org.hamcrest.CoreMatchers;
-import org.junit.Assert;
-
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -29,6 +26,8 @@ import org.apache.commons.math3.ml.distance.CanberraDistance;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
+import org.hamcrest.CoreMatchers;
+import org.junit.Assert;
import org.junit.Test;
/**
@@ -106,4 +105,31 @@ public class FuzzyKMeansClustererTest {
Assert.assertThat(clusterer.getRandomGenerator(), CoreMatchers.is(random));
}
+ @Test
+ public void testSingleCluster() {
+ final List points = new ArrayList();
+ points.add(new DoublePoint(new double[] { 1, 1 }));
+
+ final FuzzyKMeansClusterer transformer =
+ new FuzzyKMeansClusterer(1, 2.0);
+ final List> clusters = transformer.cluster(points);
+
+ Assert.assertEquals(1, clusters.size());
+ }
+
+ @Test
+ public void testClusterCenterEqualsPoints() {
+ final List points = new ArrayList();
+ points.add(new DoublePoint(new double[] { 1, 1 }));
+ points.add(new DoublePoint(new double[] { 1.00001, 1.00001 }));
+ points.add(new DoublePoint(new double[] { 2, 2 }));
+ points.add(new DoublePoint(new double[] { 3, 3 }));
+
+ final FuzzyKMeansClusterer transformer =
+ new FuzzyKMeansClusterer(3, 2.0);
+ final List> clusters = transformer.cluster(points);
+
+ Assert.assertEquals(3, clusters.size());
+ }
+
}