MATH-1524 Move "chooseInitialCenters" out of the KMeansPlusPlusClusterer

This commit is contained in:
CT 2020-03-11 01:48:26 +08:00 committed by Gilles
parent aeca88c72d
commit 84102c0c4c
5 changed files with 347 additions and 126 deletions

View File

@ -26,6 +26,8 @@ import org.apache.commons.math4.exception.ConvergenceException;
import org.apache.commons.math4.exception.MathIllegalArgumentException; import org.apache.commons.math4.exception.MathIllegalArgumentException;
import org.apache.commons.math4.exception.NumberIsTooSmallException; import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.exception.util.LocalizedFormats; import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer;
import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer;
import org.apache.commons.math4.ml.distance.DistanceMeasure; import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.ml.distance.EuclideanDistance; import org.apache.commons.math4.ml.distance.EuclideanDistance;
import org.apache.commons.rng.simple.RandomSource; import org.apache.commons.rng.simple.RandomSource;
@ -70,6 +72,9 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
/** Selected strategy for empty clusters. */ /** Selected strategy for empty clusters. */
private final EmptyClusterStrategy emptyStrategy; private final EmptyClusterStrategy emptyStrategy;
/** Clusters centroids initializer. */
private final CentroidInitializer centroidInitializer;
/** Build a clusterer. /** Build a clusterer.
* <p> * <p>
* The default strategy for handling empty clusters that may appear during * The default strategy for handling empty clusters that may appear during
@ -148,6 +153,8 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
this.maxIterations = maxIterations; this.maxIterations = maxIterations;
this.random = random; this.random = random;
this.emptyStrategy = emptyStrategy; this.emptyStrategy = emptyStrategy;
// Use K-means++ to choose the initial centers.
this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure, random);
} }
/** /**
@ -203,7 +210,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
} }
// create the initial clusters // create the initial clusters
List<CentroidCluster<T>> clusters = chooseInitialCenters(points); List<CentroidCluster<T>> clusters = centroidInitializer.selectCentroids(points, k);
// create an array containing the latest assignment of a point to a cluster // create an array containing the latest assignment of a point to a cluster
// no need to initialize the array, as it will be filled with the first assignment // no need to initialize the array, as it will be filled with the first assignment
@ -276,131 +283,6 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T>
return assignedDifferently; return assignedDifferently;
} }
/**
* Use K-means++ to choose the initial centers.
*
* @param points the points to choose the initial centers from
* @return the initial centers
*/
private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
// Convert to list for indexed access. Make it unmodifiable, since removal of items
// would screw up the logic of this method.
final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));
// The number of points in the list.
final int numPoints = pointList.size();
// Set the corresponding element in this array to indicate when
// elements of pointList are no longer available.
final boolean[] taken = new boolean[numPoints];
// The resulting list of initial centers.
final List<CentroidCluster<T>> resultSet = new ArrayList<>();
// Choose one center uniformly at random from among the data points.
final int firstPointIndex = random.nextInt(numPoints);
final T firstPoint = pointList.get(firstPointIndex);
resultSet.add(new CentroidCluster<T>(firstPoint));
// Must mark it as taken
taken[firstPointIndex] = true;
// To keep track of the minimum distance squared of elements of
// pointList to elements of resultSet.
final double[] minDistSquared = new double[numPoints];
// Initialize the elements. Since the only point in resultSet is firstPoint,
// this is very easy.
for (int i = 0; i < numPoints; i++) {
if (i != firstPointIndex) { // That point isn't considered
double d = distance(firstPoint, pointList.get(i));
minDistSquared[i] = d*d;
}
}
while (resultSet.size() < k) {
// Sum up the squared distances for the points in pointList not
// already taken.
double distSqSum = 0.0;
for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
distSqSum += minDistSquared[i];
}
}
// Add one new data point as a center. Each point x is chosen with
// probability proportional to D(x)2
final double r = random.nextDouble() * distSqSum;
// The index of the next point to be added to the resultSet.
int nextPointIndex = -1;
// Sum through the squared min distances again, stopping when
// sum >= r.
double sum = 0.0;
for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
sum += minDistSquared[i];
if (sum >= r) {
nextPointIndex = i;
break;
}
}
}
// If it's not set to >= 0, the point wasn't found in the previous
// for loop, probably because distances are extremely small. Just pick
// the last available point.
if (nextPointIndex == -1) {
for (int i = numPoints - 1; i >= 0; i--) {
if (!taken[i]) {
nextPointIndex = i;
break;
}
}
}
// We found one.
if (nextPointIndex >= 0) {
final T p = pointList.get(nextPointIndex);
resultSet.add(new CentroidCluster<T> (p));
// Mark it as taken.
taken[nextPointIndex] = true;
if (resultSet.size() < k) {
// Now update elements of minDistSquared. We only have to compute
// the distance to the new center to do this.
for (int j = 0; j < numPoints; j++) {
// Only have to worry about the points still not taken.
if (!taken[j]) {
double d = distance(p, pointList.get(j));
double d2 = d * d;
if (d2 < minDistSquared[j]) {
minDistSquared[j] = d2;
}
}
}
}
} else {
// None found --
// Break from the while loop to prevent
// an infinite loop.
break;
}
}
return resultSet;
}
/** /**
* Get a random point from the {@link Cluster} with the largest distance variance. * Get a random point from the {@link Cluster} with the largest distance variance.
* *

View File

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.ml.clustering.initialization;
import org.apache.commons.math4.ml.clustering.CentroidCluster;
import org.apache.commons.math4.ml.clustering.Clusterable;
import java.util.Collection;
import java.util.List;
/**
* Interface abstract the algorithm for clusterer to choose the initial centers.
*/
public interface CentroidInitializer {
/**
* Choose the initial centers.
*
* @param points the points to choose the initial centers from
* @param k The number of clusters
* @return the initial centers
*/
<T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final Collection<T> points, final int k);
}

