MATH-1509: Miscellaneous code style adjustments.

This commit is contained in:
Gilles Sadowski 2020-03-22 14:29:00 +01:00
parent c55d43f382
commit 844ffbeeee
1 changed files with 83 additions and 59 deletions

View File

@ -17,8 +17,6 @@
package org.apache.commons.math4.ml.clustering; package org.apache.commons.math4.ml.clustering;
import org.apache.commons.math4.exception.ConvergenceException;
import org.apache.commons.math4.exception.MathIllegalArgumentException;
import org.apache.commons.math4.exception.NumberIsTooSmallException; import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.ml.distance.DistanceMeasure; import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.util.MathUtils; import org.apache.commons.math4.util.MathUtils;
@ -36,7 +34,8 @@ import java.util.List;
* *
* @param <T> Type of the points to cluster. * @param <T> Type of the points to cluster.
*/ */
public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusPlusClusterer<T> { public class MiniBatchKMeansClusterer<T extends Clusterable>
extends KMeansPlusPlusClusterer<T> {
/** Batch data size in iteration. */ /** Batch data size in iteration. */
private final int batchSize; private final int batchSize;
/** Iteration count of initialize the centers. */ /** Iteration count of initialize the centers. */
@ -65,19 +64,34 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* @param random Random generator. * @param random Random generator.
* @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations. * @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
*/ */
public MiniBatchKMeansClusterer(final int k, final int maxIterations, final int batchSize, final int initIterations, public MiniBatchKMeansClusterer(final int k,
final int initBatchSize, final int maxNoImprovementTimes, final int maxIterations,
final DistanceMeasure measure, final UniformRandomProvider random, final int batchSize,
final int initIterations,
final int initBatchSize,
final int maxNoImprovementTimes,
final DistanceMeasure measure,
final UniformRandomProvider random,
final EmptyClusterStrategy emptyStrategy) { final EmptyClusterStrategy emptyStrategy) {
super(k, maxIterations, measure, random, emptyStrategy); super(k, maxIterations, measure, random, emptyStrategy);
if (batchSize < 1) throw new NumberIsTooSmallException(batchSize, 1, true);
else this.batchSize = batchSize; if (batchSize < 1) {
if (initIterations < 1) throw new NumberIsTooSmallException(initIterations, 1, true); throw new NumberIsTooSmallException(batchSize, 1, true);
else this.initIterations = initIterations; }
if (initBatchSize < 1) throw new NumberIsTooSmallException(initBatchSize, 1, true); if (initIterations < 1) {
else this.initBatchSize = initBatchSize; throw new NumberIsTooSmallException(initIterations, 1, true);
if (maxNoImprovementTimes < 1) throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true); }
else this.maxNoImprovementTimes = maxNoImprovementTimes; if (initBatchSize < 1) {
throw new NumberIsTooSmallException(initBatchSize, 1, true);
}
if (maxNoImprovementTimes < 1) {
throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
}
this.batchSize = batchSize;
this.initIterations = initIterations;
this.initBatchSize = initBatchSize;
this.maxNoImprovementTimes = maxNoImprovementTimes;
} }
/** /**
@ -85,43 +99,45 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* *
* @param points Points to cluster (cannot be {@code null}). * @param points Points to cluster (cannot be {@code null}).
* @return the clusters. * @return the clusters.
* @throws MathIllegalArgumentException if number of points is * @throws org.apache.commons.math4.exception.MathIllegalArgumentException
* smaller than the number of clusters. * if the number of points is smaller than the number of clusters.
*/ */
@Override @Override
public List<CentroidCluster<T>> cluster(final Collection<T> points) throws MathIllegalArgumentException, ConvergenceException { public List<CentroidCluster<T>> cluster(final Collection<T> points) {
// sanity checks // Sanity check.
MathUtils.checkNotNull(points); MathUtils.checkNotNull(points);
// number of clusters has to be smaller or equal the number of data points
if (points.size() < getK()) { if (points.size() < getK()) {
throw new NumberIsTooSmallException(points.size(), getK(), false); throw new NumberIsTooSmallException(points.size(), getK(), false);
} }
final int pointSize = points.size(); final int pointSize = points.size();
final int batchCount = pointSize / batchSize + ((pointSize % batchSize > 0) ? 1 : 0); final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
final int max = this.getMaxIterations() < 0 ? Integer.MAX_VALUE : (this.getMaxIterations() * batchCount); final int max = getMaxIterations() < 0 ?
Integer.MAX_VALUE :
getMaxIterations() * batchCount;
List<CentroidCluster<T>> clusters = initialCenters(points); List<CentroidCluster<T>> clusters = initialCenters(points);
// Loop execute the mini batch steps until reach the max loop times, or cannot improvement anymore.
final MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator(); final MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator();
for (int i = 0; i < max; i++) { for (int i = 0; i < max; i++) {
//Clear points in clusters
clearClustersPoints(clusters); clearClustersPoints(clusters);
//Random sampling a mini batch of points. final List<T> batchPoints = randomMiniBatch(points, batchSize, getRandomGenerator());
final List<T> batchPoints = randomMiniBatch(points, batchSize); // Training step.
// Processing the mini batch training step
final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters); final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
final double squareDistance = pair.getFirst(); final double squareDistance = pair.getFirst();
clusters = pair.getSecond(); clusters = pair.getSecond();
// Evaluate the training can finished early. // Check whether the training can finished early.
if (evaluator.convergence(squareDistance, pointSize)) break; if (evaluator.converge(squareDistance, pointSize)) {
break;
}
} }
//Add every mini batch points to their nearest cluster. // Add every mini batch points to their nearest cluster.
clearClustersPoints(clusters); clearClustersPoints(clusters);
for (final T point : points) { for (final T point : points) {
addToNearestCentroidCluster(point, clusters); addToNearestCentroidCluster(point, clusters);
} }
return clusters; return clusters;
} }
@ -143,9 +159,8 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* @param clusters Centers of the clusters. * @param clusters Centers of the clusters.
* @return the squared distance of all the batch points to the nearest center. * @return the squared distance of all the batch points to the nearest center.
*/ */
private Pair<Double, List<CentroidCluster<T>>> step( private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
final List<T> batchPoints, final List<CentroidCluster<T>> clusters) {
final List<CentroidCluster<T>> clusters) {
// Add every mini batch points to their nearest cluster. // Add every mini batch points to their nearest cluster.
for (final T point : batchPoints) { for (final T point : batchPoints) {
addToNearestCentroidCluster(point, clusters); addToNearestCentroidCluster(point, clusters);
@ -157,6 +172,7 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
final double d = addToNearestCentroidCluster(point, newClusters); final double d = addToNearestCentroidCluster(point, newClusters);
squareDistance += d * d; squareDistance += d * d;
} }
return new Pair<>(squareDistance, newClusters); return new Pair<>(squareDistance, newClusters);
} }
@ -164,14 +180,16 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* Extract a random subset of the given {@code points}. * Extract a random subset of the given {@code points}.
* *
* @param points All the points to cluster. * @param points All the points to cluster.
* @param batchSize Mini batch size. * @param n Mini batch size.
* @return mini batch of all the points. * @param rng Random generator.
* @return {@code n} points.
*/ */
private List<T> randomMiniBatch(final Collection<T> points, private List<T> randomMiniBatch(final Collection<T> points,
final int batchSize) { final int n,
final UniformRandomProvider rng) {
final ArrayList<T> list = new ArrayList<>(points); final ArrayList<T> list = new ArrayList<>(points);
ListSampler.shuffle(getRandomGenerator(), list); ListSampler.shuffle(rng, list);
return list.subList(0, batchSize); return list.subList(0, n);
} }
/** /**
@ -182,12 +200,15 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
*/ */
private List<CentroidCluster<T>> initialCenters(final Collection<T> points) { private List<CentroidCluster<T>> initialCenters(final Collection<T> points) {
final List<T> validPoints = initBatchSize < points.size() ? final List<T> validPoints = initBatchSize < points.size() ?
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points); randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
new ArrayList<>(points);
double nearestSquareDistance = Double.POSITIVE_INFINITY; double nearestSquareDistance = Double.POSITIVE_INFINITY;
List<CentroidCluster<T>> bestCenters = null; List<CentroidCluster<T>> bestCenters = null;
for (int i = 0; i < initIterations; i++) { for (int i = 0; i < initIterations; i++) {
final List<T> initialPoints = (initBatchSize < points.size()) ? final List<T> initialPoints = (initBatchSize < points.size()) ?
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points); randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
new ArrayList<>(points);
final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK()); final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters); final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
final double squareDistance = pair.getFirst(); final double squareDistance = pair.getFirst();
@ -208,10 +229,12 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* @param clusters Clusters. * @param clusters Clusters.
* @return the distance between point and the closest center. * @return the distance between point and the closest center.
*/ */
private double addToNearestCentroidCluster(final T point, final List<CentroidCluster<T>> clusters) { private double addToNearestCentroidCluster(final T point,
final List<CentroidCluster<T>> clusters) {
double minDistance = Double.POSITIVE_INFINITY; double minDistance = Double.POSITIVE_INFINITY;
CentroidCluster<T> closestCentroidCluster = null; CentroidCluster<T> closestCentroidCluster = null;
// Iterate clusters and find out closest cluster to the point
// Find cluster closest to the point.
for (CentroidCluster<T> centroidCluster : clusters) { for (CentroidCluster<T> centroidCluster : clusters) {
final double distance = distance(point, centroidCluster.getCenter()); final double distance = distance(point, centroidCluster.getCenter());
if (distance < minDistance) { if (distance < minDistance) {
@ -219,8 +242,9 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
closestCentroidCluster = centroidCluster; closestCentroidCluster = centroidCluster;
} }
} }
assert closestCentroidCluster != null; MathUtils.checkNotNull(closestCentroidCluster);
closestCentroidCluster.addPoint(point); closestCentroidCluster.addPoint(point);
return minDistance; return minDistance;
} }
@ -229,42 +253,42 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
* The evaluator checks whether improvement occurred during the * The evaluator checks whether improvement occurred during the
* {@link #maxNoImprovementTimes allowed number of successive iterations}. * {@link #maxNoImprovementTimes allowed number of successive iterations}.
*/ */
class MiniBatchImprovementEvaluator { private class MiniBatchImprovementEvaluator {
/** Missing doc. */ /** Missing doc. */
private Double ewaInertia = null; private double ewaInertia = Double.NaN;
/** Missing doc. */ /** Missing doc. */
private double ewaInertiaMin = Double.POSITIVE_INFINITY; private double ewaInertiaMin = Double.POSITIVE_INFINITY;
/** Missing doc. */ /** Missing doc. */
private int noImprovementTimes = 0; private int noImprovementTimes = 0;
/** /**
* Evaluate whether the iteration should finish where square has no improvement for appointed times * Stopping criterion.
* *
* @param squareDistance the total square distance of the mini batch points to their nearest center. * @param squareDistance Total square distance from the batch points
* @param pointSize size of the the data points. * to their nearest center.
* @return true if no improvement for appointed times, otherwise false * @param pointSize Number of data points.
* @return {@code true} if no improvement was made after the allowed
* number of iterations, {@code false} otherwise.
*/ */
public boolean convergence(final double squareDistance, final int pointSize) { public boolean converge(final double squareDistance,
final int pointSize) {
final double batchInertia = squareDistance / batchSize; final double batchInertia = squareDistance / batchSize;
if (ewaInertia == null) { if (Double.isNaN(ewaInertia)) {
ewaInertia = batchInertia; ewaInertia = batchInertia;
} else { } else {
// Refer to sklearn, pointSize+1 maybe intent to avoid the div/0 error, final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
// but java double does not have a div/0 error
double alpha = batchSize * 2.0 / (pointSize + 1);
alpha = Math.min(alpha, 1.0);
ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha; ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
} }
// Improved
if (ewaInertia < ewaInertiaMin) { if (ewaInertia < ewaInertiaMin) {
// Improved.
noImprovementTimes = 0; noImprovementTimes = 0;
ewaInertiaMin = ewaInertia; ewaInertiaMin = ewaInertia;
} else { } else {
// No improvement // No improvement.
noImprovementTimes++; ++noImprovementTimes;
} }
// Has no improvement continuous for many times
return noImprovementTimes >= maxNoImprovementTimes; return noImprovementTimes >= maxNoImprovementTimes;
} }
} }