diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClusterer.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClusterer.java index 989019466..57ab66388 100644 --- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClusterer.java +++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClusterer.java @@ -19,6 +19,7 @@ package org.apache.commons.math4.legacy.ml.clustering; import org.apache.commons.math4.legacy.exception.NullArgumentException; import org.apache.commons.math4.legacy.exception.ConvergenceException; +import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException; import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure; @@ -79,7 +80,7 @@ public class KMeansPlusPlusClusterer extends Clusterer * @param k the number of clusters to split the data into */ public KMeansPlusPlusClusterer(final int k) { - this(k, -1); + this(k, Integer.MAX_VALUE); } /** Build a clusterer. @@ -104,8 +105,8 @@ public class KMeansPlusPlusClusterer extends Clusterer * * @param k the number of clusters to split the data into * @param maxIterations the maximum number of iterations to run the algorithm for. - * If negative, no maximum will be used. * @param measure the distance measure to use + * @throws NotStrictlyPositiveException if {@code k <= 0}. */ public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) { this(k, maxIterations, measure, RandomSource.MT_64.create()); @@ -132,20 +133,30 @@ public class KMeansPlusPlusClusterer extends Clusterer * * @param k the number of clusters to split the data into * @param maxIterations the maximum number of iterations to run the algorithm for. - * If negative, no maximum will be used. * @param measure the distance measure to use * @param random random generator to use for choosing initial centers * @param emptyStrategy strategy to use for handling empty clusters that * may appear during algorithm iterations + * @throws NotStrictlyPositiveException if {@code k <= 0} or + * {@code maxIterations <= 0}. */ - public KMeansPlusPlusClusterer(final int k, final int maxIterations, + public KMeansPlusPlusClusterer(final int k, + final int maxIterations, final DistanceMeasure measure, final UniformRandomProvider random, final EmptyClusterStrategy emptyStrategy) { super(measure); + + if (k <= 0) { + throw new NotStrictlyPositiveException(k); + } + if (maxIterations <= 0) { + throw new NotStrictlyPositiveException(maxIterations); + } + this.numberOfClusters = k; this.maxIterations = maxIterations; - this.random = random; + this.random = random; this.emptyStrategy = emptyStrategy; } @@ -195,8 +206,7 @@ public class KMeansPlusPlusClusterer extends Clusterer assignPointsToClusters(clusters, points, assignments); // iterate through updating the centers until we're done - final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; - for (int count = 0; count < max; count++) { + for (int count = 0; count < maxIterations; count++) { boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty()); List> newClusters = adjustClustersCenters(clusters); int changes = assignPointsToClusters(newClusters, points, assignments); diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClustererTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClustererTest.java index a9e497901..a7f63b77f 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClustererTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClustererTest.java @@ -133,27 +133,21 @@ public class KMeansPlusPlusClustererTest { public void testSmallDistances() { // Create a bunch of CloseDoublePoints. Most are identical, but one is different by a // small distance. - int[] repeatedArray = { 0 }; - int[] uniqueArray = { 1 }; - DoublePoint repeatedPoint = new DoublePoint(repeatedArray); - DoublePoint uniquePoint = new DoublePoint(uniqueArray); + final int[] repeatedArray = { 0 }; + final int[] uniqueArray = { 1 }; + final DoublePoint repeatedPoint = new DoublePoint(repeatedArray); + final DoublePoint uniquePoint = new DoublePoint(uniqueArray); - Collection points = new ArrayList<>(); - final int NUM_REPEATED_POINTS = 10 * 1000; - for (int i = 0; i < NUM_REPEATED_POINTS; ++i) { + final Collection points = new ArrayList<>(); + final int numRepeated = 10000; + for (int i = 0; i < numRepeated; i++) { points.add(repeatedPoint); } points.add(uniquePoint); - // Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to simply choose initial - // cluster centers). - final int NUM_CLUSTERS = 2; - final int NUM_ITERATIONS = 0; - - KMeansPlusPlusClusterer clusterer = - new KMeansPlusPlusClusterer<>(NUM_CLUSTERS, NUM_ITERATIONS, - new CloseDistance(), random); - List> clusters = clusterer.cluster(points); + final KMeansPlusPlusClusterer clusterer = + new KMeansPlusPlusClusterer<>(2, 1, new CloseDistance(), random); + final List> clusters = clusterer.cluster(points); // Check that one of the chosen centers is the unique point. boolean uniquePointIsCenter = false; diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/MiniBatchKMeansClustererTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/MiniBatchKMeansClustererTest.java index ca6c7d1a6..8b9c222e8 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/MiniBatchKMeansClustererTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/MiniBatchKMeansClustererTest.java @@ -58,9 +58,9 @@ public class MiniBatchKMeansClustererTest { final UniformRandomProvider rng = RandomSource.MT_64.create(); List data = generateCircles(rng); KMeansPlusPlusClusterer kMeans = - new KMeansPlusPlusClusterer<>(4, -1, DEFAULT_MEASURE, rng); + new KMeansPlusPlusClusterer<>(4, Integer.MAX_VALUE, DEFAULT_MEASURE, rng); MiniBatchKMeansClusterer miniBatchKMeans = - new MiniBatchKMeansClusterer<>(4, -1, 100, 3, 300, 10, DEFAULT_MEASURE, rng, + new MiniBatchKMeansClusterer<>(4, Integer.MAX_VALUE, 100, 3, 300, 10, DEFAULT_MEASURE, rng, KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE); // Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer for (int i = 0; i < 100; i++) { diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/evaluation/CalinskiHarabaszTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/evaluation/CalinskiHarabaszTest.java index 03f5295b6..f92c18db3 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/evaluation/CalinskiHarabaszTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/evaluation/CalinskiHarabaszTest.java @@ -63,7 +63,7 @@ public class CalinskiHarabaszTest { double actualBestScore = 0.0; for (int i = 0; i < 5; i++) { final int k = i + 2; - KMeansPlusPlusClusterer kMeans = new KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd); + KMeansPlusPlusClusterer kMeans = new KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd); List> clusters = kMeans.cluster(points); double score = evaluator.score(clusters); if (score > expectBestScore) { @@ -89,7 +89,7 @@ public class CalinskiHarabaszTest { double actualBestScore = 0.0; for (int i = 0; i < 5; i++) { final int k = i + 2; - KMeansPlusPlusClusterer kMeans = new KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd); + KMeansPlusPlusClusterer kMeans = new KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd); List> clusters = kMeans.cluster(points); double score = evaluator.score(clusters); if (score > expectBestScore) {