View File

@ -0,0 +1,186 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.ml.clustering.initialization;
import org.apache.commons.math4.ml.clustering.CentroidCluster;
import org.apache.commons.math4.ml.clustering.Clusterable;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.rng.UniformRandomProvider;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
/**
* Use K-means++ to choose the initial centers.
*
* @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
*/
public class KMeansPlusPlusCentroidInitializer implements CentroidInitializer {
private final DistanceMeasure measure;
private final UniformRandomProvider random;
/**
* Build a K-means++ CentroidInitializer
* @param measure the distance measure to use
* @param random the random to use.
*/
public KMeansPlusPlusCentroidInitializer(final DistanceMeasure measure, final UniformRandomProvider random) {
this.measure = measure;
this.random = random;
}
/**
* Use K-means++ to choose the initial centers.
*
* @param points the points to choose the initial centers from
* @param k The number of clusters
* @return the initial centers
*/
@Override
public <T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final Collection<T> points, final int k) {
// Convert to list for indexed access. Make it unmodifiable, since removal of items
// would screw up the logic of this method.
final List<T> pointList = Collections.unmodifiableList(new ArrayList<>(points));
// The number of points in the list.
final int numPoints = pointList.size();
// Set the corresponding element in this array to indicate when
// elements of pointList are no longer available.
final boolean[] taken = new boolean[numPoints];
// The resulting list of initial centers.
final List<CentroidCluster<T>> resultSet = new ArrayList<>();
// Choose one center uniformly at random from among the data points.
final int firstPointIndex = random.nextInt(numPoints);
final T firstPoint = pointList.get(firstPointIndex);
resultSet.add(new CentroidCluster<>(firstPoint));
// Must mark it as taken
taken[firstPointIndex] = true;
// To keep track of the minimum distance squared of elements of
// pointList to elements of resultSet.
final double[] minDistSquared = new double[numPoints];
// Initialize the elements. Since the only point in resultSet is firstPoint,
// this is very easy.
for (int i = 0; i < numPoints; i++) {
if (i != firstPointIndex) { // That point isn't considered
double d = distance(firstPoint, pointList.get(i));
minDistSquared[i] = d * d;
}
}
while (resultSet.size() < k) {
// Sum up the squared distances for the points in pointList not
// already taken.
double distSqSum = 0.0;
for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
distSqSum += minDistSquared[i];
}
}
// Add one new data point as a center. Each point x is chosen with
// probability proportional to D(x)2
final double r = random.nextDouble() * distSqSum;
// The index of the next point to be added to the resultSet.
int nextPointIndex = -1;
// Sum through the squared min distances again, stopping when
// sum >= r.
double sum = 0.0;
for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
sum += minDistSquared[i];
if (sum >= r) {
nextPointIndex = i;
break;
}
}
}
// If it's not set to >= 0, the point wasn't found in the previous
// for loop, probably because distances are extremely small. Just pick
// the last available point.
if (nextPointIndex == -1) {
for (int i = numPoints - 1; i >= 0; i--) {
if (!taken[i]) {
nextPointIndex = i;
break;
}
}
}
// We found one.
if (nextPointIndex >= 0) {
final T p = pointList.get(nextPointIndex);
resultSet.add(new CentroidCluster<>(p));
// Mark it as taken.
taken[nextPointIndex] = true;
if (resultSet.size() < k) {
// Now update elements of minDistSquared. We only have to compute
// the distance to the new center to do this.
for (int j = 0; j < numPoints; j++) {
// Only have to worry about the points still not taken.
if (!taken[j]) {
double d = distance(p, pointList.get(j));
double d2 = d * d;
if (d2 < minDistSquared[j]) {
minDistSquared[j] = d2;
}
}
}
}
} else {
// None found --
// Break from the while loop to prevent
// an infinite loop.
break;
}
}
return resultSet;
}
/**
* Calculates the distance between two {@link Clusterable} instances
* with the configured {@link DistanceMeasure}.
*
* @param p1 the first clusterable
* @param p2 the second clusterable
* @return the distance between the two clusterables
*/
protected double distance(final Clusterable p1, final Clusterable p2) {
return measure.compute(p1.getPoint(), p2.getPoint());
}
}

