MATH-1509: Implement the MiniBatchKMeansClusterer.

This commit is contained in:
CT 2020-03-22 12:07:54 +08:00 committed by Gilles Sadowski
parent 22373aeb76
commit da455397c2
2 changed files with 457 additions and 0 deletions

View File

@ -0,0 +1,276 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.ml.clustering;
import org.apache.commons.math4.exception.ConvergenceException;
import org.apache.commons.math4.exception.MathIllegalArgumentException;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.util.MathUtils;
import org.apache.commons.math4.util.Pair;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.ListSampler;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
/**
* A very fast clustering algorithm base on KMeans(Refer to Python sklearn.cluster.MiniBatchKMeans)
* Use a partial points in initialize cluster centers, and mini batch in iterations.
* It finish in few seconds when clustering millions of data, and has few differences between KMeans.
* See https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf
*
* @param <T> Type of the points to cluster
*/
public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusPlusClusterer<T> {
/**
* Batch data size in iteration.
*/
private final int batchSize;
/**
* Iteration count of initialize the centers.
*/
private final int initIterations;
/**
* Data size of batch to initialize the centers, default 3*k
*/
private final int initBatchSize;
/**
* Max iterate times when no improvement on step iterations.
*/
private final int maxNoImprovementTimes;
/**
* Build a clusterer.
*
* @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for all the points,
* for mini batch actual iterations <= maxIterations * points.size() / batchSize
* If negative, no maximum will be used.
* @param batchSize the mini batch size for training iterations.
* @param initIterations the iterations to find out the best clusters centers with mini batch.
* @param initBatchSize the mini batch size to initial the clusters centers,
* batchSize * 3 is suitable for most case.
* @param maxNoImprovementTimes the max iterations times when the square distance has no improvement,
* 10 is suitable for most case.
* @param measure the distance measure to use, EuclideanDistance is recommended.
* @param random random generator to use for choosing initial centers
* may appear during algorithm iterations
*/
public MiniBatchKMeansClusterer(final int k, final int maxIterations, final int batchSize, final int initIterations,
final int initBatchSize, final int maxNoImprovementTimes,
final DistanceMeasure measure, final UniformRandomProvider random,
final EmptyClusterStrategy emptyStrategy) {
super(k, maxIterations, measure, random, emptyStrategy);
if (batchSize < 1) throw new NumberIsTooSmallException(batchSize, 1, true);
else this.batchSize = batchSize;
if (initIterations < 1) throw new NumberIsTooSmallException(initIterations, 1, true);
else this.initIterations = initIterations;
if (initBatchSize < 1) throw new NumberIsTooSmallException(initBatchSize, 1, true);
else this.initBatchSize = initBatchSize;
if (maxNoImprovementTimes < 1) throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
else this.maxNoImprovementTimes = maxNoImprovementTimes;
}
/**
* Runs the MiniBatch K-means clustering algorithm.
*
* @param points the points to cluster
* @return a list of clusters containing the points
* @throws MathIllegalArgumentException if the data points are null or the number
* of clusters is larger than the number of data points
*/
@Override
public List<CentroidCluster<T>> cluster(final Collection<T> points) throws MathIllegalArgumentException, ConvergenceException {
// sanity checks
MathUtils.checkNotNull(points);
// number of clusters has to be smaller or equal the number of data points
if (points.size() < getK()) {
throw new NumberIsTooSmallException(points.size(), getK(), false);
}
final int pointSize = points.size();
final int batchCount = pointSize / batchSize + ((pointSize % batchSize > 0) ? 1 : 0);
final int max = this.getMaxIterations() < 0 ? Integer.MAX_VALUE : (this.getMaxIterations() * batchCount);
List<CentroidCluster<T>> clusters = initialCenters(points);
// Loop execute the mini batch steps until reach the max loop times, or cannot improvement anymore.
final MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator();
for (int i = 0; i < max; i++) {
//Clear points in clusters
clearClustersPoints(clusters);
//Random sampling a mini batch of points.
final List<T> batchPoints = randomMiniBatch(points, batchSize);
// Processing the mini batch training step
final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
final double squareDistance = pair.getFirst();
clusters = pair.getSecond();
// Evaluate the training can finished early.
if (evaluator.convergence(squareDistance, pointSize)) break;
}
//Add every mini batch points to their nearest cluster.
clearClustersPoints(clusters);
for (final T point : points) {
addToNearestCentroidCluster(point, clusters);
}
return clusters;
}
/**
* clear clustered points
*
* @param clusters The clusters to clear
*/
private void clearClustersPoints(final List<CentroidCluster<T>> clusters) {
for (CentroidCluster<T> cluster : clusters) {
cluster.getPoints().clear();
}
}
/**
* Mini batch iteration step
*
* @param batchPoints The mini batch points.
* @param clusters The cluster centers.
* @return Square distance of all the batch points to the nearest center, and newly clusters.
*/
private Pair<Double, List<CentroidCluster<T>>> step(
final List<T> batchPoints,
final List<CentroidCluster<T>> clusters) {
//Add every mini batch points to their nearest cluster.
for (final T point : batchPoints) {
addToNearestCentroidCluster(point, clusters);
}
final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
// Add every mini batch points to their nearest cluster again.
double squareDistance = 0.0;
for (T point : batchPoints) {
final double d = addToNearestCentroidCluster(point, newClusters);
squareDistance += d * d;
}
return new Pair<>(squareDistance, newClusters);
}
/**
* Get a mini batch of points
*
* @param points all the points
* @param batchSize the mini batch size
* @return mini batch of all the points
*/
private List<T> randomMiniBatch(final Collection<T> points, final int batchSize) {
final ArrayList<T> list = new ArrayList<>(points);
ListSampler.shuffle(getRandomGenerator(), list);
return list.subList(0, batchSize);
}
/**
* Initial cluster centers with multiply iterations, find out the best.
*
* @param points Points use to initial the cluster centers.
* @return Clusters with center
*/
private List<CentroidCluster<T>> initialCenters(final Collection<T> points) {
final List<T> validPoints = initBatchSize < points.size() ?
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points);
double nearestSquareDistance = Double.POSITIVE_INFINITY;
List<CentroidCluster<T>> bestCenters = null;
for (int i = 0; i < initIterations; i++) {
final List<T> initialPoints = (initBatchSize < points.size()) ?
randomMiniBatch(points, initBatchSize) : new ArrayList<>(points);
final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
final double squareDistance = pair.getFirst();
final List<CentroidCluster<T>> newClusters = pair.getSecond();
//Find out a best centers that has the nearest total square distance.
if (squareDistance < nearestSquareDistance) {
nearestSquareDistance = squareDistance;
bestCenters = newClusters;
}
}
return bestCenters;
}
/**
* Add a point to the cluster which the closest center belong to
* and return the distance between point and the closest center.
*
* @param point The point to add.
* @param clusters The clusters to add to.
* @return The distance between point and the closest center.
*/
private double addToNearestCentroidCluster(final T point, final List<CentroidCluster<T>> clusters) {
double minDistance = Double.POSITIVE_INFINITY;
CentroidCluster<T> closestCentroidCluster = null;
// Iterate clusters and find out closest cluster to the point
for (CentroidCluster<T> centroidCluster : clusters) {
final double distance = distance(point, centroidCluster.getCenter());
if (distance < minDistance) {
minDistance = distance;
closestCentroidCluster = centroidCluster;
}
}
assert closestCentroidCluster != null;
closestCentroidCluster.addPoint(point);
return minDistance;
}
/**
* The Evaluator to evaluate whether the iteration should finish where square has no improvement for appointed times.
*/
class MiniBatchImprovementEvaluator {
private Double ewaInertia = null;
private double ewaInertiaMin = Double.POSITIVE_INFINITY;
private int noImprovementTimes = 0;
/**
* Evaluate whether the iteration should finish where square has no improvement for appointed times
*
* @param squareDistance the total square distance of the mini batch points to their nearest center.
* @param pointSize size of the the data points.
* @return true if no improvement for appointed times, otherwise false
*/
public boolean convergence(final double squareDistance, final int pointSize) {
final double batchInertia = squareDistance / batchSize;
if (ewaInertia == null) {
ewaInertia = batchInertia;
} else {
// Refer to sklearn, pointSize+1 maybe intent to avoid the div/0 error,
// but java double does not have a div/0 error
double alpha = batchSize * 2.0 / (pointSize + 1);
alpha = Math.min(alpha, 1.0);
ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
}
// Improved
if (ewaInertia < ewaInertiaMin) {
noImprovementTimes = 0;
ewaInertiaMin = ewaInertia;
} else {
// No improvement
noImprovementTimes++;
}
// Has no improvement continuous for many times
return noImprovementTimes >= maxNoImprovementTimes;
}
}
}

View File

@ -0,0 +1,181 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.ml.clustering;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.ml.clustering.evaluation.CalinskiHarabasz;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.ml.distance.EuclideanDistance;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class MiniBatchKMeansClustererTest {
/**
* Assert the illegal parameter throws proper Exceptions.
*/
@Test
public void testConstructorParameterChecks() {
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, -1, 3, 300, 10, null, null, null));
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, -2, 300, 10, null, null, null));
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, 3, -300, 10, null, null, null));
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, 3, 300, -10, null, null, null));
}
/**
* Expects block throws NumberIsTooSmallException.
* @param block the block need to run.
*/
private void expectNumberIsTooSmallException(Runnable block) {
assertException(block, NumberIsTooSmallException.class);
}
/**
* Compare the result to KMeansPlusPlusClusterer
*/
@Test
public void testCompareToKMeans() {
//Generate 4 cluster
int randomSeed = 0;
List<DoublePoint> data = generateCircles(randomSeed);
KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(4, -1, DEFAULT_MEASURE,
RandomSource.create(RandomSource.MT_64, randomSeed));
MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans = new MiniBatchKMeansClusterer<>(4, -1, 100, 3, 300, 10,
DEFAULT_MEASURE, RandomSource.create(RandomSource.MT_64, randomSeed), KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
// Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer
for (int i = 0; i < 100; i++) {
List<CentroidCluster<DoublePoint>> kMeansClusters = kMeans.cluster(data);
List<CentroidCluster<DoublePoint>> miniBatchKMeansClusters = miniBatchKMeans.cluster(data);
// Assert cluster result has proper clusters count.
Assert.assertEquals(4, kMeansClusters.size());
Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size());
int totalDiffCount = 0;
for (CentroidCluster<DoublePoint> kMeanCluster : kMeansClusters) {
// Find out most similar cluster between two clusters, and summary the points count variances.
CentroidCluster<DoublePoint> miniBatchCluster = predict(miniBatchKMeansClusters, kMeanCluster.getCenter());
totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size());
}
// Statistic points different ratio.
double diffPointsRatio = totalDiffCount * 1.0 / data.size();
// Evaluator score different ratio by "CalinskiHarabasz" algorithm.
ClusterEvaluator clusterEvaluator = new CalinskiHarabasz();
double kMeansScore = clusterEvaluator.score(kMeansClusters);
double miniBatchKMeansScore = clusterEvaluator.score(miniBatchKMeansClusters);
double scoreDiffRatio = (kMeansScore - miniBatchKMeansScore) /
kMeansScore;
// MiniBatchKMeansClusterer has few score differences between KMeansClusterer(less then 10%).
Assert.assertTrue(String.format("Different score ratio %f%%!, diff points ratio: %f%%", scoreDiffRatio * 100, diffPointsRatio * 100),
scoreDiffRatio < 0.1);
}
}
/**
* Generate points around 4 circles.
* @param randomSeed Random seed
* @return Generated points.
*/
private List<DoublePoint> generateCircles(int randomSeed) {
List<DoublePoint> data = new ArrayList<>();
Random random = new Random(randomSeed);
data.addAll(generateCircle(250, new double[]{-1.0, -1.0}, 1.0, random));
data.addAll(generateCircle(260, new double[]{0.0, 0.0}, 0.7, random));
data.addAll(generateCircle(270, new double[]{1.0, 1.0}, 0.7, random));
data.addAll(generateCircle(280, new double[]{2.0, 2.0}, 0.7, random));
return data;
}
/**
* Generate points as circles.
* @param count total points count.
* @param center circle center point.
* @param radius the circle radius points around.
* @param random the Random source.
* @return Generated points.
*/
List<DoublePoint> generateCircle(int count, double[] center, double radius, Random random) {
double x0 = center[0];
double y0 = center[1];
ArrayList<DoublePoint> list = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
double ao = random.nextDouble() * 720 - 360;
double r = random.nextDouble() * radius * 2 - radius;
double x1 = x0 + r * Math.cos(ao * Math.PI / 180);
double y1 = y0 + r * Math.sin(ao * Math.PI / 180);
list.add(new DoublePoint(new double[]{x1, y1}));
}
return list;
}
/**
* Assert there should be a exception.
*
* @param block The code block need to assert.
* @param exceptionClass A exception class.
*/
public static void assertException(Runnable block, Class<? extends Throwable> exceptionClass) {
try {
block.run();
Assert.fail(String.format("Expects %s", exceptionClass.getSimpleName()));
} catch (Throwable e) {
if (!exceptionClass.isInstance(e)) throw e;
}
}
/**
* Use EuclideanDistance as default DistanceMeasure
*/
public static final DistanceMeasure DEFAULT_MEASURE = new EuclideanDistance();
/**
* Predict which cluster is best for the point
*
* @param clusters cluster to predict into
* @param point point to predict
* @param measure distance measurer
* @param <T> type of cluster point
* @return the cluster which has nearest center to the point
*/
public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point, DistanceMeasure measure) {
double minDistance = Double.POSITIVE_INFINITY;
CentroidCluster<T> nearestCluster = null;
for (CentroidCluster<T> cluster : clusters) {
double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint());
if (distance < minDistance) {
minDistance = distance;
nearestCluster = cluster;
}
}
return nearestCluster;
}
/**
* Predict which cluster is best for the point
*
* @param clusters cluster to predict into
* @param point point to predict
* @param <T> type of cluster point
* @return the cluster which has nearest center to the point
*/
public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point) {
return predict(clusters, point, DEFAULT_MEASURE);
}
}