diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 5fa4572de..df71204e5 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -73,6 +73,10 @@ Users are encouraged to upgrade to this version as this release not 2. A few methods in the FastMath class are in fact slower that their counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901). "> + + "FuzzyKMeansClusterer" has thrown an exception in case one of the data + points was equal to a cluster center. + Interface to allow parameter validation in "o.a.c.m.fitting.leastsquares": the point computed by by the optimizer can be modified before evaluation. diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java index ed2204789..5f89934a7 100644 --- a/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java +++ b/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java @@ -346,18 +346,32 @@ public class FuzzyKMeansClusterer extends Clusterer { private void updateMembershipMatrix() { for (int i = 0; i < points.size(); i++) { final T point = points.get(i); - double maxMembership = 0.0; + double maxMembership = Double.MIN_VALUE; int newCluster = -1; for (int j = 0; j < clusters.size(); j++) { double sum = 0.0; final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter())); - for (final CentroidCluster c : clusters) { - final double distB = FastMath.abs(distance(point, c.getCenter())); - sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0)); + if (distA != 0.0) { + for (final CentroidCluster c : clusters) { + final double distB = FastMath.abs(distance(point, c.getCenter())); + if (distB == 0.0) { + sum = Double.POSITIVE_INFINITY; + break; + } + sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0)); + } } - membershipMatrix[i][j] = 1.0 / sum; + double membership; + if (sum == 0.0) { + membership = 1.0; + } else if (sum == Double.POSITIVE_INFINITY) { + membership = 0.0; + } else { + membership = 1.0 / sum; + } + membershipMatrix[i][j] = membership; if (membershipMatrix[i][j] > maxMembership) { maxMembership = membershipMatrix[i][j]; diff --git a/src/test/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClustererTest.java b/src/test/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClustererTest.java index a4a6b8c1f..885b9c25d 100644 --- a/src/test/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClustererTest.java +++ b/src/test/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClustererTest.java @@ -16,9 +16,6 @@ */ package org.apache.commons.math3.ml.clustering; -import org.hamcrest.CoreMatchers; -import org.junit.Assert; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -29,6 +26,8 @@ import org.apache.commons.math3.ml.distance.CanberraDistance; import org.apache.commons.math3.ml.distance.DistanceMeasure; import org.apache.commons.math3.random.JDKRandomGenerator; import org.apache.commons.math3.random.RandomGenerator; +import org.hamcrest.CoreMatchers; +import org.junit.Assert; import org.junit.Test; /** @@ -106,4 +105,31 @@ public class FuzzyKMeansClustererTest { Assert.assertThat(clusterer.getRandomGenerator(), CoreMatchers.is(random)); } + @Test + public void testSingleCluster() { + final List points = new ArrayList(); + points.add(new DoublePoint(new double[] { 1, 1 })); + + final FuzzyKMeansClusterer transformer = + new FuzzyKMeansClusterer(1, 2.0); + final List> clusters = transformer.cluster(points); + + Assert.assertEquals(1, clusters.size()); + } + + @Test + public void testClusterCenterEqualsPoints() { + final List points = new ArrayList(); + points.add(new DoublePoint(new double[] { 1, 1 })); + points.add(new DoublePoint(new double[] { 1.00001, 1.00001 })); + points.add(new DoublePoint(new double[] { 2, 2 })); + points.add(new DoublePoint(new double[] { 3, 3 })); + + final FuzzyKMeansClusterer transformer = + new FuzzyKMeansClusterer(3, 2.0); + final List> clusters = transformer.cluster(points); + + Assert.assertEquals(3, clusters.size()); + } + }