View File

@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.ml.clustering.initialization;
import org.apache.commons.math4.ml.clustering.CentroidCluster;
import org.apache.commons.math4.ml.clustering.Clusterable;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.ListSampler;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
/**
* Random choose the initial centers.
*/
public class RandomCentroidInitializer implements CentroidInitializer {
private final UniformRandomProvider random;
/**
* Build a random RandomCentroidInitializer
*
* @param random the random to use.
*/
public RandomCentroidInitializer(final UniformRandomProvider random) {
this.random = random;
}
/**
* Random choose the initial centers.
*
* @param points the points to choose the initial centers from
* @param k The number of clusters
* @return the initial centers
*/
@Override
public <T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final Collection<T> points, final int k) {
if (k < 1) {
return Collections.emptyList();
}
final ArrayList<T> list = new ArrayList<>(points);
ListSampler.shuffle(random, list);
final List<CentroidCluster<T>> result = new ArrayList<>(k);
for (int i = 0; i < k; i++) {
result.add(new CentroidCluster<>(list.get(i)));
}
return result;
}
}

View File

@ -0,0 +1,49 @@
package org.apache.commons.math4.ml.clustering.initialization;
import org.apache.commons.math4.ml.clustering.CentroidCluster;
import org.apache.commons.math4.ml.clustering.DoublePoint;
import org.apache.commons.math4.ml.distance.EuclideanDistance;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class CentroidInitializerTest {
private void test_generate_appropriate_number_of_cluster(
final CentroidInitializer initializer) {
// Generate some data
final List<DoublePoint> points = new ArrayList<>();
final UniformRandomProvider rnd = RandomSource.create(RandomSource.MT_64);
for (int i = 0; i < 500; i++) {
double[] p = new double[2];
p[0] = rnd.nextDouble();
p[1] = rnd.nextDouble();
points.add(new DoublePoint(p));
}
// We can only assert that the centroid initializer
// implementation generate appropriate number of cluster
for (int k = 1; k < 50; k++) {
final List<CentroidCluster<DoublePoint>> centroidClusters =
initializer.selectCentroids(points, k);
Assert.assertEquals(k, centroidClusters.size());
}
}
@Test
public void test_RandomCentroidInitializer() {
final CentroidInitializer initializer =
new RandomCentroidInitializer(RandomSource.create(RandomSource.MT_64));
test_generate_appropriate_number_of_cluster(initializer);
}
@Test
public void test_KMeanPlusPlusCentroidInitializer() {
final CentroidInitializer initializer =
new KMeansPlusPlusCentroidInitializer(new EuclideanDistance(),
RandomSource.create(RandomSource.MT_64));
test_generate_appropriate_number_of_cluster(initializer);
}
}