This commit is contained in:
Gilles 2014-11-11 23:04:21 +01:00
commit 9786ad4de8
5 changed files with 58 additions and 9 deletions

4
.gitignore vendored
View File

@ -8,3 +8,7 @@ target
/lib /lib
/site-content /site-content
*.class *.class
*.iml
*.ipr
*.iws
.idea

View File

@ -73,6 +73,10 @@ Users are encouraged to upgrade to this version as this release not
2. A few methods in the FastMath class are in fact slower that their 2. A few methods in the FastMath class are in fact slower that their
counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901). counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901).
"> ">
<action dev="tn" type="fix" issue="MATH-1165" due-to="Pashutan Modaresi">
"FuzzyKMeansClusterer" has thrown an exception in case one of the data
points was equal to a cluster center.
</action>
<action dev="erans" type="add" issue="MATH-1144"> <action dev="erans" type="add" issue="MATH-1144">
Interface to allow parameter validation in "o.a.c.m.fitting.leastsquares": Interface to allow parameter validation in "o.a.c.m.fitting.leastsquares":
the point computed by by the optimizer can be modified before evaluation. the point computed by by the optimizer can be modified before evaluation.

View File

@ -51,6 +51,7 @@ public class LeastSquaresFactory {
* @param model the model function. Produces the computed values. * @param model the model function. Produces the computed values.
* @param observed the observed (target) values * @param observed the observed (target) values
* @param start the initial guess. * @param start the initial guess.
* @param weight the weight matrix
* @param checker convergence checker * @param checker convergence checker
* @param maxEvaluations the maximum number of times to evaluate the model * @param maxEvaluations the maximum number of times to evaluate the model
* @param maxIterations the maximum number to times to iterate in the algorithm * @param maxIterations the maximum number to times to iterate in the algorithm
@ -74,7 +75,7 @@ public class LeastSquaresFactory {
observed, observed,
start, start,
checker, checker,
maxEvaluations, maxEvaluations,
maxIterations, maxIterations,
lazyEvaluation, lazyEvaluation,
paramValidator); paramValidator);

View File

@ -346,18 +346,32 @@ public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> {
private void updateMembershipMatrix() { private void updateMembershipMatrix() {
for (int i = 0; i < points.size(); i++) { for (int i = 0; i < points.size(); i++) {
final T point = points.get(i); final T point = points.get(i);
double maxMembership = 0.0; double maxMembership = Double.MIN_VALUE;
int newCluster = -1; int newCluster = -1;
for (int j = 0; j < clusters.size(); j++) { for (int j = 0; j < clusters.size(); j++) {
double sum = 0.0; double sum = 0.0;
final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter())); final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter()));
for (final CentroidCluster<T> c : clusters) { if (distA != 0.0) {
final double distB = FastMath.abs(distance(point, c.getCenter())); for (final CentroidCluster<T> c : clusters) {
sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0)); final double distB = FastMath.abs(distance(point, c.getCenter()));
if (distB == 0.0) {
sum = Double.POSITIVE_INFINITY;
break;
}
sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
}
} }
membershipMatrix[i][j] = 1.0 / sum; double membership;
if (sum == 0.0) {
membership = 1.0;
} else if (sum == Double.POSITIVE_INFINITY) {
membership = 0.0;
} else {
membership = 1.0 / sum;
}
membershipMatrix[i][j] = membership;
if (membershipMatrix[i][j] > maxMembership) { if (membershipMatrix[i][j] > maxMembership) {
maxMembership = membershipMatrix[i][j]; maxMembership = membershipMatrix[i][j];

View File

@ -16,9 +16,6 @@
*/ */
package org.apache.commons.math3.ml.clustering; package org.apache.commons.math3.ml.clustering;
import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -29,6 +26,8 @@ import org.apache.commons.math3.ml.distance.CanberraDistance;
import org.apache.commons.math3.ml.distance.DistanceMeasure; import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.random.JDKRandomGenerator; import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.RandomGenerator;
import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
/** /**
@ -106,4 +105,31 @@ public class FuzzyKMeansClustererTest {
Assert.assertThat(clusterer.getRandomGenerator(), CoreMatchers.is(random)); Assert.assertThat(clusterer.getRandomGenerator(), CoreMatchers.is(random));
} }
@Test
public void testSingleCluster() {
final List<DoublePoint> points = new ArrayList<DoublePoint>();
points.add(new DoublePoint(new double[] { 1, 1 }));
final FuzzyKMeansClusterer<DoublePoint> transformer =
new FuzzyKMeansClusterer<DoublePoint>(1, 2.0);
final List<CentroidCluster<DoublePoint>> clusters = transformer.cluster(points);
Assert.assertEquals(1, clusters.size());
}
@Test
public void testClusterCenterEqualsPoints() {
final List<DoublePoint> points = new ArrayList<DoublePoint>();
points.add(new DoublePoint(new double[] { 1, 1 }));
points.add(new DoublePoint(new double[] { 1.00001, 1.00001 }));
points.add(new DoublePoint(new double[] { 2, 2 }));
points.add(new DoublePoint(new double[] { 3, 3 }));
final FuzzyKMeansClusterer<DoublePoint> transformer =
new FuzzyKMeansClusterer<DoublePoint>(3, 2.0);
final List<CentroidCluster<DoublePoint>> clusters = transformer.cluster(points);
Assert.assertEquals(3, clusters.size());
}
} }