MATH-1509: Implement the MiniBatchKMeansClusterer.
This commit is contained in:
@ -0,0 +1,276 @@
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
import org.apache.commons.math4.exception.ConvergenceException;
import org.apache.commons.math4.exception.MathIllegalArgumentException;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.util.MathUtils;
import org.apache.commons.math4.util.Pair;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.ListSampler;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
* A very fast clustering algorithm base on KMeans(Refer to Python sklearn.cluster.MiniBatchKMeans)
* Use a partial points in initialize cluster centers, and mini batch in iterations.
* It finish in few seconds when clustering millions of data, and has few differences between KMeans.
* See
* @param <T> Type of the points to cluster
public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusPlusClusterer<T> {
* Batch data size in iteration.
private final int batchSize;
* Iteration count of initialize the centers.
private final int initIterations;
* Data size of batch to initialize the centers, default 3*k
private final int initBatchSize;
* Max iterate times when no improvement on step iterations.
private final int maxNoImprovementTimes;
* Build a clusterer.
* @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for all the points,
* for mini batch actual iterations <= maxIterations * points.size() / batchSize
* If negative, no maximum will be used.
* @param batchSize the mini batch size for training iterations.
* @param initIterations the iterations to find out the best clusters centers with mini batch.
* @param initBatchSize the mini batch size to initial the clusters centers,
* batchSize * 3 is suitable for most case.
* @param maxNoImprovementTimes the max iterations times when the square distance has no improvement,
* 10 is suitable for most case.
* @param measure the distance measure to use, EuclideanDistance is recommended.
* @param random random generator to use for choosing initial centers
* may appear during algorithm iterations
public MiniBatchKMeansClusterer(final int k, final int maxIterations, final int batchSize, final int initIterations,
final int initBatchSize, final int maxNoImprovementTimes,
final DistanceMeasure measure, final UniformRandomProvider random,
final EmptyClusterStrategy emptyStrategy) {
super(k, maxIterations, measure, random, emptyStrategy);
if (batchSize < 1) throw new NumberIsTooSmallException(batchSize, 1, true);
else this.batchSize = batchSize;
if (initIterations < 1) throw new NumberIsTooSmallException(initIterations, 1, true);
else this.initIterations = initIterations;
if (initBatchSize < 1) throw new NumberIsTooSmallException(initBatchSize, 1, true);
else this.initBatchSize = initBatchSize;
if (maxNoImprovementTimes < 1) throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
else this.maxNoImprovementTimes = maxNoImprovementTimes;
* Runs the MiniBatch K-means clustering algorithm.
* @param points the points to cluster
* @return a list of clusters containing the points
* @throws MathIllegalArgumentException if the data points are null or the number
* of clusters is larger than the number of data points
public List<CentroidCluster<T>> cluster(final Collection<T> points) throws MathIllegalArgumentException, ConvergenceException {
// sanity checks
// number of clusters has to be smaller or equal the number of data points
if (points.size() < getK()) {
throw new NumberIsTooSmallException(points.size(), getK(), false);
final int pointSize = points.size();
final int batchCount = pointSize / batchSize + ((pointSize % batchSize > 0) ? 1 : 0);
final int max = this.getMaxIterations() < 0 ? Integer.MAX_VALUE : (this.getMaxIterations() * batchCount);
List<CentroidCluster<T>> clusters = initialCenters(points);
// Loop execute the mini batch steps until reach the max loop times, or cannot improvement anymore.
final MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator();
for (int i = 0; i < max; i++) {
//Clear points in clusters
//Random sampling a mini batch of points.
final List<T> batchPoints = randomMiniBatch(points, batchSize);
// Processing the mini batch training step
final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
final double squareDistance = pair.getFirst();
clusters = pair.getSecond();
// Evaluate the training can finished early.
if (evaluator.convergence(squareDistance, pointSize)) break;
//Add every mini batch points to their nearest cluster.
for (final T point : points) {
addToNearestCentroidCluster(point, clusters);
return clusters;
* clear clustered points
* @param clusters The clusters to clear
private void clearClustersPoints(final List<CentroidCluster<T>> clusters) {
for (CentroidCluster<T> cluster : clusters) {
* Mini batch iteration step
* @param batchPoints The mini batch points.
* @param clusters The cluster centers.
* @return Square distance of all the batch points to the nearest center, and newly clusters.
private Pair<Double, List<CentroidCluster<T>>> step(
final List<T> batchPoints,
final List<CentroidCluster<T>> clusters) {
//Add every mini batch points to their nearest cluster.
for (final T point : batchPoints) {
addToNearestCentroidCluster(point, clusters);
final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
// Add every mini batch points to their nearest cluster again.
double squareDistance = 0.0;
for (T point : batchPoints) {
final double d = addToNearestCentroidCluster(point, newClusters);
squareDistance += d * d;
return new Pair<>(squareDistance, newClusters);
* Get a mini batch of points
* @param points all the points
* @param batchSize the mini batch size
* @return mini batch of all the points
private List<T> randomMiniBatch(final Collection<T> points, final int batchSize) {
final ArrayList<T> list = new ArrayList<>(points);
ListSampler.shuffle(getRandomGenerator(), list);
return list.subList(0, batchSize);
* Initial cluster centers with multiply iterations, find out the best.
* @param points Points use to initial the cluster centers.
* @return Clusters with center
private List<CentroidCluster<T>> initialCenters(final Collection<T> points) {
final List<T> validPoints = initBatchSize < points.size() ?
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points);
double nearestSquareDistance = Double.POSITIVE_INFINITY;
List<CentroidCluster<T>> bestCenters = null;
for (int i = 0; i < initIterations; i++) {
final List<T> initialPoints = (initBatchSize < points.size()) ?
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points);
final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
final double squareDistance = pair.getFirst();
final List<CentroidCluster<T>> newClusters = pair.getSecond();
//Find out a best centers that has the nearest total square distance.
if (squareDistance < nearestSquareDistance) {
nearestSquareDistance = squareDistance;
bestCenters = newClusters;
return bestCenters;
* Add a point to the cluster which the closest center belong to
* and return the distance between point and the closest center.
* @param point The point to add.
* @param clusters The clusters to add to.
* @return The distance between point and the closest center.
private double addToNearestCentroidCluster(final T point, final List<CentroidCluster<T>> clusters) {
double minDistance = Double.POSITIVE_INFINITY;
CentroidCluster<T> closestCentroidCluster = null;
// Iterate clusters and find out closest cluster to the point
for (CentroidCluster<T> centroidCluster : clusters) {
final double distance = distance(point, centroidCluster.getCenter());
if (distance < minDistance) {
minDistance = distance;
closestCentroidCluster = centroidCluster;
assert closestCentroidCluster != null;
return minDistance;
* The Evaluator to evaluate whether the iteration should finish where square has no improvement for appointed times.
class MiniBatchImprovementEvaluator {
private Double ewaInertia = null;
private double ewaInertiaMin = Double.POSITIVE_INFINITY;
private int noImprovementTimes = 0;
* Evaluate whether the iteration should finish where square has no improvement for appointed times
* @param squareDistance the total square distance of the mini batch points to their nearest center.
* @param pointSize size of the the data points.
* @return true if no improvement for appointed times, otherwise false
public boolean convergence(final double squareDistance, final int pointSize) {
final double batchInertia = squareDistance / batchSize;
if (ewaInertia == null) {
ewaInertia = batchInertia;
} else {
// Refer to sklearn, pointSize+1 maybe intent to avoid the div/0 error,
// but java double does not have a div/0 error
double alpha = batchSize * 2.0 / (pointSize + 1);
alpha = Math.min(alpha, 1.0);
ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
// Improved
if (ewaInertia < ewaInertiaMin) {
noImprovementTimes = 0;
ewaInertiaMin = ewaInertia;
} else {
// No improvement
// Has no improvement continuous for many times
return noImprovementTimes >= maxNoImprovementTimes;
@ -0,0 +1,181 @@
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class MiniBatchKMeansClustererTest {
* Assert the illegal parameter throws proper Exceptions.
public void testConstructorParameterChecks() {
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, -1, 3, 300, 10, null, null, null));
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, -2, 300, 10, null, null, null));
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, 3, -300, 10, null, null, null));
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, 3, 300, -10, null, null, null));
* Expects block throws NumberIsTooSmallException.
* @param block the block need to run.
private void expectNumberIsTooSmallException(Runnable block) {
assertException(block, NumberIsTooSmallException.class);
* Compare the result to KMeansPlusPlusClusterer
public void testCompareToKMeans() {
//Generate 4 cluster
int randomSeed = 0;
List<DoublePoint> data = generateCircles(randomSeed);
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(4, -1, DEFAULT_MEASURE,
RandomSource.create(RandomSource.MT_64, randomSeed));
MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans = new MiniBatchKMeansClusterer<>(4, -1, 100, 3, 300, 10,
DEFAULT_MEASURE, RandomSource.create(RandomSource.MT_64, randomSeed), KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
// Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer
for (int i = 0; i < 100; i++) {
List<CentroidCluster<DoublePoint>> kMeansClusters = kMeans.cluster(data);
List<CentroidCluster<DoublePoint>> miniBatchKMeansClusters = miniBatchKMeans.cluster(data);
// Assert cluster result has proper clusters count.
Assert.assertEquals(4, kMeansClusters.size());
Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size());
int totalDiffCount = 0;
for (CentroidCluster<DoublePoint> kMeanCluster : kMeansClusters) {
// Find out most similar cluster between two clusters, and summary the points count variances.
CentroidCluster<DoublePoint> miniBatchCluster = predict(miniBatchKMeansClusters, kMeanCluster.getCenter());
totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size());
// Statistic points different ratio.
double diffPointsRatio = totalDiffCount * 1.0 / data.size();
// Evaluator score different ratio by "CalinskiHarabasz" algorithm.
ClusterEvaluator clusterEvaluator = new CalinskiHarabasz();
double kMeansScore = clusterEvaluator.score(kMeansClusters);
double miniBatchKMeansScore = clusterEvaluator.score(miniBatchKMeansClusters);
double scoreDiffRatio = (kMeansScore - miniBatchKMeansScore) /
// MiniBatchKMeansClusterer has few score differences between KMeansClusterer(less then 10%).
Assert.assertTrue(String.format("Different score ratio %f%%!, diff points ratio: %f%%", scoreDiffRatio * 100, diffPointsRatio * 100),
scoreDiffRatio < 0.1);
* Generate points around 4 circles.
* @param randomSeed Random seed
* @return Generated points.
private List<DoublePoint> generateCircles(int randomSeed) {
List<DoublePoint> data = new ArrayList<>();
Random random = new Random(randomSeed);
data.addAll(generateCircle(250, new double[]{-1.0, -1.0}, 1.0, random));
data.addAll(generateCircle(260, new double[]{0.0, 0.0}, 0.7, random));
data.addAll(generateCircle(270, new double[]{1.0, 1.0}, 0.7, random));
data.addAll(generateCircle(280, new double[]{2.0, 2.0}, 0.7, random));
return data;
* Generate points as circles.
* @param count total points count.
* @param center circle center point.
* @param radius the circle radius points around.
* @param random the Random source.
* @return Generated points.
List<DoublePoint> generateCircle(int count, double[] center, double radius, Random random) {
double x0 = center[0];
double y0 = center[1];
ArrayList<DoublePoint> list = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
double ao = random.nextDouble() * 720 - 360;
double r = random.nextDouble() * radius * 2 - radius;
double x1 = x0 + r * Math.cos(ao * Math.PI / 180);
double y1 = y0 + r * Math.sin(ao * Math.PI / 180);
list.add(new DoublePoint(new double[]{x1, y1}));
return list;
* Assert there should be a exception.
* @param block The code block need to assert.
* @param exceptionClass A exception class.
public static void assertException(Runnable block, Class<? extends Throwable> exceptionClass) {
try {
|"Expects %s", exceptionClass.getSimpleName()));
} catch (Throwable e) {
if (!exceptionClass.isInstance(e)) throw e;
* Use EuclideanDistance as default DistanceMeasure
public static final DistanceMeasure DEFAULT_MEASURE = new EuclideanDistance();
* Predict which cluster is best for the point
* @param clusters cluster to predict into
* @param point point to predict
* @param measure distance measurer
* @param <T> type of cluster point
* @return the cluster which has nearest center to the point
public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point, DistanceMeasure measure) {
double minDistance = Double.POSITIVE_INFINITY;
CentroidCluster<T> nearestCluster = null;
for (CentroidCluster<T> cluster : clusters) {
double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint());
if (distance < minDistance) {
minDistance = distance;
nearestCluster = cluster;
return nearestCluster;
* Predict which cluster is best for the point
* @param clusters cluster to predict into
* @param point point to predict
* @param <T> type of cluster point
* @return the cluster which has nearest center to the point
public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point) {
return predict(clusters, point, DEFAULT_MEASURE);
Reference in New Issue