MATH-1509: Miscellaneous code style adjustments.

This commit is contained in:
Gilles Sadowski 2020-03-22 14:29:00 +01:00
parent c55d43f382
commit 844ffbeeee
1 changed files with 83 additions and 59 deletions

View File

@ -17,8 +17,6 @@
package org.apache.commons.math4.ml.clustering;
import org.apache.commons.math4.exception.ConvergenceException;
import org.apache.commons.math4.exception.MathIllegalArgumentException;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.util.MathUtils;
@ -36,7 +34,8 @@ import java.util.List;
*
* @param <T> Type of the points to cluster.
*/
public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusPlusClusterer<T> {
public class MiniBatchKMeansClusterer<T extends Clusterable>
extends KMeansPlusPlusClusterer<T> {
/** Batch data size in iteration. */
private final int batchSize;
/** Iteration count of initialize the centers. */
@ -65,19 +64,34 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* @param random Random generator.
* @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
*/
public MiniBatchKMeansClusterer(final int k, final int maxIterations, final int batchSize, final int initIterations,
final int initBatchSize, final int maxNoImprovementTimes,
final DistanceMeasure measure, final UniformRandomProvider random,
public MiniBatchKMeansClusterer(final int k,
final int maxIterations,
final int batchSize,
final int initIterations,
final int initBatchSize,
final int maxNoImprovementTimes,
final DistanceMeasure measure,
final UniformRandomProvider random,
final EmptyClusterStrategy emptyStrategy) {
super(k, maxIterations, measure, random, emptyStrategy);
if (batchSize < 1) throw new NumberIsTooSmallException(batchSize, 1, true);
else this.batchSize = batchSize;
if (initIterations < 1) throw new NumberIsTooSmallException(initIterations, 1, true);
else this.initIterations = initIterations;
if (initBatchSize < 1) throw new NumberIsTooSmallException(initBatchSize, 1, true);
else this.initBatchSize = initBatchSize;
if (maxNoImprovementTimes < 1) throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
else this.maxNoImprovementTimes = maxNoImprovementTimes;
if (batchSize < 1) {
throw new NumberIsTooSmallException(batchSize, 1, true);
}
if (initIterations < 1) {
throw new NumberIsTooSmallException(initIterations, 1, true);
}
if (initBatchSize < 1) {
throw new NumberIsTooSmallException(initBatchSize, 1, true);
}
if (maxNoImprovementTimes < 1) {
throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
}
this.batchSize = batchSize;
this.initIterations = initIterations;
this.initBatchSize = initBatchSize;
this.maxNoImprovementTimes = maxNoImprovementTimes;
}
/**
@ -85,36 +99,37 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
*
* @param points Points to cluster (cannot be {@code null}).
* @return the clusters.
* @throws MathIllegalArgumentException if number of points is
* smaller than the number of clusters.
* @throws org.apache.commons.math4.exception.MathIllegalArgumentException
* if the number of points is smaller than the number of clusters.
*/
@Override
public List<CentroidCluster<T>> cluster(final Collection<T> points) throws MathIllegalArgumentException, ConvergenceException {
// sanity checks
public List<CentroidCluster<T>> cluster(final Collection<T> points) {
// Sanity check.
MathUtils.checkNotNull(points);
// number of clusters has to be smaller or equal the number of data points
if (points.size() < getK()) {
throw new NumberIsTooSmallException(points.size(), getK(), false);
}
final int pointSize = points.size();
final int batchCount = pointSize / batchSize + ((pointSize % batchSize > 0) ? 1 : 0);
final int max = this.getMaxIterations() < 0 ? Integer.MAX_VALUE : (this.getMaxIterations() * batchCount);
final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
final int max = getMaxIterations() < 0 ?
Integer.MAX_VALUE :
getMaxIterations() * batchCount;
List<CentroidCluster<T>> clusters = initialCenters(points);
// Loop execute the mini batch steps until reach the max loop times, or cannot improvement anymore.
final MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator();
for (int i = 0; i < max; i++) {
//Clear points in clusters
clearClustersPoints(clusters);
//Random sampling a mini batch of points.
final List<T> batchPoints = randomMiniBatch(points, batchSize);
// Processing the mini batch training step
final List<T> batchPoints = randomMiniBatch(points, batchSize, getRandomGenerator());
// Training step.
final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
final double squareDistance = pair.getFirst();
clusters = pair.getSecond();
// Evaluate the training can finished early.
if (evaluator.convergence(squareDistance, pointSize)) break;
// Check whether the training can finished early.
if (evaluator.converge(squareDistance, pointSize)) {
break;
}
}
// Add every mini batch points to their nearest cluster.
@ -122,6 +137,7 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
for (final T point : points) {
addToNearestCentroidCluster(point, clusters);
}
return clusters;
}
@ -143,8 +159,7 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* @param clusters Centers of the clusters.
* @return the squared distance of all the batch points to the nearest center.
*/
private Pair<Double, List<CentroidCluster<T>>> step(
final List<T> batchPoints,
private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
final List<CentroidCluster<T>> clusters) {
// Add every mini batch points to their nearest cluster.
for (final T point : batchPoints) {
@ -157,6 +172,7 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
final double d = addToNearestCentroidCluster(point, newClusters);
squareDistance += d * d;
}
return new Pair<>(squareDistance, newClusters);
}
@ -164,14 +180,16 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* Extract a random subset of the given {@code points}.
*
* @param points All the points to cluster.
* @param batchSize Mini batch size.
* @return mini batch of all the points.
* @param n Mini batch size.
* @param rng Random generator.
* @return {@code n} points.
*/
private List<T> randomMiniBatch(final Collection<T> points,
final int batchSize) {
final int n,
final UniformRandomProvider rng) {
final ArrayList<T> list = new ArrayList<>(points);
ListSampler.shuffle(getRandomGenerator(), list);
return list.subList(0, batchSize);
ListSampler.shuffle(rng, list);
return list.subList(0, n);
}
/**
@ -182,12 +200,15 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
*/
private List<CentroidCluster<T>> initialCenters(final Collection<T> points) {
final List<T> validPoints = initBatchSize < points.size() ?
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points);
randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
new ArrayList<>(points);
double nearestSquareDistance = Double.POSITIVE_INFINITY;
List<CentroidCluster<T>> bestCenters = null;
for (int i = 0; i < initIterations; i++) {
final List<T> initialPoints = (initBatchSize < points.size()) ?
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points);
randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
new ArrayList<>(points);
final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
final double squareDistance = pair.getFirst();
@ -208,10 +229,12 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* @param clusters Clusters.
* @return the distance between point and the closest center.
*/
private double addToNearestCentroidCluster(final T point, final List<CentroidCluster<T>> clusters) {
private double addToNearestCentroidCluster(final T point,
final List<CentroidCluster<T>> clusters) {
double minDistance = Double.POSITIVE_INFINITY;
CentroidCluster<T> closestCentroidCluster = null;
// Iterate clusters and find out closest cluster to the point
// Find cluster closest to the point.
for (CentroidCluster<T> centroidCluster : clusters) {
final double distance = distance(point, centroidCluster.getCenter());
if (distance < minDistance) {
@ -219,8 +242,9 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
closestCentroidCluster = centroidCluster;
}
}
assert closestCentroidCluster != null;
MathUtils.checkNotNull(closestCentroidCluster);
closestCentroidCluster.addPoint(point);
return minDistance;
}
@ -229,42 +253,42 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* The evaluator checks whether improvement occurred during the
* {@link #maxNoImprovementTimes allowed number of successive iterations}.
*/
class MiniBatchImprovementEvaluator {
private class MiniBatchImprovementEvaluator {
/** Missing doc. */
private Double ewaInertia = null;
private double ewaInertia = Double.NaN;
/** Missing doc. */
private double ewaInertiaMin = Double.POSITIVE_INFINITY;
/** Missing doc. */
private int noImprovementTimes = 0;
/**
* Evaluate whether the iteration should finish where square has no improvement for appointed times
* Stopping criterion.
*
* @param squareDistance the total square distance of the mini batch points to their nearest center.
* @param pointSize size of the the data points.
* @return true if no improvement for appointed times, otherwise false
* @param squareDistance Total square distance from the batch points
* to their nearest center.
* @param pointSize Number of data points.
* @return {@code true} if no improvement was made after the allowed
* number of iterations, {@code false} otherwise.
*/
public boolean convergence(final double squareDistance, final int pointSize) {
public boolean converge(final double squareDistance,
final int pointSize) {
final double batchInertia = squareDistance / batchSize;
if (ewaInertia == null) {
if (Double.isNaN(ewaInertia)) {
ewaInertia = batchInertia;
} else {
// Refer to sklearn, pointSize+1 maybe intent to avoid the div/0 error,
// but java double does not have a div/0 error
double alpha = batchSize * 2.0 / (pointSize + 1);
alpha = Math.min(alpha, 1.0);
final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
}
// Improved
if (ewaInertia < ewaInertiaMin) {
// Improved.
noImprovementTimes = 0;
ewaInertiaMin = ewaInertia;
} else {
// No improvement
noImprovementTimes++;
// No improvement.
++noImprovementTimes;
}
// Has no improvement continuous for many times
return noImprovementTimes >= maxNoImprovementTimes;
}
}