MATH-1509: Miscellaneous code style adjustments.
This commit is contained in:
parent
c55d43f382
commit
844ffbeeee
|
@ -17,8 +17,6 @@
|
||||||
|
|
||||||
package org.apache.commons.math4.ml.clustering;
|
package org.apache.commons.math4.ml.clustering;
|
||||||
|
|
||||||
import org.apache.commons.math4.exception.ConvergenceException;
|
|
||||||
import org.apache.commons.math4.exception.MathIllegalArgumentException;
|
|
||||||
import org.apache.commons.math4.exception.NumberIsTooSmallException;
|
import org.apache.commons.math4.exception.NumberIsTooSmallException;
|
||||||
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
||||||
import org.apache.commons.math4.util.MathUtils;
|
import org.apache.commons.math4.util.MathUtils;
|
||||||
|
@ -36,7 +34,8 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
* @param <T> Type of the points to cluster.
|
* @param <T> Type of the points to cluster.
|
||||||
*/
|
*/
|
||||||
public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusPlusClusterer<T> {
|
public class MiniBatchKMeansClusterer<T extends Clusterable>
|
||||||
|
extends KMeansPlusPlusClusterer<T> {
|
||||||
/** Batch data size in iteration. */
|
/** Batch data size in iteration. */
|
||||||
private final int batchSize;
|
private final int batchSize;
|
||||||
/** Iteration count of initialize the centers. */
|
/** Iteration count of initialize the centers. */
|
||||||
|
@ -65,19 +64,34 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
||||||
* @param random Random generator.
|
* @param random Random generator.
|
||||||
* @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
|
* @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
|
||||||
*/
|
*/
|
||||||
public MiniBatchKMeansClusterer(final int k, final int maxIterations, final int batchSize, final int initIterations,
|
public MiniBatchKMeansClusterer(final int k,
|
||||||
final int initBatchSize, final int maxNoImprovementTimes,
|
final int maxIterations,
|
||||||
final DistanceMeasure measure, final UniformRandomProvider random,
|
final int batchSize,
|
||||||
|
final int initIterations,
|
||||||
|
final int initBatchSize,
|
||||||
|
final int maxNoImprovementTimes,
|
||||||
|
final DistanceMeasure measure,
|
||||||
|
final UniformRandomProvider random,
|
||||||
final EmptyClusterStrategy emptyStrategy) {
|
final EmptyClusterStrategy emptyStrategy) {
|
||||||
super(k, maxIterations, measure, random, emptyStrategy);
|
super(k, maxIterations, measure, random, emptyStrategy);
|
||||||
if (batchSize < 1) throw new NumberIsTooSmallException(batchSize, 1, true);
|
|
||||||
else this.batchSize = batchSize;
|
if (batchSize < 1) {
|
||||||
if (initIterations < 1) throw new NumberIsTooSmallException(initIterations, 1, true);
|
throw new NumberIsTooSmallException(batchSize, 1, true);
|
||||||
else this.initIterations = initIterations;
|
}
|
||||||
if (initBatchSize < 1) throw new NumberIsTooSmallException(initBatchSize, 1, true);
|
if (initIterations < 1) {
|
||||||
else this.initBatchSize = initBatchSize;
|
throw new NumberIsTooSmallException(initIterations, 1, true);
|
||||||
if (maxNoImprovementTimes < 1) throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
|
}
|
||||||
else this.maxNoImprovementTimes = maxNoImprovementTimes;
|
if (initBatchSize < 1) {
|
||||||
|
throw new NumberIsTooSmallException(initBatchSize, 1, true);
|
||||||
|
}
|
||||||
|
if (maxNoImprovementTimes < 1) {
|
||||||
|
throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.batchSize = batchSize;
|
||||||
|
this.initIterations = initIterations;
|
||||||
|
this.initBatchSize = initBatchSize;
|
||||||
|
this.maxNoImprovementTimes = maxNoImprovementTimes;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -85,43 +99,45 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
||||||
*
|
*
|
||||||
* @param points Points to cluster (cannot be {@code null}).
|
* @param points Points to cluster (cannot be {@code null}).
|
||||||
* @return the clusters.
|
* @return the clusters.
|
||||||
* @throws MathIllegalArgumentException if number of points is
|
* @throws org.apache.commons.math4.exception.MathIllegalArgumentException
|
||||||
* smaller than the number of clusters.
|
* if the number of points is smaller than the number of clusters.
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<CentroidCluster<T>> cluster(final Collection<T> points) throws MathIllegalArgumentException, ConvergenceException {
|
public List<CentroidCluster<T>> cluster(final Collection<T> points) {
|
||||||
// sanity checks
|
// Sanity check.
|
||||||
MathUtils.checkNotNull(points);
|
MathUtils.checkNotNull(points);
|
||||||
|
|
||||||
// number of clusters has to be smaller or equal the number of data points
|
|
||||||
if (points.size() < getK()) {
|
if (points.size() < getK()) {
|
||||||
throw new NumberIsTooSmallException(points.size(), getK(), false);
|
throw new NumberIsTooSmallException(points.size(), getK(), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
final int pointSize = points.size();
|
final int pointSize = points.size();
|
||||||
final int batchCount = pointSize / batchSize + ((pointSize % batchSize > 0) ? 1 : 0);
|
final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
|
||||||
final int max = this.getMaxIterations() < 0 ? Integer.MAX_VALUE : (this.getMaxIterations() * batchCount);
|
final int max = getMaxIterations() < 0 ?
|
||||||
|
Integer.MAX_VALUE :
|
||||||
|
getMaxIterations() * batchCount;
|
||||||
|
|
||||||
List<CentroidCluster<T>> clusters = initialCenters(points);
|
List<CentroidCluster<T>> clusters = initialCenters(points);
|
||||||
// Loop execute the mini batch steps until reach the max loop times, or cannot improvement anymore.
|
|
||||||
final MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator();
|
final MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator();
|
||||||
for (int i = 0; i < max; i++) {
|
for (int i = 0; i < max; i++) {
|
||||||
//Clear points in clusters
|
|
||||||
clearClustersPoints(clusters);
|
clearClustersPoints(clusters);
|
||||||
//Random sampling a mini batch of points.
|
final List<T> batchPoints = randomMiniBatch(points, batchSize, getRandomGenerator());
|
||||||
final List<T> batchPoints = randomMiniBatch(points, batchSize);
|
// Training step.
|
||||||
// Processing the mini batch training step
|
|
||||||
final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
|
final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
|
||||||
final double squareDistance = pair.getFirst();
|
final double squareDistance = pair.getFirst();
|
||||||
clusters = pair.getSecond();
|
clusters = pair.getSecond();
|
||||||
// Evaluate the training can finished early.
|
// Check whether the training can finished early.
|
||||||
if (evaluator.convergence(squareDistance, pointSize)) break;
|
if (evaluator.converge(squareDistance, pointSize)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Add every mini batch points to their nearest cluster.
|
// Add every mini batch points to their nearest cluster.
|
||||||
clearClustersPoints(clusters);
|
clearClustersPoints(clusters);
|
||||||
for (final T point : points) {
|
for (final T point : points) {
|
||||||
addToNearestCentroidCluster(point, clusters);
|
addToNearestCentroidCluster(point, clusters);
|
||||||
}
|
}
|
||||||
|
|
||||||
return clusters;
|
return clusters;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,9 +159,8 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
||||||
* @param clusters Centers of the clusters.
|
* @param clusters Centers of the clusters.
|
||||||
* @return the squared distance of all the batch points to the nearest center.
|
* @return the squared distance of all the batch points to the nearest center.
|
||||||
*/
|
*/
|
||||||
private Pair<Double, List<CentroidCluster<T>>> step(
|
private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
|
||||||
final List<T> batchPoints,
|
final List<CentroidCluster<T>> clusters) {
|
||||||
final List<CentroidCluster<T>> clusters) {
|
|
||||||
// Add every mini batch points to their nearest cluster.
|
// Add every mini batch points to their nearest cluster.
|
||||||
for (final T point : batchPoints) {
|
for (final T point : batchPoints) {
|
||||||
addToNearestCentroidCluster(point, clusters);
|
addToNearestCentroidCluster(point, clusters);
|
||||||
|
@ -157,6 +172,7 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
||||||
final double d = addToNearestCentroidCluster(point, newClusters);
|
final double d = addToNearestCentroidCluster(point, newClusters);
|
||||||
squareDistance += d * d;
|
squareDistance += d * d;
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Pair<>(squareDistance, newClusters);
|
return new Pair<>(squareDistance, newClusters);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,14 +180,16 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
||||||
* Extract a random subset of the given {@code points}.
|
* Extract a random subset of the given {@code points}.
|
||||||
*
|
*
|
||||||
* @param points All the points to cluster.
|
* @param points All the points to cluster.
|
||||||
* @param batchSize Mini batch size.
|
* @param n Mini batch size.
|
||||||
* @return mini batch of all the points.
|
* @param rng Random generator.
|
||||||
|
* @return {@code n} points.
|
||||||
*/
|
*/
|
||||||
private List<T> randomMiniBatch(final Collection<T> points,
|
private List<T> randomMiniBatch(final Collection<T> points,
|
||||||
final int batchSize) {
|
final int n,
|
||||||
|
final UniformRandomProvider rng) {
|
||||||
final ArrayList<T> list = new ArrayList<>(points);
|
final ArrayList<T> list = new ArrayList<>(points);
|
||||||
ListSampler.shuffle(getRandomGenerator(), list);
|
ListSampler.shuffle(rng, list);
|
||||||
return list.subList(0, batchSize);
|
return list.subList(0, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -182,12 +200,15 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
||||||
*/
|
*/
|
||||||
private List<CentroidCluster<T>> initialCenters(final Collection<T> points) {
|
private List<CentroidCluster<T>> initialCenters(final Collection<T> points) {
|
||||||
final List<T> validPoints = initBatchSize < points.size() ?
|
final List<T> validPoints = initBatchSize < points.size() ?
|
||||||
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points);
|
randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
|
||||||
|
new ArrayList<>(points);
|
||||||
double nearestSquareDistance = Double.POSITIVE_INFINITY;
|
double nearestSquareDistance = Double.POSITIVE_INFINITY;
|
||||||
List<CentroidCluster<T>> bestCenters = null;
|
List<CentroidCluster<T>> bestCenters = null;
|
||||||
|
|
||||||
for (int i = 0; i < initIterations; i++) {
|
for (int i = 0; i < initIterations; i++) {
|
||||||
final List<T> initialPoints = (initBatchSize < points.size()) ?
|
final List<T> initialPoints = (initBatchSize < points.size()) ?
|
||||||
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points);
|
randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
|
||||||
|
new ArrayList<>(points);
|
||||||
final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
|
final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
|
||||||
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
|
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
|
||||||
final double squareDistance = pair.getFirst();
|
final double squareDistance = pair.getFirst();
|
||||||
|
@ -208,10 +229,12 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
||||||
* @param clusters Clusters.
|
* @param clusters Clusters.
|
||||||
* @return the distance between point and the closest center.
|
* @return the distance between point and the closest center.
|
||||||
*/
|
*/
|
||||||
private double addToNearestCentroidCluster(final T point, final List<CentroidCluster<T>> clusters) {
|
private double addToNearestCentroidCluster(final T point,
|
||||||
|
final List<CentroidCluster<T>> clusters) {
|
||||||
double minDistance = Double.POSITIVE_INFINITY;
|
double minDistance = Double.POSITIVE_INFINITY;
|
||||||
CentroidCluster<T> closestCentroidCluster = null;
|
CentroidCluster<T> closestCentroidCluster = null;
|
||||||
// Iterate clusters and find out closest cluster to the point
|
|
||||||
|
// Find cluster closest to the point.
|
||||||
for (CentroidCluster<T> centroidCluster : clusters) {
|
for (CentroidCluster<T> centroidCluster : clusters) {
|
||||||
final double distance = distance(point, centroidCluster.getCenter());
|
final double distance = distance(point, centroidCluster.getCenter());
|
||||||
if (distance < minDistance) {
|
if (distance < minDistance) {
|
||||||
|
@ -219,8 +242,9 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
||||||
closestCentroidCluster = centroidCluster;
|
closestCentroidCluster = centroidCluster;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert closestCentroidCluster != null;
|
MathUtils.checkNotNull(closestCentroidCluster);
|
||||||
closestCentroidCluster.addPoint(point);
|
closestCentroidCluster.addPoint(point);
|
||||||
|
|
||||||
return minDistance;
|
return minDistance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,42 +253,42 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
|
||||||
* The evaluator checks whether improvement occurred during the
|
* The evaluator checks whether improvement occurred during the
|
||||||
* {@link #maxNoImprovementTimes allowed number of successive iterations}.
|
* {@link #maxNoImprovementTimes allowed number of successive iterations}.
|
||||||
*/
|
*/
|
||||||
class MiniBatchImprovementEvaluator {
|
private class MiniBatchImprovementEvaluator {
|
||||||
/** Missing doc. */
|
/** Missing doc. */
|
||||||
private Double ewaInertia = null;
|
private double ewaInertia = Double.NaN;
|
||||||
/** Missing doc. */
|
/** Missing doc. */
|
||||||
private double ewaInertiaMin = Double.POSITIVE_INFINITY;
|
private double ewaInertiaMin = Double.POSITIVE_INFINITY;
|
||||||
/** Missing doc. */
|
/** Missing doc. */
|
||||||
private int noImprovementTimes = 0;
|
private int noImprovementTimes = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Evaluate whether the iteration should finish where square has no improvement for appointed times
|
* Stopping criterion.
|
||||||
*
|
*
|
||||||
* @param squareDistance the total square distance of the mini batch points to their nearest center.
|
* @param squareDistance Total square distance from the batch points
|
||||||
* @param pointSize size of the the data points.
|
* to their nearest center.
|
||||||
* @return true if no improvement for appointed times, otherwise false
|
* @param pointSize Number of data points.
|
||||||
|
* @return {@code true} if no improvement was made after the allowed
|
||||||
|
* number of iterations, {@code false} otherwise.
|
||||||
*/
|
*/
|
||||||
public boolean convergence(final double squareDistance, final int pointSize) {
|
public boolean converge(final double squareDistance,
|
||||||
|
final int pointSize) {
|
||||||
final double batchInertia = squareDistance / batchSize;
|
final double batchInertia = squareDistance / batchSize;
|
||||||
if (ewaInertia == null) {
|
if (Double.isNaN(ewaInertia)) {
|
||||||
ewaInertia = batchInertia;
|
ewaInertia = batchInertia;
|
||||||
} else {
|
} else {
|
||||||
// Refer to sklearn, pointSize+1 maybe intent to avoid the div/0 error,
|
final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
|
||||||
// but java double does not have a div/0 error
|
|
||||||
double alpha = batchSize * 2.0 / (pointSize + 1);
|
|
||||||
alpha = Math.min(alpha, 1.0);
|
|
||||||
ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
|
ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Improved
|
|
||||||
if (ewaInertia < ewaInertiaMin) {
|
if (ewaInertia < ewaInertiaMin) {
|
||||||
|
// Improved.
|
||||||
noImprovementTimes = 0;
|
noImprovementTimes = 0;
|
||||||
ewaInertiaMin = ewaInertia;
|
ewaInertiaMin = ewaInertia;
|
||||||
} else {
|
} else {
|
||||||
// No improvement
|
// No improvement.
|
||||||
noImprovementTimes++;
|
++noImprovementTimes;
|
||||||
}
|
}
|
||||||
// Has no improvement continuous for many times
|
|
||||||
return noImprovementTimes >= maxNoImprovementTimes;
|
return noImprovementTimes >= maxNoImprovementTimes;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue