From f2eebe68d856ff681f5d7b0282a02bf6d396a9ea Mon Sep 17 00:00:00 2001 From: CT Date: Thu, 26 Mar 2020 17:04:10 +0800 Subject: [PATCH 1/3] MATH-1524: Remove package initialization and reuse chooseInitialCenters as package-private --- .../clustering/KMeansPlusPlusClusterer.java | 152 ++++++++++++-- .../clustering/MiniBatchKMeansClusterer.java | 2 +- .../initialization/CentroidInitializer.java | 40 ---- .../KMeansPlusPlusCentroidInitializer.java | 186 ------------------ .../RandomCentroidInitializer.java | 65 ------ .../initialization/package-info.java | 24 --- .../CentroidInitializerTest.java | 65 ------ 7 files changed, 133 insertions(+), 401 deletions(-) delete mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java delete mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java delete mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java delete mode 100644 src/main/java/org/apache/commons/math4/ml/clustering/initialization/package-info.java delete mode 100644 src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java index 7bc72eeed..7036ccb0c 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java @@ -17,21 +17,20 @@ package org.apache.commons.math4.ml.clustering; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; - import org.apache.commons.math4.exception.ConvergenceException; import org.apache.commons.math4.exception.NumberIsTooSmallException; import org.apache.commons.math4.exception.util.LocalizedFormats; -import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer; -import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer; import org.apache.commons.math4.ml.distance.DistanceMeasure; import org.apache.commons.math4.ml.distance.EuclideanDistance; import org.apache.commons.math4.stat.descriptive.moment.Variance; import org.apache.commons.math4.util.MathUtils; -import org.apache.commons.rng.simple.RandomSource; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; /** * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. @@ -70,9 +69,6 @@ public class KMeansPlusPlusClusterer extends Clusterer /** Selected strategy for empty clusters. */ private final EmptyClusterStrategy emptyStrategy; - /** Clusters centroids initializer. */ - private final CentroidInitializer centroidInitializer; - /** Build a clusterer. *

* The default strategy for handling empty clusters that may appear during @@ -151,8 +147,6 @@ public class KMeansPlusPlusClusterer extends Clusterer this.maxIterations = maxIterations; this.random = random; this.emptyStrategy = emptyStrategy; - // Use K-means++ to choose the initial centers. - this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure, random); } /** @@ -193,7 +187,7 @@ public class KMeansPlusPlusClusterer extends Clusterer } // create the initial clusters - List> clusters = centroidInitializer.selectCentroids(points, k); + List> clusters = chooseInitialCenters(points); // create an array containing the latest assignment of a point to a cluster // no need to initialize the array, as it will be filled with the first assignment @@ -231,13 +225,6 @@ public class KMeansPlusPlusClusterer extends Clusterer return emptyStrategy; } - /** - * @return the CentroidInitializer - */ - CentroidInitializer getCentroidInitializer() { - return centroidInitializer; - } - /** * Adjust the clusters's centers with means of points * @param clusters the origin clusters @@ -296,6 +283,131 @@ public class KMeansPlusPlusClusterer extends Clusterer return assignedDifferently; } + /** + * Use K-means++ to choose the initial centers. + * + * @param points the points to choose the initial centers from + * @return the initial centers + */ + List> chooseInitialCenters(final Collection points) { + + // Convert to list for indexed access. Make it unmodifiable, since removal of items + // would screw up the logic of this method. + final List pointList = Collections.unmodifiableList(new ArrayList<> (points)); + + // The number of points in the list. + final int numPoints = pointList.size(); + + // Set the corresponding element in this array to indicate when + // elements of pointList are no longer available. + final boolean[] taken = new boolean[numPoints]; + + // The resulting list of initial centers. + final List> resultSet = new ArrayList<>(); + + // Choose one center uniformly at random from among the data points. + final int firstPointIndex = random.nextInt(numPoints); + + final T firstPoint = pointList.get(firstPointIndex); + + resultSet.add(new CentroidCluster(firstPoint)); + + // Must mark it as taken + taken[firstPointIndex] = true; + + // To keep track of the minimum distance squared of elements of + // pointList to elements of resultSet. + final double[] minDistSquared = new double[numPoints]; + + // Initialize the elements. Since the only point in resultSet is firstPoint, + // this is very easy. + for (int i = 0; i < numPoints; i++) { + if (i != firstPointIndex) { // That point isn't considered + double d = distance(firstPoint, pointList.get(i)); + minDistSquared[i] = d*d; + } + } + + while (resultSet.size() < k) { + + // Sum up the squared distances for the points in pointList not + // already taken. + double distSqSum = 0.0; + + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + distSqSum += minDistSquared[i]; + } + } + + // Add one new data point as a center. Each point x is chosen with + // probability proportional to D(x)2 + final double r = random.nextDouble() * distSqSum; + + // The index of the next point to be added to the resultSet. + int nextPointIndex = -1; + + // Sum through the squared min distances again, stopping when + // sum >= r. + double sum = 0.0; + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + sum += minDistSquared[i]; + if (sum >= r) { + nextPointIndex = i; + break; + } + } + } + + // If it's not set to >= 0, the point wasn't found in the previous + // for loop, probably because distances are extremely small. Just pick + // the last available point. + if (nextPointIndex == -1) { + for (int i = numPoints - 1; i >= 0; i--) { + if (!taken[i]) { + nextPointIndex = i; + break; + } + } + } + + // We found one. + if (nextPointIndex >= 0) { + + final T p = pointList.get(nextPointIndex); + + resultSet.add(new CentroidCluster (p)); + + // Mark it as taken. + taken[nextPointIndex] = true; + + if (resultSet.size() < k) { + // Now update elements of minDistSquared. We only have to compute + // the distance to the new center to do this. + for (int j = 0; j < numPoints; j++) { + // Only have to worry about the points still not taken. + if (!taken[j]) { + double d = distance(p, pointList.get(j)); + double d2 = d * d; + if (d2 < minDistSquared[j]) { + minDistSquared[j] = d2; + } + } + } + } + + } else { + // None found -- + // Break from the while loop to prevent + // an infinite loop. + break; + } + } + + return resultSet; + } + /** * Get a random point from the {@link Cluster} with the largest distance variance. * diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java index e131e0a4d..e54f25d98 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java @@ -195,7 +195,7 @@ public class MiniBatchKMeansClusterer final List initialPoints = (initBatchSize < points.size()) ? ListSampler.sample(getRandomGenerator(), points, initBatchSize) : new ArrayList<>(points); - final List> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK()); + final List> clusters = chooseInitialCenters(initialPoints); final Pair>> pair = step(validPoints, clusters); final double squareDistance = pair.getFirst(); final List> newClusters = pair.getSecond(); diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java deleted file mode 100644 index 0378f5516..000000000 --- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.commons.math4.ml.clustering.initialization; - -import org.apache.commons.math4.ml.clustering.CentroidCluster; -import org.apache.commons.math4.ml.clustering.Clusterable; - -import java.util.Collection; -import java.util.List; - -/** - * Interface abstract the algorithm for clusterer to choose the initial centers. - */ -public interface CentroidInitializer { - - /** - * Choose the initial centers. - * - * @param Type of points to cluster. - * @param points the points to choose the initial centers from - * @param k The number of clusters - * @return the initial centers - */ - List> selectCentroids(final Collection points, final int k); -} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java deleted file mode 100644 index f0ab288c2..000000000 --- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.commons.math4.ml.clustering.initialization; - -import org.apache.commons.math4.ml.clustering.CentroidCluster; -import org.apache.commons.math4.ml.clustering.Clusterable; -import org.apache.commons.math4.ml.distance.DistanceMeasure; -import org.apache.commons.rng.UniformRandomProvider; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; - -/** - * Use K-means++ to choose the initial centers. - * - * @see K-means++ (wikipedia) - */ -public class KMeansPlusPlusCentroidInitializer implements CentroidInitializer { - private final DistanceMeasure measure; - private final UniformRandomProvider random; - - /** - * Build a K-means++ CentroidInitializer - * @param measure the distance measure to use - * @param random the random to use. - */ - public KMeansPlusPlusCentroidInitializer(final DistanceMeasure measure, final UniformRandomProvider random) { - this.measure = measure; - this.random = random; - } - - /** - * Use K-means++ to choose the initial centers. - * - * @param points the points to choose the initial centers from - * @param k The number of clusters - * @return the initial centers - */ - @Override - public List> selectCentroids(final Collection points, final int k) { - // Convert to list for indexed access. Make it unmodifiable, since removal of items - // would screw up the logic of this method. - final List pointList = Collections.unmodifiableList(new ArrayList<>(points)); - - // The number of points in the list. - final int numPoints = pointList.size(); - - // Set the corresponding element in this array to indicate when - // elements of pointList are no longer available. - final boolean[] taken = new boolean[numPoints]; - - // The resulting list of initial centers. - final List> resultSet = new ArrayList<>(); - - // Choose one center uniformly at random from among the data points. - final int firstPointIndex = random.nextInt(numPoints); - - final T firstPoint = pointList.get(firstPointIndex); - - resultSet.add(new CentroidCluster<>(firstPoint)); - - // Must mark it as taken - taken[firstPointIndex] = true; - - // To keep track of the minimum distance squared of elements of - // pointList to elements of resultSet. - final double[] minDistSquared = new double[numPoints]; - - // Initialize the elements. Since the only point in resultSet is firstPoint, - // this is very easy. - for (int i = 0; i < numPoints; i++) { - if (i != firstPointIndex) { // That point isn't considered - double d = distance(firstPoint, pointList.get(i)); - minDistSquared[i] = d * d; - } - } - - while (resultSet.size() < k) { - - // Sum up the squared distances for the points in pointList not - // already taken. - double distSqSum = 0.0; - - for (int i = 0; i < numPoints; i++) { - if (!taken[i]) { - distSqSum += minDistSquared[i]; - } - } - - // Add one new data point as a center. Each point x is chosen with - // probability proportional to D(x)2 - final double r = random.nextDouble() * distSqSum; - - // The index of the next point to be added to the resultSet. - int nextPointIndex = -1; - - // Sum through the squared min distances again, stopping when - // sum >= r. - double sum = 0.0; - for (int i = 0; i < numPoints; i++) { - if (!taken[i]) { - sum += minDistSquared[i]; - if (sum >= r) { - nextPointIndex = i; - break; - } - } - } - - // If it's not set to >= 0, the point wasn't found in the previous - // for loop, probably because distances are extremely small. Just pick - // the last available point. - if (nextPointIndex == -1) { - for (int i = numPoints - 1; i >= 0; i--) { - if (!taken[i]) { - nextPointIndex = i; - break; - } - } - } - - // We found one. - if (nextPointIndex >= 0) { - - final T p = pointList.get(nextPointIndex); - - resultSet.add(new CentroidCluster<>(p)); - - // Mark it as taken. - taken[nextPointIndex] = true; - - if (resultSet.size() < k) { - // Now update elements of minDistSquared. We only have to compute - // the distance to the new center to do this. - for (int j = 0; j < numPoints; j++) { - // Only have to worry about the points still not taken. - if (!taken[j]) { - double d = distance(p, pointList.get(j)); - double d2 = d * d; - if (d2 < minDistSquared[j]) { - minDistSquared[j] = d2; - } - } - } - } - - } else { - // None found -- - // Break from the while loop to prevent - // an infinite loop. - break; - } - } - - return resultSet; - } - - /** - * Calculates the distance between two {@link Clusterable} instances - * with the configured {@link DistanceMeasure}. - * - * @param p1 the first clusterable - * @param p2 the second clusterable - * @return the distance between the two clusterables - */ - protected double distance(final Clusterable p1, final Clusterable p2) { - return measure.compute(p1.getPoint(), p2.getPoint()); - } -} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java deleted file mode 100644 index f3f561d15..000000000 --- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.commons.math4.ml.clustering.initialization; - -import org.apache.commons.math4.ml.clustering.CentroidCluster; -import org.apache.commons.math4.ml.clustering.Clusterable; -import org.apache.commons.rng.UniformRandomProvider; -import org.apache.commons.rng.sampling.ListSampler; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; - -/** - * Random choose the initial centers. - */ -public class RandomCentroidInitializer implements CentroidInitializer { - private final UniformRandomProvider random; - - /** - * Build a random RandomCentroidInitializer - * - * @param random the random to use. - */ - public RandomCentroidInitializer(final UniformRandomProvider random) { - this.random = random; - } - - /** - * Random choose the initial centers. - * - * @param points the points to choose the initial centers from - * @param k The number of clusters - * @return the initial centers - */ - @Override - public List> selectCentroids(final Collection points, final int k) { - if (k < 1) { - return Collections.emptyList(); - } - final ArrayList list = new ArrayList<>(points); - ListSampler.shuffle(random, list); - final List> result = new ArrayList<>(k); - for (int i = 0; i < k; i++) { - result.add(new CentroidCluster<>(list.get(i))); - } - return result; - } -} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/package-info.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/package-info.java deleted file mode 100644 index 3c335ae8e..000000000 --- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/package-info.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * Clusters initialization methods. - * - * Centroid cluster initializers should implement the - * {@link org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer CentroidInitializer} - * interface. - */ -package org.apache.commons.math4.ml.clustering.initialization; diff --git a/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java b/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java deleted file mode 100644 index 329dce71c..000000000 --- a/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.commons.math4.ml.clustering.initialization; - -import org.apache.commons.math4.ml.clustering.CentroidCluster; -import org.apache.commons.math4.ml.clustering.DoublePoint; -import org.apache.commons.math4.ml.distance.EuclideanDistance; -import org.apache.commons.rng.UniformRandomProvider; -import org.apache.commons.rng.simple.RandomSource; -import org.junit.Assert; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.List; - -public class CentroidInitializerTest { - private void test_generate_appropriate_number_of_cluster( - final CentroidInitializer initializer) { - // Generate some data - final List points = new ArrayList<>(); - final UniformRandomProvider rnd = RandomSource.create(RandomSource.MT_64); - for (int i = 0; i < 500; i++) { - double[] p = new double[2]; - p[0] = rnd.nextDouble(); - p[1] = rnd.nextDouble(); - points.add(new DoublePoint(p)); - } - // We can only assert that the centroid initializer - // implementation generate appropriate number of cluster - for (int k = 1; k < 50; k++) { - final List> centroidClusters = - initializer.selectCentroids(points, k); - Assert.assertEquals(k, centroidClusters.size()); - } - } - - @Test - public void test_RandomCentroidInitializer() { - final CentroidInitializer initializer = - new RandomCentroidInitializer(RandomSource.create(RandomSource.MT_64)); - test_generate_appropriate_number_of_cluster(initializer); - } - - @Test - public void test_KMeanPlusPlusCentroidInitializer() { - final CentroidInitializer initializer = - new KMeansPlusPlusCentroidInitializer(new EuclideanDistance(), - RandomSource.create(RandomSource.MT_64)); - test_generate_appropriate_number_of_cluster(initializer); - } -} From 68de72cc0b4c3f349c0b8d6d03ceebfd49cdc00c Mon Sep 17 00:00:00 2001 From: Gilles Sadowski Date: Thu, 26 Mar 2020 17:55:55 +0100 Subject: [PATCH 2/3] Make "" and "" sections consistent (wrt RAT analysis). --- pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/pom.xml b/pom.xml index bb76392ab..08594e0bd 100644 --- a/pom.xml +++ b/pom.xml @@ -489,6 +489,7 @@ .git/** .checkstyle .ekstazi/** + **/target/** From b95a43fa9af718899d03bbc2c10587c069c707f0 Mon Sep 17 00:00:00 2001 From: Gilles Sadowski Date: Thu, 26 Mar 2020 18:18:47 +0100 Subject: [PATCH 3/3] More self-documenting code. --- .../ml/clustering/KMeansPlusPlusClusterer.java | 16 ++++++++-------- .../ml/clustering/MiniBatchKMeansClusterer.java | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java index 7036ccb0c..806b2487b 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java @@ -58,7 +58,7 @@ public class KMeansPlusPlusClusterer extends Clusterer } /** The number of clusters. */ - private final int k; + private final int numberOfClusters; /** The maximum number of iterations. */ private final int maxIterations; @@ -143,7 +143,7 @@ public class KMeansPlusPlusClusterer extends Clusterer final UniformRandomProvider random, final EmptyClusterStrategy emptyStrategy) { super(measure); - this.k = k; + this.numberOfClusters = k; this.maxIterations = maxIterations; this.random = random; this.emptyStrategy = emptyStrategy; @@ -153,8 +153,8 @@ public class KMeansPlusPlusClusterer extends Clusterer * Return the number of clusters this instance will use. * @return the number of clusters */ - public int getK() { - return k; + public int getNumberOfClusters() { + return numberOfClusters; } /** @@ -182,8 +182,8 @@ public class KMeansPlusPlusClusterer extends Clusterer MathUtils.checkNotNull(points); // number of clusters has to be smaller or equal the number of data points - if (points.size() < k) { - throw new NumberIsTooSmallException(points.size(), k, false); + if (points.size() < numberOfClusters) { + throw new NumberIsTooSmallException(points.size(), numberOfClusters, false); } // create the initial clusters @@ -328,7 +328,7 @@ public class KMeansPlusPlusClusterer extends Clusterer } } - while (resultSet.size() < k) { + while (resultSet.size() < numberOfClusters) { // Sum up the squared distances for the points in pointList not // already taken. @@ -382,7 +382,7 @@ public class KMeansPlusPlusClusterer extends Clusterer // Mark it as taken. taken[nextPointIndex] = true; - if (resultSet.size() < k) { + if (resultSet.size() < numberOfClusters) { // Now update elements of minDistSquared. We only have to compute // the distance to the new center to do this. for (int j = 0; j < numPoints; j++) { diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java index e54f25d98..a563fda3d 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java @@ -106,8 +106,8 @@ public class MiniBatchKMeansClusterer public List> cluster(final Collection points) { // Sanity check. MathUtils.checkNotNull(points); - if (points.size() < getK()) { - throw new NumberIsTooSmallException(points.size(), getK(), false); + if (points.size() < getNumberOfClusters()) { + throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false); } final int pointSize = points.size();