This commit is contained in:
Gilles Sadowski 2020-03-22 11:39:54 +01:00
parent c251395aaf
commit 103276c53b
1 changed files with 51 additions and 58 deletions

View File

@ -31,51 +31,39 @@ import java.util.Collection;
import java.util.List;
/**
* A very fast clustering algorithm base on KMeans(Refer to Python sklearn.cluster.MiniBatchKMeans)
* Use a partial points in initialize cluster centers, and mini batch in iterations.
* It finish in few seconds when clustering millions of data, and has few differences between KMeans.
* See https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf
* Clustering algorithm <a href="https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf">
* based on KMeans</a>.
*
* @param <T> Type of the points to cluster
* @param <T> Type of the points to cluster.
*/
public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusPlusClusterer<T> {
/**
* Batch data size in iteration.
*/
/** Batch data size in iteration. */
private final int batchSize;
/**
* Iteration count of initialize the centers.
*/
/** Iteration count of initialize the centers. */
private final int initIterations;
/**
* Data size of batch to initialize the centers, default 3*k
*/
/** Data size of batch to initialize the centers. */
private final int initBatchSize;
/**
* Max iterate times when no improvement on step iterations.
*/
/** Maximum number of iterations during which no improvement is occuring. */
private final int maxNoImprovementTimes;
/**
* Build a clusterer.
*
* @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for all the points,
* for mini batch actual {@code iterations <= maxIterations * points.size() / batchSize}.
* If negative, no maximum will be used.
* @param batchSize the mini batch size for training iterations.
* @param initIterations the iterations to find out the best clusters centers with mini batch.
* @param initBatchSize the mini batch size to initial the clusters centers,
* batchSize * 3 is suitable for most case.
* @param maxNoImprovementTimes the max iterations times when the square distance has no improvement,
* 10 is suitable for most case.
* @param measure the distance measure to use, EuclideanDistance is recommended.
* @param random random generator to use for choosing initial centers
* may appear during algorithm iterations
* @param emptyStrategy Strategy to use for handling empty clusters that
* may appear during algorithm iterations.
* @param k Number of clusters to split the data into.
* @param maxIterations Maximum number of iterations to run the algorithm for all the points,
* The actual number of iterationswill be smaller than {@code maxIterations * size / batchSize},
* where {@code size} is the number of points to cluster.
* Disabled if negative.
* @param batchSize Batch size for training iterations.
* @param initIterations Number of iterations allowed in order to find out the best initial centers.
* @param initBatchSize Batch size for initializing the clusters centers.
* A value of {@code 3 * batchSize} should be suitable in most cases.
* @param maxNoImprovementTimes Maximum number of iterations during which no improvement is occuring.
* A value of 10 is suitable in most cases.
* @param measure Distance measure.
* @param random Random generator.
* @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
*/
public MiniBatchKMeansClusterer(final int k, final int maxIterations, final int batchSize, final int initIterations,
final int initBatchSize, final int maxNoImprovementTimes,
@ -95,10 +83,10 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
/**
* Runs the MiniBatch K-means clustering algorithm.
*
* @param points the points to cluster
* @return a list of clusters containing the points
* @throws MathIllegalArgumentException if the data points are null or the number
* of clusters is larger than the number of data points
* @param points Points to cluster (cannot be {@code null}).
* @return the clusters.
* @throws MathIllegalArgumentException if number of points is
* smaller than the number of clusters.
*/
@Override
public List<CentroidCluster<T>> cluster(final Collection<T> points) throws MathIllegalArgumentException, ConvergenceException {
@ -138,9 +126,9 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
}
/**
* clear clustered points
* Helper method.
*
* @param clusters The clusters to clear
* @param clusters Clusters to clear.
*/
private void clearClustersPoints(final List<CentroidCluster<T>> clusters) {
for (CentroidCluster<T> cluster : clusters) {
@ -149,16 +137,16 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
}
/**
* Mini batch iteration step
* Mini batch iteration step.
*
* @param batchPoints The mini batch points.
* @param clusters The cluster centers.
* @return Square distance of all the batch points to the nearest center, and newly clusters.
* @param batchPoints Points selected for this batch.
* @param clusters Centers of the clusters.
* @return the squared distance of all the batch points to the nearest center.
*/
private Pair<Double, List<CentroidCluster<T>>> step(
final List<T> batchPoints,
final List<CentroidCluster<T>> clusters) {
//Add every mini batch points to their nearest cluster.
// Add every mini batch points to their nearest cluster.
for (final T point : batchPoints) {
addToNearestCentroidCluster(point, clusters);
}
@ -173,23 +161,24 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
}
/**
* Get a mini batch of points
* Extract a random subset of the given {@code points}.
*
* @param points all the points
* @param batchSize the mini batch size
* @return mini batch of all the points
* @param points All the points to cluster.
* @param batchSize Mini batch size.
* @return mini batch of all the points.
*/
private List<T> randomMiniBatch(final Collection<T> points, final int batchSize) {
private List<T> randomMiniBatch(final Collection<T> points,
final int batchSize) {
final ArrayList<T> list = new ArrayList<>(points);
ListSampler.shuffle(getRandomGenerator(), list);
return list.subList(0, batchSize);
}
/**
* Initial cluster centers with multiply iterations, find out the best.
* Initializes the clusters centers.
*
* @param points Points use to initial the cluster centers.
* @return Clusters with center
* @param points Points used to initialize the centers.
* @return clusters with their center initialized.
*/
private List<CentroidCluster<T>> initialCenters(final Collection<T> points) {
final List<T> validPoints = initBatchSize < points.size() ?
@ -213,12 +202,11 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
}
/**
* Add a point to the cluster which the closest center belong to
* and return the distance between point and the closest center.
* Adds a point to the cluster whose center is closest.
*
* @param point The point to add.
* @param clusters The clusters to add to.
* @return The distance between point and the closest center.
* @param point Point to add.
* @param clusters Clusters.
* @return the distance between point and the closest center.
*/
private double addToNearestCentroidCluster(final T point, final List<CentroidCluster<T>> clusters) {
double minDistance = Double.POSITIVE_INFINITY;
@ -237,11 +225,16 @@ public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusP
}
/**
* The Evaluator to evaluate whether the iteration should finish where square has no improvement for appointed times.
* Stopping criterion.
* The evaluator checks whether improvement occurred during the
* {@link #maxNoImprovementTimes allowed number of successive iterations}.
*/
class MiniBatchImprovementEvaluator {
/** Missing doc. */
private Double ewaInertia = null;
/** Missing doc. */
private double ewaInertiaMin = Double.POSITIVE_INFINITY;
/** Missing doc. */
private int noImprovementTimes = 0;
/**