diff --git a/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java index b73ac9d3e..e09bbc357 100644 --- a/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java +++ b/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java @@ -172,7 +172,7 @@ public class KMeansPlusPlusClusterer> { while (resultSet.size() < k) { // For each data point x, compute D(x), the distance between x and // the nearest center that has already been chosen. - int sum = 0; + double sum = 0; for (int i = 0; i < pointSet.size(); i++) { final T p = pointSet.get(i); final Cluster nearest = getNearestCluster(resultSet, p); diff --git a/src/site/xdoc/changes.xml b/src/site/xdoc/changes.xml index 154001edc..e4f9c3431 100644 --- a/src/site/xdoc/changes.xml +++ b/src/site/xdoc/changes.xml @@ -52,6 +52,9 @@ The type attribute can be add,update,fix,remove. If the output is not quite correct, check for invisible trailing spaces! --> + + Fixed bug in "KMeansPlusPlusClusterer". + All exceptions defined in Commons Math provide a context and a compound message list. diff --git a/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java b/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java index 33b2d8af8..0bb04203d 100644 --- a/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java +++ b/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java @@ -20,7 +20,9 @@ package org.apache.commons.math.stat.clustering; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Random; @@ -166,4 +168,84 @@ public class KMeansPlusPlusClustererTest { } + /** + * A helper class for testSmallDistances(). This class is similar to EuclideanIntegerPoint, but + * it defines a different distanceFrom() method that tends to return distances less than 1. + */ + private class CloseIntegerPoint implements Clusterable { + public CloseIntegerPoint(EuclideanIntegerPoint point) { + euclideanPoint = point; + } + + public double distanceFrom(CloseIntegerPoint p) { + return euclideanPoint.distanceFrom(p.euclideanPoint) * 0.001; + } + + public CloseIntegerPoint centroidOf(Collection p) { + Collection euclideanPoints = + new ArrayList(); + for (CloseIntegerPoint point : p) { + euclideanPoints.add(point.euclideanPoint); + } + return new CloseIntegerPoint(euclideanPoint.centroidOf(euclideanPoints)); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof CloseIntegerPoint)) { + return false; + } + CloseIntegerPoint p = (CloseIntegerPoint) o; + + return euclideanPoint.equals(p.euclideanPoint); + } + + @Override + public int hashCode() { + return euclideanPoint.hashCode(); + } + + private EuclideanIntegerPoint euclideanPoint; + } + + /** + * Test points that are very close together. See issue MATH-546. + */ + @Test + public void testSmallDistances() { + // Create a bunch of CloseIntegerPoints. Most are identical, but one is different by a + // small distance. + int[] repeatedArray = { 0 }; + int[] uniqueArray = { 1 }; + CloseIntegerPoint repeatedPoint = + new CloseIntegerPoint(new EuclideanIntegerPoint(repeatedArray)); + CloseIntegerPoint uniquePoint = + new CloseIntegerPoint(new EuclideanIntegerPoint(uniqueArray)); + + Collection points = new ArrayList(); + final int NUM_REPEATED_POINTS = 10 * 1000; + for (int i = 0; i < NUM_REPEATED_POINTS; ++i) { + points.add(repeatedPoint); + } + points.add(uniquePoint); + + // Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to simply choose initial + // cluster centers). + final long RANDOM_SEED = 0; + final int NUM_CLUSTERS = 2; + final int NUM_ITERATIONS = 0; + KMeansPlusPlusClusterer clusterer = + new KMeansPlusPlusClusterer(new Random(RANDOM_SEED)); + List> clusters = + clusterer.cluster(points, NUM_CLUSTERS, NUM_ITERATIONS); + + // Check that one of the chosen centers is the unique point. + boolean uniquePointIsCenter = false; + for (Cluster cluster : clusters) { + if (cluster.getCenter().equals(uniquePoint)) { + uniquePointIsCenter = true; + } + } + assertTrue(uniquePointIsCenter); + } }