MATH-1640: Do not try to outguess the caller.
This commit is contained in:
parent
645d85a8c7
commit
c6b4ca908c
|
@ -19,6 +19,7 @@ package org.apache.commons.math4.legacy.ml.clustering;
|
||||||
|
|
||||||
import org.apache.commons.math4.legacy.exception.NullArgumentException;
|
import org.apache.commons.math4.legacy.exception.NullArgumentException;
|
||||||
import org.apache.commons.math4.legacy.exception.ConvergenceException;
|
import org.apache.commons.math4.legacy.exception.ConvergenceException;
|
||||||
|
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
|
||||||
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
|
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
|
||||||
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
|
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
|
||||||
import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
|
import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
|
||||||
|
@ -79,7 +80,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
||||||
* @param k the number of clusters to split the data into
|
* @param k the number of clusters to split the data into
|
||||||
*/
|
*/
|
||||||
public KMeansPlusPlusClusterer(final int k) {
|
public KMeansPlusPlusClusterer(final int k) {
|
||||||
this(k, -1);
|
this(k, Integer.MAX_VALUE);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Build a clusterer.
|
/** Build a clusterer.
|
||||||
|
@ -104,8 +105,8 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
||||||
*
|
*
|
||||||
* @param k the number of clusters to split the data into
|
* @param k the number of clusters to split the data into
|
||||||
* @param maxIterations the maximum number of iterations to run the algorithm for.
|
* @param maxIterations the maximum number of iterations to run the algorithm for.
|
||||||
* If negative, no maximum will be used.
|
|
||||||
* @param measure the distance measure to use
|
* @param measure the distance measure to use
|
||||||
|
* @throws NotStrictlyPositiveException if {@code k <= 0}.
|
||||||
*/
|
*/
|
||||||
public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
|
public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
|
||||||
this(k, maxIterations, measure, RandomSource.MT_64.create());
|
this(k, maxIterations, measure, RandomSource.MT_64.create());
|
||||||
|
@ -132,17 +133,27 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
||||||
*
|
*
|
||||||
* @param k the number of clusters to split the data into
|
* @param k the number of clusters to split the data into
|
||||||
* @param maxIterations the maximum number of iterations to run the algorithm for.
|
* @param maxIterations the maximum number of iterations to run the algorithm for.
|
||||||
* If negative, no maximum will be used.
|
|
||||||
* @param measure the distance measure to use
|
* @param measure the distance measure to use
|
||||||
* @param random random generator to use for choosing initial centers
|
* @param random random generator to use for choosing initial centers
|
||||||
* @param emptyStrategy strategy to use for handling empty clusters that
|
* @param emptyStrategy strategy to use for handling empty clusters that
|
||||||
* may appear during algorithm iterations
|
* may appear during algorithm iterations
|
||||||
|
* @throws NotStrictlyPositiveException if {@code k <= 0} or
|
||||||
|
* {@code maxIterations <= 0}.
|
||||||
*/
|
*/
|
||||||
public KMeansPlusPlusClusterer(final int k, final int maxIterations,
|
public KMeansPlusPlusClusterer(final int k,
|
||||||
|
final int maxIterations,
|
||||||
final DistanceMeasure measure,
|
final DistanceMeasure measure,
|
||||||
final UniformRandomProvider random,
|
final UniformRandomProvider random,
|
||||||
final EmptyClusterStrategy emptyStrategy) {
|
final EmptyClusterStrategy emptyStrategy) {
|
||||||
super(measure);
|
super(measure);
|
||||||
|
|
||||||
|
if (k <= 0) {
|
||||||
|
throw new NotStrictlyPositiveException(k);
|
||||||
|
}
|
||||||
|
if (maxIterations <= 0) {
|
||||||
|
throw new NotStrictlyPositiveException(maxIterations);
|
||||||
|
}
|
||||||
|
|
||||||
this.numberOfClusters = k;
|
this.numberOfClusters = k;
|
||||||
this.maxIterations = maxIterations;
|
this.maxIterations = maxIterations;
|
||||||
this.random = random;
|
this.random = random;
|
||||||
|
@ -195,8 +206,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
||||||
assignPointsToClusters(clusters, points, assignments);
|
assignPointsToClusters(clusters, points, assignments);
|
||||||
|
|
||||||
// iterate through updating the centers until we're done
|
// iterate through updating the centers until we're done
|
||||||
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
|
for (int count = 0; count < maxIterations; count++) {
|
||||||
for (int count = 0; count < max; count++) {
|
|
||||||
boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
|
boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
|
||||||
List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
|
List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
|
||||||
int changes = assignPointsToClusters(newClusters, points, assignments);
|
int changes = assignPointsToClusters(newClusters, points, assignments);
|
||||||
|
|
|
@ -133,27 +133,21 @@ public class KMeansPlusPlusClustererTest {
|
||||||
public void testSmallDistances() {
|
public void testSmallDistances() {
|
||||||
// Create a bunch of CloseDoublePoints. Most are identical, but one is different by a
|
// Create a bunch of CloseDoublePoints. Most are identical, but one is different by a
|
||||||
// small distance.
|
// small distance.
|
||||||
int[] repeatedArray = { 0 };
|
final int[] repeatedArray = { 0 };
|
||||||
int[] uniqueArray = { 1 };
|
final int[] uniqueArray = { 1 };
|
||||||
DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
|
final DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
|
||||||
DoublePoint uniquePoint = new DoublePoint(uniqueArray);
|
final DoublePoint uniquePoint = new DoublePoint(uniqueArray);
|
||||||
|
|
||||||
Collection<DoublePoint> points = new ArrayList<>();
|
final Collection<DoublePoint> points = new ArrayList<>();
|
||||||
final int NUM_REPEATED_POINTS = 10 * 1000;
|
final int numRepeated = 10000;
|
||||||
for (int i = 0; i < NUM_REPEATED_POINTS; ++i) {
|
for (int i = 0; i < numRepeated; i++) {
|
||||||
points.add(repeatedPoint);
|
points.add(repeatedPoint);
|
||||||
}
|
}
|
||||||
points.add(uniquePoint);
|
points.add(uniquePoint);
|
||||||
|
|
||||||
// Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to simply choose initial
|
final KMeansPlusPlusClusterer<DoublePoint> clusterer =
|
||||||
// cluster centers).
|
new KMeansPlusPlusClusterer<>(2, 1, new CloseDistance(), random);
|
||||||
final int NUM_CLUSTERS = 2;
|
final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
|
||||||
final int NUM_ITERATIONS = 0;
|
|
||||||
|
|
||||||
KMeansPlusPlusClusterer<DoublePoint> clusterer =
|
|
||||||
new KMeansPlusPlusClusterer<>(NUM_CLUSTERS, NUM_ITERATIONS,
|
|
||||||
new CloseDistance(), random);
|
|
||||||
List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
|
|
||||||
|
|
||||||
// Check that one of the chosen centers is the unique point.
|
// Check that one of the chosen centers is the unique point.
|
||||||
boolean uniquePointIsCenter = false;
|
boolean uniquePointIsCenter = false;
|
||||||
|
|
|
@ -58,9 +58,9 @@ public class MiniBatchKMeansClustererTest {
|
||||||
final UniformRandomProvider rng = RandomSource.MT_64.create();
|
final UniformRandomProvider rng = RandomSource.MT_64.create();
|
||||||
List<DoublePoint> data = generateCircles(rng);
|
List<DoublePoint> data = generateCircles(rng);
|
||||||
KMeansPlusPlusClusterer<DoublePoint> kMeans =
|
KMeansPlusPlusClusterer<DoublePoint> kMeans =
|
||||||
new KMeansPlusPlusClusterer<>(4, -1, DEFAULT_MEASURE, rng);
|
new KMeansPlusPlusClusterer<>(4, Integer.MAX_VALUE, DEFAULT_MEASURE, rng);
|
||||||
MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans =
|
MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans =
|
||||||
new MiniBatchKMeansClusterer<>(4, -1, 100, 3, 300, 10, DEFAULT_MEASURE, rng,
|
new MiniBatchKMeansClusterer<>(4, Integer.MAX_VALUE, 100, 3, 300, 10, DEFAULT_MEASURE, rng,
|
||||||
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
|
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
|
||||||
// Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer
|
// Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
|
|
|
@ -63,7 +63,7 @@ public class CalinskiHarabaszTest {
|
||||||
double actualBestScore = 0.0;
|
double actualBestScore = 0.0;
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
final int k = i + 2;
|
final int k = i + 2;
|
||||||
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd);
|
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd);
|
||||||
List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points);
|
List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points);
|
||||||
double score = evaluator.score(clusters);
|
double score = evaluator.score(clusters);
|
||||||
if (score > expectBestScore) {
|
if (score > expectBestScore) {
|
||||||
|
@ -89,7 +89,7 @@ public class CalinskiHarabaszTest {
|
||||||
double actualBestScore = 0.0;
|
double actualBestScore = 0.0;
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
final int k = i + 2;
|
final int k = i + 2;
|
||||||
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd);
|
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd);
|
||||||
List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points);
|
List<CentroidCluster<DoublePoint>> clusters = kMeans.cluster(points);
|
||||||
double score = evaluator.score(clusters);
|
double score = evaluator.score(clusters);
|
||||||
if (score > expectBestScore) {
|
if (score > expectBestScore) {
|
||||||
|
|
Loading…
Reference in New Issue