Use "sample" functionality from "ListSampler".

Change should improve performance by
 * extracting the number of required items (instead of shuffling the whole list),
 * converting the "Collection" to a "List" once.
This commit is contained in:
Gilles Sadowski 2020-03-22 14:46:58 +01:00
parent 844ffbeeee
commit ab33121f7b
1 changed files with 6 additions and 21 deletions

View File

@ -116,12 +116,13 @@ public class MiniBatchKMeansClusterer<T extends Clusterable>
Integer.MAX_VALUE :
getMaxIterations() * batchCount;
List<CentroidCluster<T>> clusters = initialCenters(points);
final List<T> pointList = new ArrayList<>(points);
List<CentroidCluster<T>> clusters = initialCenters(pointList);
final MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator();
for (int i = 0; i < max; i++) {
clearClustersPoints(clusters);
final List<T> batchPoints = randomMiniBatch(points, batchSize, getRandomGenerator());
final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize);
// Training step.
final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
final double squareDistance = pair.getFirst();
@ -176,38 +177,22 @@ public class MiniBatchKMeansClusterer<T extends Clusterable>
return new Pair<>(squareDistance, newClusters);
}
/**
* Extract a random subset of the given {@code points}.
*
* @param points All the points to cluster.
* @param n Mini batch size.
* @param rng Random generator.
* @return {@code n} points.
*/
private List<T> randomMiniBatch(final Collection<T> points,
final int n,
final UniformRandomProvider rng) {
final ArrayList<T> list = new ArrayList<>(points);
ListSampler.shuffle(rng, list);
return list.subList(0, n);
}
/**
* Initializes the clusters centers.
*
* @param points Points used to initialize the centers.
* @return clusters with their center initialized.
*/
private List<CentroidCluster<T>> initialCenters(final Collection<T> points) {
private List<CentroidCluster<T>> initialCenters(final List<T> points) {
final List<T> validPoints = initBatchSize < points.size() ?
randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
new ArrayList<>(points);
double nearestSquareDistance = Double.POSITIVE_INFINITY;
List<CentroidCluster<T>> bestCenters = null;
for (int i = 0; i < initIterations; i++) {
final List<T> initialPoints = (initBatchSize < points.size()) ?
randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
new ArrayList<>(points);
final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);