MATH-1640: Do not try to outguess the caller.

This commit is contained in:
Gilles Sadowski 2022-01-22 18:53:17 +01:00
parent 645d85a8c7
commit c6b4ca908c
4 changed files with 31 additions and 27 deletions

View File

@ -19,6 +19,7 @@ package org.apache.commons.math4.legacy.ml.clustering;
import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.ConvergenceException;
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
@ -79,7 +80,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
* @param k the number of clusters to split the data into
*/
public KMeansPlusPlusClusterer(final int k) {
this(k, -1);
this(k, Integer.MAX_VALUE);
}
/** Build a clusterer.
@ -104,8 +105,8 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
*
* @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for.
* If negative, no maximum will be used.
* @param measure the distance measure to use
* @throws NotStrictlyPositiveException if {@code k <= 0}.
*/
public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
this(k, maxIterations, measure, RandomSource.MT_64.create());
@ -132,20 +133,30 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
*
* @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for.
* If negative, no maximum will be used.
* @param measure the distance measure to use
* @param random random generator to use for choosing initial centers
* @param emptyStrategy strategy to use for handling empty clusters that
* may appear during algorithm iterations
* @throws NotStrictlyPositiveException if {@code k <= 0} or
* {@code maxIterations <= 0}.
*/
public KMeansPlusPlusClusterer(final int k, final int maxIterations,
public KMeansPlusPlusClusterer(final int k,
final int maxIterations,
final DistanceMeasure measure,
final UniformRandomProvider random,
final EmptyClusterStrategy emptyStrategy) {
super(measure);
if (k <= 0) {
throw new NotStrictlyPositiveException(k);
}
if (maxIterations <= 0) {
throw new NotStrictlyPositiveException(maxIterations);
}
this.numberOfClusters = k;
this.maxIterations = maxIterations;
this.random = random;
this.random = random;
this.emptyStrategy = emptyStrategy;
}
@ -195,8 +206,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
assignPointsToClusters(clusters, points, assignments);
// iterate through updating the centers until we're done
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
for (int count = 0; count < max; count++) {
for (int count = 0; count < maxIterations; count++) {
boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
int changes = assignPointsToClusters(newClusters, points, assignments);

View File

@ -133,27 +133,21 @@ public class KMeansPlusPlusClustererTest {
public void testSmallDistances() {
// Create a bunch of CloseDoublePoints. Most are identical, but one is different by a
// small distance.
int[] repeatedArray = { 0 };
int[] uniqueArray = { 1 };
DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
DoublePoint uniquePoint = new DoublePoint(uniqueArray);
final int[] repeatedArray = { 0 };
final int[] uniqueArray = { 1 };
final DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
final DoublePoint uniquePoint = new DoublePoint(uniqueArray);
Collection<DoublePoint> points = new ArrayList<>();
final int NUM_REPEATED_POINTS = 10 * 1000;
for (int i = 0; i < NUM_REPEATED_POINTS; ++i) {
final Collection<DoublePoint> points = new ArrayList<>();
final int numRepeated = 10000;
for (int i = 0; i < numRepeated; i++) {
points.add(repeatedPoint);
}
points.add(uniquePoint);
// Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to simply choose initial
// cluster centers).
final int NUM_CLUSTERS = 2;
final int NUM_ITERATIONS = 0;
KMeansPlusPlusClusterer<DoublePoint> clusterer =
new KMeansPlusPlusClusterer<>(NUM_CLUSTERS, NUM_ITERATIONS,
new CloseDistance(), random);
List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
final KMeansPlusPlusClusterer<DoublePoint> clusterer =
new KMeansPlusPlusClusterer<>(2, 1, new CloseDistance(), random);
final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
// Check that one of the chosen centers is the unique point.
boolean uniquePointIsCenter = false;

View File

@ -58,9 +58,9 @@ public class MiniBatchKMeansClustererTest {
final UniformRandomProvider rng = RandomSource.MT_64.create();
List<DoublePoint> data = generateCircles(rng);
KMeansPlusPlusClusterer<DoublePoint> kMeans =
new KMeansPlusPlusClusterer<>(4, -1, DEFAULT_MEASURE, rng);
new KMeansPlusPlusClusterer<>(4, Integer.MAX_VALUE, DEFAULT_MEASURE, rng);
MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans =
new MiniBatchKMeansClusterer<>(4, -1, 100, 3, 300, 10, DEFAULT_MEASURE, rng,
new MiniBatchKMeansClusterer<>(4, Integer.MAX_VALUE, 100, 3, 300, 10, DEFAULT_MEASURE, rng,
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
// Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer
for (int i = 0; i < 100; i++) {

View File

@ -63,7 +63,7 @@ public class CalinskiHarabaszTest {
double actualBestScore = 0.0;
for (int i = 0; i < 5; i++) {
final int k = i + 2;
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd);
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd);
List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points);
double score = evaluator.score(clusters);
if (score > expectBestScore) {
@ -89,7 +89,7 @@ public class CalinskiHarabaszTest {
double actualBestScore = 0.0;
for (int i = 0; i < 5; i++) {
final int k = i + 2;
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd);
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd);
List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points);
double score = evaluator.score(clusters);
if (score > expectBestScore) {