MATH-1640: Do not try to outguess the caller.
This commit is contained in:
parent
645d85a8c7
commit
c6b4ca908c
|
@ -19,6 +19,7 @@ package org.apache.commons.math4.legacy.ml.clustering;
|
|||
|
||||
import org.apache.commons.math4.legacy.exception.NullArgumentException;
|
||||
import org.apache.commons.math4.legacy.exception.ConvergenceException;
|
||||
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
|
||||
|
@ -79,7 +80,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
|||
* @param k the number of clusters to split the data into
|
||||
*/
|
||||
public KMeansPlusPlusClusterer(final int k) {
|
||||
this(k, -1);
|
||||
this(k, Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
/** Build a clusterer.
|
||||
|
@ -104,8 +105,8 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
|||
*
|
||||
* @param k the number of clusters to split the data into
|
||||
* @param maxIterations the maximum number of iterations to run the algorithm for.
|
||||
* If negative, no maximum will be used.
|
||||
* @param measure the distance measure to use
|
||||
* @throws NotStrictlyPositiveException if {@code k <= 0}.
|
||||
*/
|
||||
public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
|
||||
this(k, maxIterations, measure, RandomSource.MT_64.create());
|
||||
|
@ -132,20 +133,30 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
|||
*
|
||||
* @param k the number of clusters to split the data into
|
||||
* @param maxIterations the maximum number of iterations to run the algorithm for.
|
||||
* If negative, no maximum will be used.
|
||||
* @param measure the distance measure to use
|
||||
* @param random random generator to use for choosing initial centers
|
||||
* @param emptyStrategy strategy to use for handling empty clusters that
|
||||
* may appear during algorithm iterations
|
||||
* @throws NotStrictlyPositiveException if {@code k <= 0} or
|
||||
* {@code maxIterations <= 0}.
|
||||
*/
|
||||
public KMeansPlusPlusClusterer(final int k, final int maxIterations,
|
||||
public KMeansPlusPlusClusterer(final int k,
|
||||
final int maxIterations,
|
||||
final DistanceMeasure measure,
|
||||
final UniformRandomProvider random,
|
||||
final EmptyClusterStrategy emptyStrategy) {
|
||||
super(measure);
|
||||
|
||||
if (k <= 0) {
|
||||
throw new NotStrictlyPositiveException(k);
|
||||
}
|
||||
if (maxIterations <= 0) {
|
||||
throw new NotStrictlyPositiveException(maxIterations);
|
||||
}
|
||||
|
||||
this.numberOfClusters = k;
|
||||
this.maxIterations = maxIterations;
|
||||
this.random = random;
|
||||
this.random = random;
|
||||
this.emptyStrategy = emptyStrategy;
|
||||
}
|
||||
|
||||
|
@ -195,8 +206,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
|||
assignPointsToClusters(clusters, points, assignments);
|
||||
|
||||
// iterate through updating the centers until we're done
|
||||
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
|
||||
for (int count = 0; count < max; count++) {
|
||||
for (int count = 0; count < maxIterations; count++) {
|
||||
boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
|
||||
List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
|
||||
int changes = assignPointsToClusters(newClusters, points, assignments);
|
||||
|
|
|
@ -133,27 +133,21 @@ public class KMeansPlusPlusClustererTest {
|
|||
public void testSmallDistances() {
|
||||
// Create a bunch of CloseDoublePoints. Most are identical, but one is different by a
|
||||
// small distance.
|
||||
int[] repeatedArray = { 0 };
|
||||
int[] uniqueArray = { 1 };
|
||||
DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
|
||||
DoublePoint uniquePoint = new DoublePoint(uniqueArray);
|
||||
final int[] repeatedArray = { 0 };
|
||||
final int[] uniqueArray = { 1 };
|
||||
final DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
|
||||
final DoublePoint uniquePoint = new DoublePoint(uniqueArray);
|
||||
|
||||
Collection<DoublePoint> points = new ArrayList<>();
|
||||
final int NUM_REPEATED_POINTS = 10 * 1000;
|
||||
for (int i = 0; i < NUM_REPEATED_POINTS; ++i) {
|
||||
final Collection<DoublePoint> points = new ArrayList<>();
|
||||
final int numRepeated = 10000;
|
||||
for (int i = 0; i < numRepeated; i++) {
|
||||
points.add(repeatedPoint);
|
||||
}
|
||||
points.add(uniquePoint);
|
||||
|
||||
// Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to simply choose initial
|
||||
// cluster centers).
|
||||
final int NUM_CLUSTERS = 2;
|
||||
final int NUM_ITERATIONS = 0;
|
||||
|
||||
KMeansPlusPlusClusterer<DoublePoint> clusterer =
|
||||
new KMeansPlusPlusClusterer<>(NUM_CLUSTERS, NUM_ITERATIONS,
|
||||
new CloseDistance(), random);
|
||||
List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
|
||||
final KMeansPlusPlusClusterer<DoublePoint> clusterer =
|
||||
new KMeansPlusPlusClusterer<>(2, 1, new CloseDistance(), random);
|
||||
final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
|
||||
|
||||
// Check that one of the chosen centers is the unique point.
|
||||
boolean uniquePointIsCenter = false;
|
||||
|
|
|
@ -58,9 +58,9 @@ public class MiniBatchKMeansClustererTest {
|
|||
final UniformRandomProvider rng = RandomSource.MT_64.create();
|
||||
List<DoublePoint> data = generateCircles(rng);
|
||||
KMeansPlusPlusClusterer<DoublePoint> kMeans =
|
||||
new KMeansPlusPlusClusterer<>(4, -1, DEFAULT_MEASURE, rng);
|
||||
new KMeansPlusPlusClusterer<>(4, Integer.MAX_VALUE, DEFAULT_MEASURE, rng);
|
||||
MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans =
|
||||
new MiniBatchKMeansClusterer<>(4, -1, 100, 3, 300, 10, DEFAULT_MEASURE, rng,
|
||||
new MiniBatchKMeansClusterer<>(4, Integer.MAX_VALUE, 100, 3, 300, 10, DEFAULT_MEASURE, rng,
|
||||
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
|
||||
// Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer
|
||||
for (int i = 0; i < 100; i++) {
|
||||
|
|
|
@ -63,7 +63,7 @@ public class CalinskiHarabaszTest {
|
|||
double actualBestScore = 0.0;
|
||||
for (int i = 0; i < 5; i++) {
|
||||
final int k = i + 2;
|
||||
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd);
|
||||
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd);
|
||||
List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points);
|
||||
double score = evaluator.score(clusters);
|
||||
if (score > expectBestScore) {
|
||||
|
@ -89,7 +89,7 @@ public class CalinskiHarabaszTest {
|
|||
double actualBestScore = 0.0;
|
||||
for (int i = 0; i < 5; i++) {
|
||||
final int k = i + 2;
|
||||
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd);
|
||||
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd);
|
||||
List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points);
|
||||
double score = evaluator.score(clusters);
|
||||
if (score > expectBestScore) {
|
||||
|
|
Loading…
Reference in New Issue