/** Selected strategy for empty clusters. */
private final EmptyClusterStrategy emptyStrategy;
- /** Clusters centroids initializer. */
- private final CentroidInitializer centroidInitializer;
-
/** Build a clusterer.
*
* The default strategy for handling empty clusters that may appear during
@@ -147,20 +143,18 @@ public class KMeansPlusPlusClusterer extends Clusterer
final UniformRandomProvider random,
final EmptyClusterStrategy emptyStrategy) {
super(measure);
- this.k = k;
+ this.numberOfClusters = k;
this.maxIterations = maxIterations;
this.random = random;
this.emptyStrategy = emptyStrategy;
- // Use K-means++ to choose the initial centers.
- this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure, random);
}
/**
* Return the number of clusters this instance will use.
* @return the number of clusters
*/
- public int getK() {
- return k;
+ public int getNumberOfClusters() {
+ return numberOfClusters;
}
/**
@@ -188,12 +182,12 @@ public class KMeansPlusPlusClusterer extends Clusterer
MathUtils.checkNotNull(points);
// number of clusters has to be smaller or equal the number of data points
- if (points.size() < k) {
- throw new NumberIsTooSmallException(points.size(), k, false);
+ if (points.size() < numberOfClusters) {
+ throw new NumberIsTooSmallException(points.size(), numberOfClusters, false);
}
// create the initial clusters
- List> clusters = centroidInitializer.selectCentroids(points, k);
+ List> clusters = chooseInitialCenters(points);
// create an array containing the latest assignment of a point to a cluster
// no need to initialize the array, as it will be filled with the first assignment
@@ -231,13 +225,6 @@ public class KMeansPlusPlusClusterer extends Clusterer
return emptyStrategy;
}
- /**
- * @return the CentroidInitializer
- */
- CentroidInitializer getCentroidInitializer() {
- return centroidInitializer;
- }
-
/**
* Adjust the clusters's centers with means of points
* @param clusters the origin clusters
@@ -296,6 +283,131 @@ public class KMeansPlusPlusClusterer extends Clusterer
return assignedDifferently;
}
+ /**
+ * Use K-means++ to choose the initial centers.
+ *
+ * @param points the points to choose the initial centers from
+ * @return the initial centers
+ */
+ List> chooseInitialCenters(final Collection points) {
+
+ // Convert to list for indexed access. Make it unmodifiable, since removal of items
+ // would screw up the logic of this method.
+ final List pointList = Collections.unmodifiableList(new ArrayList<> (points));
+
+ // The number of points in the list.
+ final int numPoints = pointList.size();
+
+ // Set the corresponding element in this array to indicate when
+ // elements of pointList are no longer available.
+ final boolean[] taken = new boolean[numPoints];
+
+ // The resulting list of initial centers.
+ final List> resultSet = new ArrayList<>();
+
+ // Choose one center uniformly at random from among the data points.
+ final int firstPointIndex = random.nextInt(numPoints);
+
+ final T firstPoint = pointList.get(firstPointIndex);
+
+ resultSet.add(new CentroidCluster(firstPoint));
+
+ // Must mark it as taken
+ taken[firstPointIndex] = true;
+
+ // To keep track of the minimum distance squared of elements of
+ // pointList to elements of resultSet.
+ final double[] minDistSquared = new double[numPoints];
+
+ // Initialize the elements. Since the only point in resultSet is firstPoint,
+ // this is very easy.
+ for (int i = 0; i < numPoints; i++) {
+ if (i != firstPointIndex) { // That point isn't considered
+ double d = distance(firstPoint, pointList.get(i));
+ minDistSquared[i] = d*d;
+ }
+ }
+
+ while (resultSet.size() < numberOfClusters) {
+
+ // Sum up the squared distances for the points in pointList not
+ // already taken.
+ double distSqSum = 0.0;
+
+ for (int i = 0; i < numPoints; i++) {
+ if (!taken[i]) {
+ distSqSum += minDistSquared[i];
+ }
+ }
+
+ // Add one new data point as a center. Each point x is chosen with
+ // probability proportional to D(x)2
+ final double r = random.nextDouble() * distSqSum;
+
+ // The index of the next point to be added to the resultSet.
+ int nextPointIndex = -1;
+
+ // Sum through the squared min distances again, stopping when
+ // sum >= r.
+ double sum = 0.0;
+ for (int i = 0; i < numPoints; i++) {
+ if (!taken[i]) {
+ sum += minDistSquared[i];
+ if (sum >= r) {
+ nextPointIndex = i;
+ break;
+ }
+ }
+ }
+
+ // If it's not set to >= 0, the point wasn't found in the previous
+ // for loop, probably because distances are extremely small. Just pick
+ // the last available point.
+ if (nextPointIndex == -1) {
+ for (int i = numPoints - 1; i >= 0; i--) {
+ if (!taken[i]) {
+ nextPointIndex = i;
+ break;
+ }
+ }
+ }
+
+ // We found one.
+ if (nextPointIndex >= 0) {
+
+ final T p = pointList.get(nextPointIndex);
+
+ resultSet.add(new CentroidCluster (p));
+
+ // Mark it as taken.
+ taken[nextPointIndex] = true;
+
+ if (resultSet.size() < numberOfClusters) {
+ // Now update elements of minDistSquared. We only have to compute
+ // the distance to the new center to do this.
+ for (int j = 0; j < numPoints; j++) {
+ // Only have to worry about the points still not taken.
+ if (!taken[j]) {
+ double d = distance(p, pointList.get(j));
+ double d2 = d * d;
+ if (d2 < minDistSquared[j]) {
+ minDistSquared[j] = d2;
+ }
+ }
+ }
+ }
+
+ } else {
+ // None found --
+ // Break from the while loop to prevent
+ // an infinite loop.
+ break;
+ }
+ }
+
+ return resultSet;
+ }
+
/**
* Get a random point from the {@link Cluster} with the largest distance variance.
*
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java
index e131e0a4d..a563fda3d 100644
--- a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java
@@ -106,8 +106,8 @@ public class MiniBatchKMeansClusterer
public List> cluster(final Collection points) {
// Sanity check.
MathUtils.checkNotNull(points);
- if (points.size() < getK()) {
- throw new NumberIsTooSmallException(points.size(), getK(), false);
+ if (points.size() < getNumberOfClusters()) {
+ throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false);
}
final int pointSize = points.size();
@@ -195,7 +195,7 @@ public class MiniBatchKMeansClusterer
final List initialPoints = (initBatchSize < points.size()) ?
ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
new ArrayList<>(points);
- final List> clusters = getCentroidInitializer().selectCentroids(initialPoints, getK());
+ final List> clusters = chooseInitialCenters(initialPoints);
final Pair>> pair = step(validPoints, clusters);
final double squareDistance = pair.getFirst();
final List> newClusters = pair.getSecond();
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java
deleted file mode 100644
index 0378f5516..000000000
--- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.commons.math4.ml.clustering.initialization;
-
-import org.apache.commons.math4.ml.clustering.CentroidCluster;
-import org.apache.commons.math4.ml.clustering.Clusterable;
-
-import java.util.Collection;
-import java.util.List;
-
-/**
- * Interface abstract the algorithm for clusterer to choose the initial centers.
- */
-public interface CentroidInitializer {
-
- /**
- * Choose the initial centers.
- *
- * @param Type of points to cluster.
- * @param points the points to choose the initial centers from
- * @param k The number of clusters
- * @return the initial centers
- */
- List> selectCentroids(final Collection points, final int k);
-}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java
deleted file mode 100644
index f0ab288c2..000000000
--- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java
+++ /dev/null
@@ -1,186 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.commons.math4.ml.clustering.initialization;
-
-import org.apache.commons.math4.ml.clustering.CentroidCluster;
-import org.apache.commons.math4.ml.clustering.Clusterable;
-import org.apache.commons.math4.ml.distance.DistanceMeasure;
-import org.apache.commons.rng.UniformRandomProvider;
-
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.List;
-
-/**
- * Use K-means++ to choose the initial centers.
- *
- * @see K-means++ (wikipedia)
- */
-public class KMeansPlusPlusCentroidInitializer implements CentroidInitializer {
- private final DistanceMeasure measure;
- private final UniformRandomProvider random;
-
- /**
- * Build a K-means++ CentroidInitializer
- * @param measure the distance measure to use
- * @param random the random to use.
- */
- public KMeansPlusPlusCentroidInitializer(final DistanceMeasure measure, final UniformRandomProvider random) {
- this.measure = measure;
- this.random = random;
- }
-
- /**
- * Use K-means++ to choose the initial centers.
- *
- * @param points the points to choose the initial centers from
- * @param k The number of clusters
- * @return the initial centers
- */
- @Override
- public List> selectCentroids(final Collection points, final int k) {
- // Convert to list for indexed access. Make it unmodifiable, since removal of items
- // would screw up the logic of this method.
- final List pointList = Collections.unmodifiableList(new ArrayList<>(points));
-
- // The number of points in the list.
- final int numPoints = pointList.size();
-
- // Set the corresponding element in this array to indicate when
- // elements of pointList are no longer available.
- final boolean[] taken = new boolean[numPoints];
-
- // The resulting list of initial centers.
- final List> resultSet = new ArrayList<>();
-
- // Choose one center uniformly at random from among the data points.
- final int firstPointIndex = random.nextInt(numPoints);
-
- final T firstPoint = pointList.get(firstPointIndex);
-
- resultSet.add(new CentroidCluster<>(firstPoint));
-
- // Must mark it as taken
- taken[firstPointIndex] = true;
-
- // To keep track of the minimum distance squared of elements of
- // pointList to elements of resultSet.
- final double[] minDistSquared = new double[numPoints];
-
- // Initialize the elements. Since the only point in resultSet is firstPoint,
- // this is very easy.
- for (int i = 0; i < numPoints; i++) {
- if (i != firstPointIndex) { // That point isn't considered
- double d = distance(firstPoint, pointList.get(i));
- minDistSquared[i] = d * d;
- }
- }
-
- while (resultSet.size() < k) {
-
- // Sum up the squared distances for the points in pointList not
- // already taken.
- double distSqSum = 0.0;
-
- for (int i = 0; i < numPoints; i++) {
- if (!taken[i]) {
- distSqSum += minDistSquared[i];
- }
- }
-
- // Add one new data point as a center. Each point x is chosen with
- // probability proportional to D(x)2
- final double r = random.nextDouble() * distSqSum;
-
- // The index of the next point to be added to the resultSet.
- int nextPointIndex = -1;
-
- // Sum through the squared min distances again, stopping when
- // sum >= r.
- double sum = 0.0;
- for (int i = 0; i < numPoints; i++) {
- if (!taken[i]) {
- sum += minDistSquared[i];
- if (sum >= r) {
- nextPointIndex = i;
- break;
- }
- }
- }
-
- // If it's not set to >= 0, the point wasn't found in the previous
- // for loop, probably because distances are extremely small. Just pick
- // the last available point.
- if (nextPointIndex == -1) {
- for (int i = numPoints - 1; i >= 0; i--) {
- if (!taken[i]) {
- nextPointIndex = i;
- break;
- }
- }
- }
-
- // We found one.
- if (nextPointIndex >= 0) {
-
- final T p = pointList.get(nextPointIndex);
-
- resultSet.add(new CentroidCluster<>(p));
-
- // Mark it as taken.
- taken[nextPointIndex] = true;
-
- if (resultSet.size() < k) {
- // Now update elements of minDistSquared. We only have to compute
- // the distance to the new center to do this.
- for (int j = 0; j < numPoints; j++) {
- // Only have to worry about the points still not taken.
- if (!taken[j]) {
- double d = distance(p, pointList.get(j));
- double d2 = d * d;
- if (d2 < minDistSquared[j]) {
- minDistSquared[j] = d2;
- }
- }
- }
- }
-
- } else {
- // None found --
- // Break from the while loop to prevent
- // an infinite loop.
- break;
- }
- }
-
- return resultSet;
- }
-
- /**
- * Calculates the distance between two {@link Clusterable} instances
- * with the configured {@link DistanceMeasure}.
- *
- * @param p1 the first clusterable
- * @param p2 the second clusterable
- * @return the distance between the two clusterables
- */
- protected double distance(final Clusterable p1, final Clusterable p2) {
- return measure.compute(p1.getPoint(), p2.getPoint());
- }
-}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java
deleted file mode 100644
index f3f561d15..000000000
--- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.commons.math4.ml.clustering.initialization;
-
-import org.apache.commons.math4.ml.clustering.CentroidCluster;
-import org.apache.commons.math4.ml.clustering.Clusterable;
-import org.apache.commons.rng.UniformRandomProvider;
-import org.apache.commons.rng.sampling.ListSampler;
-
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.List;
-
-/**
- * Random choose the initial centers.
- */
-public class RandomCentroidInitializer implements CentroidInitializer {
- private final UniformRandomProvider random;
-
- /**
- * Build a random RandomCentroidInitializer
- *
- * @param random the random to use.
- */
- public RandomCentroidInitializer(final UniformRandomProvider random) {
- this.random = random;
- }
-
- /**
- * Random choose the initial centers.
- *
- * @param points the points to choose the initial centers from
- * @param k The number of clusters
- * @return the initial centers
- */
- @Override
- public List> selectCentroids(final Collection points, final int k) {
- if (k < 1) {
- return Collections.emptyList();
- }
- final ArrayList list = new ArrayList<>(points);
- ListSampler.shuffle(random, list);
- final List> result = new ArrayList<>(k);
- for (int i = 0; i < k; i++) {
- result.add(new CentroidCluster<>(list.get(i)));
- }
- return result;
- }
-}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/package-info.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/package-info.java
deleted file mode 100644
index 3c335ae8e..000000000
--- a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/package-info.java
+++ /dev/null
@@ -1,24 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-/**
- * Clusters initialization methods.
- *
- * Centroid cluster initializers should implement the
- * {@link org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer CentroidInitializer}
- * interface.
- */
-package org.apache.commons.math4.ml.clustering.initialization;
diff --git a/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java b/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java
deleted file mode 100644
index 329dce71c..000000000
--- a/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.commons.math4.ml.clustering.initialization;
-
-import org.apache.commons.math4.ml.clustering.CentroidCluster;
-import org.apache.commons.math4.ml.clustering.DoublePoint;
-import org.apache.commons.math4.ml.distance.EuclideanDistance;
-import org.apache.commons.rng.UniformRandomProvider;
-import org.apache.commons.rng.simple.RandomSource;
-import org.junit.Assert;
-import org.junit.Test;
-
-import java.util.ArrayList;
-import java.util.List;
-
-public class CentroidInitializerTest {
- private void test_generate_appropriate_number_of_cluster(
- final CentroidInitializer initializer) {
- // Generate some data
- final List points = new ArrayList<>();
- final UniformRandomProvider rnd = RandomSource.create(RandomSource.MT_64);
- for (int i = 0; i < 500; i++) {
- double[] p = new double[2];
- p[0] = rnd.nextDouble();
- p[1] = rnd.nextDouble();
- points.add(new DoublePoint(p));
- }
- // We can only assert that the centroid initializer
- // implementation generate appropriate number of cluster
- for (int k = 1; k < 50; k++) {
- final List> centroidClusters =
- initializer.selectCentroids(points, k);
- Assert.assertEquals(k, centroidClusters.size());
- }
- }
-
- @Test
- public void test_RandomCentroidInitializer() {
- final CentroidInitializer initializer =
- new RandomCentroidInitializer(RandomSource.create(RandomSource.MT_64));
- test_generate_appropriate_number_of_cluster(initializer);
- }
-
- @Test
- public void test_KMeanPlusPlusCentroidInitializer() {
- final CentroidInitializer initializer =
- new KMeansPlusPlusCentroidInitializer(new EuclideanDistance(),
- RandomSource.create(RandomSource.MT_64));
- test_generate_appropriate_number_of_cluster(initializer);
- }
-}