MATH-1640: Do not try to outguess the caller.

This commit is contained in:
Gilles Sadowski 2022-01-22 18:53:17 +01:00
parent 645d85a8c7
commit c6b4ca908c
4 changed files with 31 additions and 27 deletions

View File

@ -19,6 +19,7 @@ package org.apache.commons.math4.legacy.ml.clustering;
import org.apache.commons.math4.legacy.exception.NullArgumentException; import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.ConvergenceException; import org.apache.commons.math4.legacy.exception.ConvergenceException;
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException; import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure; import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
@ -79,7 +80,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
* @param k the number of clusters to split the data into * @param k the number of clusters to split the data into
*/ */
public KMeansPlusPlusClusterer(final int k) { public KMeansPlusPlusClusterer(final int k) {
this(k, -1); this(k, Integer.MAX_VALUE);
} }
/** Build a clusterer. /** Build a clusterer.
@ -104,8 +105,8 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
* *
* @param k the number of clusters to split the data into * @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for. * @param maxIterations the maximum number of iterations to run the algorithm for.
* If negative, no maximum will be used.
* @param measure the distance measure to use * @param measure the distance measure to use
* @throws NotStrictlyPositiveException if {@code k <= 0}.
*/ */
public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) { public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
this(k, maxIterations, measure, RandomSource.MT_64.create()); this(k, maxIterations, measure, RandomSource.MT_64.create());
@ -132,17 +133,27 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
* *
* @param k the number of clusters to split the data into * @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for. * @param maxIterations the maximum number of iterations to run the algorithm for.
* If negative, no maximum will be used.
* @param measure the distance measure to use * @param measure the distance measure to use
* @param random random generator to use for choosing initial centers * @param random random generator to use for choosing initial centers
* @param emptyStrategy strategy to use for handling empty clusters that * @param emptyStrategy strategy to use for handling empty clusters that
* may appear during algorithm iterations * may appear during algorithm iterations
* @throws NotStrictlyPositiveException if {@code k <= 0} or
* {@code maxIterations <= 0}.
*/ */
public KMeansPlusPlusClusterer(final int k, final int maxIterations, public KMeansPlusPlusClusterer(final int k,
final int maxIterations,
final DistanceMeasure measure, final DistanceMeasure measure,
final UniformRandomProvider random, final UniformRandomProvider random,
final EmptyClusterStrategy emptyStrategy) { final EmptyClusterStrategy emptyStrategy) {
super(measure); super(measure);
if (k <= 0) {
throw new NotStrictlyPositiveException(k);
}
if (maxIterations <= 0) {
throw new NotStrictlyPositiveException(maxIterations);
}
this.numberOfClusters = k; this.numberOfClusters = k;
this.maxIterations = maxIterations; this.maxIterations = maxIterations;
this.random = random; this.random = random;
@ -195,8 +206,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
assignPointsToClusters(clusters, points, assignments); assignPointsToClusters(clusters, points, assignments);
// iterate through updating the centers until we're done // iterate through updating the centers until we're done
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; for (int count = 0; count < maxIterations; count++) {
for (int count = 0; count < max; count++) {
boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty()); boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters); List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
int changes = assignPointsToClusters(newClusters, points, assignments); int changes = assignPointsToClusters(newClusters, points, assignments);

View File

@ -133,27 +133,21 @@ public class KMeansPlusPlusClustererTest {
public void testSmallDistances() { public void testSmallDistances() {
// Create a bunch of CloseDoublePoints. Most are identical, but one is different by a // Create a bunch of CloseDoublePoints. Most are identical, but one is different by a
// small distance. // small distance.
int[] repeatedArray = { 0 }; final int[] repeatedArray = { 0 };
int[] uniqueArray = { 1 }; final int[] uniqueArray = { 1 };
DoublePoint repeatedPoint = new DoublePoint(repeatedArray); final DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
DoublePoint uniquePoint = new DoublePoint(uniqueArray); final DoublePoint uniquePoint = new DoublePoint(uniqueArray);
Collection<DoublePoint> points = new ArrayList<>(); final Collection<DoublePoint> points = new ArrayList<>();
final int NUM_REPEATED_POINTS = 10 * 1000; final int numRepeated = 10000;
for (int i = 0; i < NUM_REPEATED_POINTS; ++i) { for (int i = 0; i < numRepeated; i++) {
points.add(repeatedPoint); points.add(repeatedPoint);
} }
points.add(uniquePoint); points.add(uniquePoint);
// Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to simply choose initial final KMeansPlusPlusClusterer<DoublePoint> clusterer =
// cluster centers). new KMeansPlusPlusClusterer<>(2, 1, new CloseDistance(), random);
final int NUM_CLUSTERS = 2; final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
final int NUM_ITERATIONS = 0;
KMeansPlusPlusClusterer<DoublePoint> clusterer =
new KMeansPlusPlusClusterer<>(NUM_CLUSTERS, NUM_ITERATIONS,
new CloseDistance(), random);
List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
// Check that one of the chosen centers is the unique point. // Check that one of the chosen centers is the unique point.
boolean uniquePointIsCenter = false; boolean uniquePointIsCenter = false;

View File

@ -58,9 +58,9 @@ public class MiniBatchKMeansClustererTest {
final UniformRandomProvider rng = RandomSource.MT_64.create(); final UniformRandomProvider rng = RandomSource.MT_64.create();
List<DoublePoint> data = generateCircles(rng); List<DoublePoint> data = generateCircles(rng);
KMeansPlusPlusClusterer<DoublePoint> kMeans = KMeansPlusPlusClusterer<DoublePoint> kMeans =
new KMeansPlusPlusClusterer<>(4, -1, DEFAULT_MEASURE, rng); new KMeansPlusPlusClusterer<>(4, Integer.MAX_VALUE, DEFAULT_MEASURE, rng);
MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans = MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans =
new MiniBatchKMeansClusterer<>(4, -1, 100, 3, 300, 10, DEFAULT_MEASURE, rng, new MiniBatchKMeansClusterer<>(4, Integer.MAX_VALUE, 100, 3, 300, 10, DEFAULT_MEASURE, rng,
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE); KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
// Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer // Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {

View File

@ -63,7 +63,7 @@ public class CalinskiHarabaszTest {
double actualBestScore = 0.0; double actualBestScore = 0.0;
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
final int k = i + 2; final int k = i + 2;
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd); KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd);
List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points); List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points);
double score = evaluator.score(clusters); double score = evaluator.score(clusters);
if (score > expectBestScore) { if (score > expectBestScore) {
@ -89,7 +89,7 @@ public class CalinskiHarabaszTest {
double actualBestScore = 0.0; double actualBestScore = 0.0;
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
final int k = i + 2; final int k = i + 2;
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd); KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd);
List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points); List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points);
double score = evaluator.score(clusters); double score = evaluator.score(clusters);
if (score > expectBestScore) { if (score > expectBestScore) {