Use "sample" functionality from "ListSampler".
Change should improve performance by * extracting the number of required items (instead of shuffling the whole list), * converting the "Collection" to a "List" once.
This commit is contained in:
parent
844ffbeeee
commit
ab33121f7b
|
@ -116,12 +116,13 @@ public class MiniBatchKMeansClusterer<T extends Clusterable>
|
|||
Integer.MAX_VALUE :
|
||||
getMaxIterations() * batchCount;
|
||||
|
||||
List<CentroidCluster<T>> clusters = initialCenters(points);
|
||||
final List<T> pointList = new ArrayList<>(points);
|
||||
List<CentroidCluster<T>> clusters = initialCenters(pointList);
|
||||
|
||||
final MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator();
|
||||
for (int i = 0; i < max; i++) {
|
||||
clearClustersPoints(clusters);
|
||||
final List<T> batchPoints = randomMiniBatch(points, batchSize, getRandomGenerator());
|
||||
final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize);
|
||||
// Training step.
|
||||
final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
|
||||
final double squareDistance = pair.getFirst();
|
||||
|
@ -176,38 +177,22 @@ public class MiniBatchKMeansClusterer<T extends Clusterable>
|
|||
return new Pair<>(squareDistance, newClusters);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a random subset of the given {@code points}.
|
||||
*
|
||||
* @param points All the points to cluster.
|
||||
* @param n Mini batch size.
|
||||
* @param rng Random generator.
|
||||
* @return {@code n} points.
|
||||
*/
|
||||
private List<T> randomMiniBatch(final Collection<T> points,
|
||||
final int n,
|
||||
final UniformRandomProvider rng) {
|
||||
final ArrayList<T> list = new ArrayList<>(points);
|
||||
ListSampler.shuffle(rng, list);
|
||||
return list.subList(0, n);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the clusters centers.
|
||||
*
|
||||
* @param points Points used to initialize the centers.
|
||||
* @return clusters with their center initialized.
|
||||
*/
|
||||
private List<CentroidCluster<T>> initialCenters(final Collection<T> points) {
|
||||
private List<CentroidCluster<T>> initialCenters(final List<T> points) {
|
||||
final List<T> validPoints = initBatchSize < points.size() ?
|
||||
randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
|
||||
ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
|
||||
new ArrayList<>(points);
|
||||
double nearestSquareDistance = Double.POSITIVE_INFINITY;
|
||||
List<CentroidCluster<T>> bestCenters = null;
|
||||
|
||||
for (int i = 0; i < initIterations; i++) {
|
||||
final List<T> initialPoints = (initBatchSize < points.size()) ?
|
||||
randomMiniBatch(points, initBatchSize, getRandomGenerator()) :
|
||||
ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
|
||||
new ArrayList<>(points);
|
||||
final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
|
||||
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
|
||||
|
|
Loading…
Reference in New Issue