MATH-1524: Remove package initialization and reuse chooseInitialCenters as package-private
This commit is contained in:
parent
7ae8c7ac46
commit
f2eebe68d8
|
@ -17,21 +17,20 @@
|
||||||
|
|
||||||
package org.apache.commons.math4.ml.clustering;
|
package org.apache.commons.math4.ml.clustering;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import org.apache.commons.math4.exception.ConvergenceException;
|
import org.apache.commons.math4.exception.ConvergenceException;
|
||||||
import org.apache.commons.math4.exception.NumberIsTooSmallException;
|
import org.apache.commons.math4.exception.NumberIsTooSmallException;
|
||||||
import org.apache.commons.math4.exception.util.LocalizedFormats;
|
import org.apache.commons.math4.exception.util.LocalizedFormats;
|
||||||
import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer;
|
|
||||||
import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer;
|
|
||||||
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
||||||
import org.apache.commons.math4.ml.distance.EuclideanDistance;
|
import org.apache.commons.math4.ml.distance.EuclideanDistance;
|
||||||
import org.apache.commons.math4.stat.descriptive.moment.Variance;
|
import org.apache.commons.math4.stat.descriptive.moment.Variance;
|
||||||
import org.apache.commons.math4.util.MathUtils;
|
import org.apache.commons.math4.util.MathUtils;
|
||||||
import org.apache.commons.rng.simple.RandomSource;
|
|
||||||
import org.apache.commons.rng.UniformRandomProvider;
|
import org.apache.commons.rng.UniformRandomProvider;
|
||||||
|
import org.apache.commons.rng.simple.RandomSource;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
|
* Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
|
||||||
|
@ -70,9 +69,6 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
||||||
/** Selected strategy for empty clusters. */
|
/** Selected strategy for empty clusters. */
|
||||||
private final EmptyClusterStrategy emptyStrategy;
|
private final EmptyClusterStrategy emptyStrategy;
|
||||||
|
|
||||||
/** Clusters centroids initializer. */
|
|
||||||
private final CentroidInitializer centroidInitializer;
|
|
||||||
|
|
||||||
/** Build a clusterer.
|
/** Build a clusterer.
|
||||||
* <p>
|
* <p>
|
||||||
* The default strategy for handling empty clusters that may appear during
|
* The default strategy for handling empty clusters that may appear during
|
||||||
|
@ -151,8 +147,6 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
||||||
this.maxIterations = maxIterations;
|
this.maxIterations = maxIterations;
|
||||||
this.random = random;
|
this.random = random;
|
||||||
this.emptyStrategy = emptyStrategy;
|
this.emptyStrategy = emptyStrategy;
|
||||||
// Use K-means++ to choose the initial centers.
|
|
||||||
this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure, random);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -193,7 +187,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the initial clusters
|
// create the initial clusters
|
||||||
List<CentroidCluster<T>> clusters = centroidInitializer.selectCentroids(points, k);
|
List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
|
||||||
|
|
||||||
// create an array containing the latest assignment of a point to a cluster
|
// create an array containing the latest assignment of a point to a cluster
|
||||||
// no need to initialize the array, as it will be filled with the first assignment
|
// no need to initialize the array, as it will be filled with the first assignment
|
||||||
|
@ -231,13 +225,6 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
||||||
return emptyStrategy;
|
return emptyStrategy;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @return the CentroidInitializer
|
|
||||||
*/
|
|
||||||
CentroidInitializer getCentroidInitializer() {
|
|
||||||
return centroidInitializer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Adjust the clusters's centers with means of points
|
* Adjust the clusters's centers with means of points
|
||||||
* @param clusters the origin clusters
|
* @param clusters the origin clusters
|
||||||
|
@ -296,6 +283,131 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
|
||||||
return assignedDifferently;
|
return assignedDifferently;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use K-means++ to choose the initial centers.
|
||||||
|
*
|
||||||
|
* @param points the points to choose the initial centers from
|
||||||
|
* @return the initial centers
|
||||||
|
*/
|
||||||
|
List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
|
||||||
|
|
||||||
|
// Convert to list for indexed access. Make it unmodifiable, since removal of items
|
||||||
|
// would screw up the logic of this method.
|
||||||
|
final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));
|
||||||
|
|
||||||
|
// The number of points in the list.
|
||||||
|
final int numPoints = pointList.size();
|
||||||
|
|
||||||
|
// Set the corresponding element in this array to indicate when
|
||||||
|
// elements of pointList are no longer available.
|
||||||
|
final boolean[] taken = new boolean[numPoints];
|
||||||
|
|
||||||
|
// The resulting list of initial centers.
|
||||||
|
final List<CentroidCluster<T>> resultSet = new ArrayList<>();
|
||||||
|
|
||||||
|
// Choose one center uniformly at random from among the data points.
|
||||||
|
final int firstPointIndex = random.nextInt(numPoints);
|
||||||
|
|
||||||
|
final T firstPoint = pointList.get(firstPointIndex);
|
||||||
|
|
||||||
|
resultSet.add(new CentroidCluster<T>(firstPoint));
|
||||||
|
|
||||||
|
// Must mark it as taken
|
||||||
|
taken[firstPointIndex] = true;
|
||||||
|
|
||||||
|
// To keep track of the minimum distance squared of elements of
|
||||||
|
// pointList to elements of resultSet.
|
||||||
|
final double[] minDistSquared = new double[numPoints];
|
||||||
|
|
||||||
|
// Initialize the elements. Since the only point in resultSet is firstPoint,
|
||||||
|
// this is very easy.
|
||||||
|
for (int i = 0; i < numPoints; i++) {
|
||||||
|
if (i != firstPointIndex) { // That point isn't considered
|
||||||
|
double d = distance(firstPoint, pointList.get(i));
|
||||||
|
minDistSquared[i] = d*d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (resultSet.size() < k) {
|
||||||
|
|
||||||
|
// Sum up the squared distances for the points in pointList not
|
||||||
|
// already taken.
|
||||||
|
double distSqSum = 0.0;
|
||||||
|
|
||||||
|
for (int i = 0; i < numPoints; i++) {
|
||||||
|
if (!taken[i]) {
|
||||||
|
distSqSum += minDistSquared[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add one new data point as a center. Each point x is chosen with
|
||||||
|
// probability proportional to D(x)2
|
||||||
|
final double r = random.nextDouble() * distSqSum;
|
||||||
|
|
||||||
|
// The index of the next point to be added to the resultSet.
|
||||||
|
int nextPointIndex = -1;
|
||||||
|
|
||||||
|
// Sum through the squared min distances again, stopping when
|
||||||
|
// sum >= r.
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int i = 0; i < numPoints; i++) {
|
||||||
|
if (!taken[i]) {
|
||||||
|
sum += minDistSquared[i];
|
||||||
|
if (sum >= r) {
|
||||||
|
nextPointIndex = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's not set to >= 0, the point wasn't found in the previous
|
||||||
|
// for loop, probably because distances are extremely small. Just pick
|
||||||
|
// the last available point.
|
||||||
|
if (nextPointIndex == -1) {
|
||||||
|
for (int i = numPoints - 1; i >= 0; i--) {
|
||||||
|
if (!taken[i]) {
|
||||||
|
nextPointIndex = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We found one.
|
||||||
|
if (nextPointIndex >= 0) {
|
||||||
|
|
||||||
|
final T p = pointList.get(nextPointIndex);
|
||||||
|
|
||||||
|
resultSet.add(new CentroidCluster<T> (p));
|
||||||
|
|
||||||
|
// Mark it as taken.
|
||||||
|
taken[nextPointIndex] = true;
|
||||||
|
|
||||||
|
if (resultSet.size() < k) {
|
||||||
|
// Now update elements of minDistSquared. We only have to compute
|
||||||
|
// the distance to the new center to do this.
|
||||||
|
for (int j = 0; j < numPoints; j++) {
|
||||||
|
// Only have to worry about the points still not taken.
|
||||||
|
if (!taken[j]) {
|
||||||
|
double d = distance(p, pointList.get(j));
|
||||||
|
double d2 = d * d;
|
||||||
|
if (d2 < minDistSquared[j]) {
|
||||||
|
minDistSquared[j] = d2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// None found --
|
||||||
|
// Break from the while loop to prevent
|
||||||
|
// an infinite loop.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultSet;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get a random point from the {@link Cluster} with the largest distance variance.
|
* Get a random point from the {@link Cluster} with the largest distance variance.
|
||||||
*
|
*
|
||||||
|
|
|
@ -195,7 +195,7 @@ public class MiniBatchKMeansClusterer<T extends Clusterable>
|
||||||
final List<T> initialPoints = (initBatchSize < points.size()) ?
|
final List<T> initialPoints = (initBatchSize < points.size()) ?
|
||||||
ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
|
ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
|
||||||
new ArrayList<>(points);
|
new ArrayList<>(points);
|
||||||
final List<CentroidCluster<T>> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
|
final List<CentroidCluster<T>> clusters = chooseInitialCenters(initialPoints);
|
||||||
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
|
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
|
||||||
final double squareDistance = pair.getFirst();
|
final double squareDistance = pair.getFirst();
|
||||||
final List<CentroidCluster<T>> newClusters = pair.getSecond();
|
final List<CentroidCluster<T>> newClusters = pair.getSecond();
|
||||||
|
|
|
@ -1,40 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.commons.math4.ml.clustering.initialization;
|
|
||||||
|
|
||||||
import org.apache.commons.math4.ml.clustering.CentroidCluster;
|
|
||||||
import org.apache.commons.math4.ml.clustering.Clusterable;
|
|
||||||
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface abstract the algorithm for clusterer to choose the initial centers.
|
|
||||||
*/
|
|
||||||
public interface CentroidInitializer {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Choose the initial centers.
|
|
||||||
*
|
|
||||||
* @param <T> Type of points to cluster.
|
|
||||||
* @param points the points to choose the initial centers from
|
|
||||||
* @param k The number of clusters
|
|
||||||
* @return the initial centers
|
|
||||||
*/
|
|
||||||
<T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final Collection<T> points, final int k);
|
|
||||||
}
|
|
|
@ -1,186 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.commons.math4.ml.clustering.initialization;
|
|
||||||
|
|
||||||
import org.apache.commons.math4.ml.clustering.CentroidCluster;
|
|
||||||
import org.apache.commons.math4.ml.clustering.Clusterable;
|
|
||||||
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
|
||||||
import org.apache.commons.rng.UniformRandomProvider;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use K-means++ to choose the initial centers.
|
|
||||||
*
|
|
||||||
* @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
|
|
||||||
*/
|
|
||||||
public class KMeansPlusPlusCentroidInitializer implements CentroidInitializer {
|
|
||||||
private final DistanceMeasure measure;
|
|
||||||
private final UniformRandomProvider random;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build a K-means++ CentroidInitializer
|
|
||||||
* @param measure the distance measure to use
|
|
||||||
* @param random the random to use.
|
|
||||||
*/
|
|
||||||
public KMeansPlusPlusCentroidInitializer(final DistanceMeasure measure, final UniformRandomProvider random) {
|
|
||||||
this.measure = measure;
|
|
||||||
this.random = random;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use K-means++ to choose the initial centers.
|
|
||||||
*
|
|
||||||
* @param points the points to choose the initial centers from
|
|
||||||
* @param k The number of clusters
|
|
||||||
* @return the initial centers
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public <T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final Collection<T> points, final int k) {
|
|
||||||
// Convert to list for indexed access. Make it unmodifiable, since removal of items
|
|
||||||
// would screw up the logic of this method.
|
|
||||||
final List<T> pointList = Collections.unmodifiableList(new ArrayList<>(points));
|
|
||||||
|
|
||||||
// The number of points in the list.
|
|
||||||
final int numPoints = pointList.size();
|
|
||||||
|
|
||||||
// Set the corresponding element in this array to indicate when
|
|
||||||
// elements of pointList are no longer available.
|
|
||||||
final boolean[] taken = new boolean[numPoints];
|
|
||||||
|
|
||||||
// The resulting list of initial centers.
|
|
||||||
final List<CentroidCluster<T>> resultSet = new ArrayList<>();
|
|
||||||
|
|
||||||
// Choose one center uniformly at random from among the data points.
|
|
||||||
final int firstPointIndex = random.nextInt(numPoints);
|
|
||||||
|
|
||||||
final T firstPoint = pointList.get(firstPointIndex);
|
|
||||||
|
|
||||||
resultSet.add(new CentroidCluster<>(firstPoint));
|
|
||||||
|
|
||||||
// Must mark it as taken
|
|
||||||
taken[firstPointIndex] = true;
|
|
||||||
|
|
||||||
// To keep track of the minimum distance squared of elements of
|
|
||||||
// pointList to elements of resultSet.
|
|
||||||
final double[] minDistSquared = new double[numPoints];
|
|
||||||
|
|
||||||
// Initialize the elements. Since the only point in resultSet is firstPoint,
|
|
||||||
// this is very easy.
|
|
||||||
for (int i = 0; i < numPoints; i++) {
|
|
||||||
if (i != firstPointIndex) { // That point isn't considered
|
|
||||||
double d = distance(firstPoint, pointList.get(i));
|
|
||||||
minDistSquared[i] = d * d;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
while (resultSet.size() < k) {
|
|
||||||
|
|
||||||
// Sum up the squared distances for the points in pointList not
|
|
||||||
// already taken.
|
|
||||||
double distSqSum = 0.0;
|
|
||||||
|
|
||||||
for (int i = 0; i < numPoints; i++) {
|
|
||||||
if (!taken[i]) {
|
|
||||||
distSqSum += minDistSquared[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add one new data point as a center. Each point x is chosen with
|
|
||||||
// probability proportional to D(x)2
|
|
||||||
final double r = random.nextDouble() * distSqSum;
|
|
||||||
|
|
||||||
// The index of the next point to be added to the resultSet.
|
|
||||||
int nextPointIndex = -1;
|
|
||||||
|
|
||||||
// Sum through the squared min distances again, stopping when
|
|
||||||
// sum >= r.
|
|
||||||
double sum = 0.0;
|
|
||||||
for (int i = 0; i < numPoints; i++) {
|
|
||||||
if (!taken[i]) {
|
|
||||||
sum += minDistSquared[i];
|
|
||||||
if (sum >= r) {
|
|
||||||
nextPointIndex = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it's not set to >= 0, the point wasn't found in the previous
|
|
||||||
// for loop, probably because distances are extremely small. Just pick
|
|
||||||
// the last available point.
|
|
||||||
if (nextPointIndex == -1) {
|
|
||||||
for (int i = numPoints - 1; i >= 0; i--) {
|
|
||||||
if (!taken[i]) {
|
|
||||||
nextPointIndex = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We found one.
|
|
||||||
if (nextPointIndex >= 0) {
|
|
||||||
|
|
||||||
final T p = pointList.get(nextPointIndex);
|
|
||||||
|
|
||||||
resultSet.add(new CentroidCluster<>(p));
|
|
||||||
|
|
||||||
// Mark it as taken.
|
|
||||||
taken[nextPointIndex] = true;
|
|
||||||
|
|
||||||
if (resultSet.size() < k) {
|
|
||||||
// Now update elements of minDistSquared. We only have to compute
|
|
||||||
// the distance to the new center to do this.
|
|
||||||
for (int j = 0; j < numPoints; j++) {
|
|
||||||
// Only have to worry about the points still not taken.
|
|
||||||
if (!taken[j]) {
|
|
||||||
double d = distance(p, pointList.get(j));
|
|
||||||
double d2 = d * d;
|
|
||||||
if (d2 < minDistSquared[j]) {
|
|
||||||
minDistSquared[j] = d2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// None found --
|
|
||||||
// Break from the while loop to prevent
|
|
||||||
// an infinite loop.
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return resultSet;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculates the distance between two {@link Clusterable} instances
|
|
||||||
* with the configured {@link DistanceMeasure}.
|
|
||||||
*
|
|
||||||
* @param p1 the first clusterable
|
|
||||||
* @param p2 the second clusterable
|
|
||||||
* @return the distance between the two clusterables
|
|
||||||
*/
|
|
||||||
protected double distance(final Clusterable p1, final Clusterable p2) {
|
|
||||||
return measure.compute(p1.getPoint(), p2.getPoint());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,65 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.commons.math4.ml.clustering.initialization;
|
|
||||||
|
|
||||||
import org.apache.commons.math4.ml.clustering.CentroidCluster;
|
|
||||||
import org.apache.commons.math4.ml.clustering.Clusterable;
|
|
||||||
import org.apache.commons.rng.UniformRandomProvider;
|
|
||||||
import org.apache.commons.rng.sampling.ListSampler;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Random choose the initial centers.
|
|
||||||
*/
|
|
||||||
public class RandomCentroidInitializer implements CentroidInitializer {
|
|
||||||
private final UniformRandomProvider random;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build a random RandomCentroidInitializer
|
|
||||||
*
|
|
||||||
* @param random the random to use.
|
|
||||||
*/
|
|
||||||
public RandomCentroidInitializer(final UniformRandomProvider random) {
|
|
||||||
this.random = random;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Random choose the initial centers.
|
|
||||||
*
|
|
||||||
* @param points the points to choose the initial centers from
|
|
||||||
* @param k The number of clusters
|
|
||||||
* @return the initial centers
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public <T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final Collection<T> points, final int k) {
|
|
||||||
if (k < 1) {
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
final ArrayList<T> list = new ArrayList<>(points);
|
|
||||||
ListSampler.shuffle(random, list);
|
|
||||||
final List<CentroidCluster<T>> result = new ArrayList<>(k);
|
|
||||||
for (int i = 0; i < k; i++) {
|
|
||||||
result.add(new CentroidCluster<>(list.get(i)));
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,24 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
/**
|
|
||||||
* Clusters initialization methods.
|
|
||||||
*
|
|
||||||
* Centroid cluster initializers should implement the
|
|
||||||
* {@link org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer CentroidInitializer}
|
|
||||||
* interface.
|
|
||||||
*/
|
|
||||||
package org.apache.commons.math4.ml.clustering.initialization;
|
|
|
@ -1,65 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
package org.apache.commons.math4.ml.clustering.initialization;
|
|
||||||
|
|
||||||
import org.apache.commons.math4.ml.clustering.CentroidCluster;
|
|
||||||
import org.apache.commons.math4.ml.clustering.DoublePoint;
|
|
||||||
import org.apache.commons.math4.ml.distance.EuclideanDistance;
|
|
||||||
import org.apache.commons.rng.UniformRandomProvider;
|
|
||||||
import org.apache.commons.rng.simple.RandomSource;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class CentroidInitializerTest {
|
|
||||||
private void test_generate_appropriate_number_of_cluster(
|
|
||||||
final CentroidInitializer initializer) {
|
|
||||||
// Generate some data
|
|
||||||
final List<DoublePoint> points = new ArrayList<>();
|
|
||||||
final UniformRandomProvider rnd = RandomSource.create(RandomSource.MT_64);
|
|
||||||
for (int i = 0; i < 500; i++) {
|
|
||||||
double[] p = new double[2];
|
|
||||||
p[0] = rnd.nextDouble();
|
|
||||||
p[1] = rnd.nextDouble();
|
|
||||||
points.add(new DoublePoint(p));
|
|
||||||
}
|
|
||||||
// We can only assert that the centroid initializer
|
|
||||||
// implementation generate appropriate number of cluster
|
|
||||||
for (int k = 1; k < 50; k++) {
|
|
||||||
final List<CentroidCluster<DoublePoint>> centroidClusters =
|
|
||||||
initializer.selectCentroids(points, k);
|
|
||||||
Assert.assertEquals(k, centroidClusters.size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_RandomCentroidInitializer() {
|
|
||||||
final CentroidInitializer initializer =
|
|
||||||
new RandomCentroidInitializer(RandomSource.create(RandomSource.MT_64));
|
|
||||||
test_generate_appropriate_number_of_cluster(initializer);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_KMeanPlusPlusCentroidInitializer() {
|
|
||||||
final CentroidInitializer initializer =
|
|
||||||
new KMeansPlusPlusCentroidInitializer(new EuclideanDistance(),
|
|
||||||
RandomSource.create(RandomSource.MT_64));
|
|
||||||
test_generate_appropriate_number_of_cluster(initializer);
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue