MATH-1509: Miscellaneous code style adjustments.
This commit is contained in:
parent
c55d43f382
commit
844ffbeeee
|
@ -17,8 +17,6 @@
|
|||
|
||||
package org.apache.commons.math4.ml.clustering;
|
||||
|
||||
import org.apache.commons.math4.exception.ConvergenceException;
|
||||
import org.apache.commons.math4.exception.MathIllegalArgumentException;
|
||||
import org.apache.commons.math4.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math4.util.MathUtils;
|
||||
|
@ -36,7 +34,8 @@ import java.util.List;
|
|||
*
|
||||
* @param <T> Type of the points to cluster.
|
||||
*/
|
||||
public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusPlusClusterer<T> {
|
||||
public class MiniBatchKMeansClusterer<T extends Clusterable>
|
||||
extends KMeansPlusPlusClusterer<T> {
|
||||
/** Batch data size in iteration. */
|
||||
private final int batchSize;
|
||||
/** Iteration count of initialize the centers. */
|
||||
|
@ -65,19 +64,34 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
|||
* @param random Random generator.
|
||||
* @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
|
||||
*/
|
||||
public MiniBatchKMeansClusterer(final int k, final int maxIterations, final int batchSize, final int initIterations,
|
||||
final int initBatchSize, final int maxNoImprovementTimes,
|
||||
final DistanceMeasure measure, final UniformRandomProvider random,
|
||||
public MiniBatchKMeansClusterer(final int k,
|
||||
final int maxIterations,
|
||||
final int batchSize,
|
||||
final int initIterations,
|
||||
final int initBatchSize,
|
||||
final int maxNoImprovementTimes,
|
||||
final DistanceMeasure measure,
|
||||
final UniformRandomProvider random,
|
||||
final EmptyClusterStrategy emptyStrategy) {
|
||||
super(k, maxIterations, measure, random, emptyStrategy);
|
||||
if (batchSize < 1) throw new NumberIsTooSmallException(batchSize, 1, true);
|
||||
else this.batchSize = batchSize;
|
||||
if (initIterations < 1) throw new NumberIsTooSmallException(initIterations, 1, true);
|
||||
else this.initIterations = initIterations;
|
||||
if (initBatchSize < 1) throw new NumberIsTooSmallException(initBatchSize, 1, true);
|
||||
else this.initBatchSize = initBatchSize;
|
||||
if (maxNoImprovementTimes < 1) throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
|
||||
else this.maxNoImprovementTimes = maxNoImprovementTimes;
|
||||
|
||||
if (batchSize < 1) {
|
||||
throw new NumberIsTooSmallException(batchSize, 1, true);
|
||||
}
|
||||
if (initIterations < 1) {
|
||||
throw new NumberIsTooSmallException(initIterations, 1, true);
|
||||
}
|
||||
if (initBatchSize < 1) {
|
||||
throw new NumberIsTooSmallException(initBatchSize, 1, true);
|
||||
}
|
||||
if (maxNoImprovementTimes < 1) {
|
||||
throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
|
||||
}
|
||||
|
||||
this.batchSize = batchSize;
|
||||
this.initIterations = initIterations;
|
||||
this.initBatchSize = initBatchSize;
|
||||
this.maxNoImprovementTimes = maxNoImprovementTimes;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -85,36 +99,37 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
|||
*
|
||||
* @param points Points to cluster (cannot be {@code null}).
|
||||
* @return the clusters.
|
||||
* @throws MathIllegalArgumentException if number of points is
|
||||
* smaller than the number of clusters.
|
||||
* @throws org.apache.commons.math4.exception.MathIllegalArgumentException
|
||||
* if the number of points is smaller than the number of clusters.
|
||||
*/
|
||||
@Override
|
||||
public List<CentroidCluster<T>> cluster(final Collection<T> points) throws MathIllegalArgumentException, ConvergenceException {
|
||||
// sanity checks
|
||||
public List<CentroidCluster<T>> cluster(final Collection<T> points) {
|
||||
// Sanity check.
|
||||
MathUtils.checkNotNull(points);
|
||||
|
||||
// number of clusters has to be smaller or equal the number of data points
|
||||
if (points.size() < getK()) {
|
||||
throw new NumberIsTooSmallException(points.size(), getK(), false);
|
||||
}
|
||||
|
||||
final int pointSize = points.size();
|
||||
final int batchCount = pointSize / batchSize + ((pointSize % batchSize > 0) ? 1 : 0);
|
||||
final int max = this.getMaxIterations() < 0 ? Integer.MAX_VALUE : (this.getMaxIterations() * batchCount);
|
||||
final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
|
||||
final int max = getMaxIterations() < 0 ?
|
||||
Integer.MAX_VALUE :
|
||||
getMaxIterations() * batchCount;
|
||||
|
||||
List<CentroidCluster<T>> clusters = initialCenters(points);
|
||||
// Loop execute the mini batch steps until reach the max loop times, or cannot improvement anymore.
|
||||
|
||||
final MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator();
|
||||
for (int i = 0; i < max; i++) {
|
||||
//Clear points in clusters
|
||||
clearClustersPoints(clusters);
|
||||
//Random sampling a mini batch of points.
|
||||
final List<T> batchPoints = randomMiniBatch(points, batchSize);
|
||||
// Processing the mini batch training step
|
||||
final List<T> batchPoints = randomMiniBatch(points, batchSize, getRandomGenerator());
|
||||
// Training step.
|
||||
final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
|
||||
final double squareDistance = pair.getFirst();
|
||||
clusters = pair.getSecond();
|
||||
// Evaluate the training can finished early.
|
||||
if (evaluator.convergence(squareDistance, pointSize)) break;
|
||||
// Check whether the training can finished early.
|
||||
if (evaluator.converge(squareDistance, pointSize)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Add every mini batch points to their nearest cluster.
|
||||
|
@ -122,6 +137,7 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
|||
for (final T point : points) {
|
||||
addToNearestCentroidCluster(point, clusters);
|
||||
}
|
||||
|
||||
return clusters;
|
||||
}
|
||||
|
||||
|
@ -143,8 +159,7 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
|||
* @param clusters Centers of the clusters.
|
||||
* @return the squared distance of all the batch points to the nearest center.
|
||||
*/
|
||||
private Pair<Double, List<CentroidCluster<T>>> step(
|
||||
final List<T> batchPoints,
|
||||
private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
|
||||
final List<CentroidCluster<T>> clusters) {
|
||||
// Add every mini batch points to their nearest cluster.
|
||||
for (final T point : batchPoints) {
|
||||
|
@ -157,6 +172,7 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
|||
final double d = addToNearestCentroidCluster(point, newClusters);
|
||||
squareDistance += d * d;
|
||||
}
|
||||
|
||||
return new Pair<>(squareDistance, newClusters);
|
||||
}
|
||||
|
||||
|
@ -164,14 +180,16 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
|||
* Extract a random subset of the given {@code points}.
|
||||
*
|
||||
* @param points All the points to cluster.
|
||||
* @param batchSize Mini batch size.
|
||||
* @return mini batch of all the points.
|
||||
* @param n Mini batch size.
|
||||
* @param rng Random generator.
|
||||
* @return {@code n} points.
|
||||
*/
|
||||
private List<T> randomMiniBatch(final Collection<T> points,
|
||||
final int batchSize) {
|
||||
final int n,
|
||||
final UniformRandomProvider rng) {
|
||||
final ArrayList<T> list = new ArrayList<>(points);
|
||||
ListSampler.shuffle(getRandomGenerator(), list);
|
||||
return list.subList(0, batchSize);
|
||||
ListSampler.shuffle(rng, list);
|
||||
return list.subList(0, n);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -182,12 +200,15 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
|||
*/
|
||||
private List<CentroidCluster<T>> initialCenters(final Collection<T> points) {
|
||||
final List<T> validPoints = initBatchSize < points.size() ?
|
||||
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points);
|
||||
randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
|
||||
new ArrayList<>(points);
|
||||
double nearestSquareDistance = Double.POSITIVE_INFINITY;
|
||||
List<CentroidCluster<T>> bestCenters = null;
|
||||
|
||||
for (int i = 0; i < initIterations; i++) {
|
||||
final List<T> initialPoints = (initBatchSize < points.size()) ?
|
||||
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points);
|
||||
randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
|
||||
new ArrayList<>(points);
|
||||
final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
|
||||
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
|
||||
final double squareDistance = pair.getFirst();
|
||||
|
@ -208,10 +229,12 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
|||
* @param clusters Clusters.
|
||||
* @return the distance between point and the closest center.
|
||||
*/
|
||||
private double addToNearestCentroidCluster(final T point, final List<CentroidCluster<T>> clusters) {
|
||||
private double addToNearestCentroidCluster(final T point,
|
||||
final List<CentroidCluster<T>> clusters) {
|
||||
double minDistance = Double.POSITIVE_INFINITY;
|
||||
CentroidCluster<T> closestCentroidCluster = null;
|
||||
// Iterate clusters and find out closest cluster to the point
|
||||
|
||||
// Find cluster closest to the point.
|
||||
for (CentroidCluster<T> centroidCluster : clusters) {
|
||||
final double distance = distance(point, centroidCluster.getCenter());
|
||||
if (distance < minDistance) {
|
||||
|
@ -219,8 +242,9 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
|||
closestCentroidCluster = centroidCluster;
|
||||
}
|
||||
}
|
||||
assert closestCentroidCluster != null;
|
||||
MathUtils.checkNotNull(closestCentroidCluster);
|
||||
closestCentroidCluster.addPoint(point);
|
||||
|
||||
return minDistance;
|
||||
}
|
||||
|
||||
|
@ -229,42 +253,42 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
|||
* The evaluator checks whether improvement occurred during the
|
||||
* {@link #maxNoImprovementTimes allowed number of successive iterations}.
|
||||
*/
|
||||
class MiniBatchImprovementEvaluator {
|
||||
private class MiniBatchImprovementEvaluator {
|
||||
/** Missing doc. */
|
||||
private Double ewaInertia = null;
|
||||
private double ewaInertia = Double.NaN;
|
||||
/** Missing doc. */
|
||||
private double ewaInertiaMin = Double.POSITIVE_INFINITY;
|
||||
/** Missing doc. */
|
||||
private int noImprovementTimes = 0;
|
||||
|
||||
/**
|
||||
* Evaluate whether the iteration should finish where square has no improvement for appointed times
|
||||
* Stopping criterion.
|
||||
*
|
||||
* @param squareDistance the total square distance of the mini batch points to their nearest center.
|
||||
* @param pointSize size of the the data points.
|
||||
* @return true if no improvement for appointed times, otherwise false
|
||||
* @param squareDistance Total square distance from the batch points
|
||||
* to their nearest center.
|
||||
* @param pointSize Number of data points.
|
||||
* @return {@code true} if no improvement was made after the allowed
|
||||
* number of iterations, {@code false} otherwise.
|
||||
*/
|
||||
public boolean convergence(final double squareDistance, final int pointSize) {
|
||||
public boolean converge(final double squareDistance,
|
||||
final int pointSize) {
|
||||
final double batchInertia = squareDistance / batchSize;
|
||||
if (ewaInertia == null) {
|
||||
if (Double.isNaN(ewaInertia)) {
|
||||
ewaInertia = batchInertia;
|
||||
} else {
|
||||
// Refer to sklearn, pointSize+1 maybe intent to avoid the div/0 error,
|
||||
// but java double does not have a div/0 error
|
||||
double alpha = batchSize * 2.0 / (pointSize + 1);
|
||||
alpha = Math.min(alpha, 1.0);
|
||||
final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
|
||||
ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
|
||||
}
|
||||
|
||||
// Improved
|
||||
if (ewaInertia < ewaInertiaMin) {
|
||||
// Improved.
|
||||
noImprovementTimes = 0;
|
||||
ewaInertiaMin = ewaInertia;
|
||||
} else {
|
||||
// No improvement
|
||||
noImprovementTimes++;
|
||||
// No improvement.
|
||||
++noImprovementTimes;
|
||||
}
|
||||
// Has no improvement continuous for many times
|
||||
|
||||
return noImprovementTimes >= maxNoImprovementTimes;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue