added a clustering package with an implementation of k-means++

JIRA: MATH-266

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@770979 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2009-05-02 19:34:51 +00:00
parent d94d0a556a
commit 28257de180
7 changed files with 500 additions and 0 deletions

View File

@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.stat.clustering;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
/**
* Cluster holding a set of {@link Clusterable} points.
* @param <T> the type of points that can be clustered
* @version $Revision$ $Date$
* @since 2.0
*/
public class Cluster<T extends Clusterable<T>> implements Serializable {
/** Serializable version identifier. */
private static final long serialVersionUID = -1741417096265465690L;
/** The points contained in this cluster. */
final List<T> points;
/** Center of the cluster. */
final T center;
/**
* Build a cluster centered at a specified point.
* @param center the point which is to be the center of this cluster
*/
public Cluster(final T center) {
this.center = center;
points = new ArrayList<T>();
}
/**
* Add a point to this cluster.
* @param point point to add
*/
public void addPoint(final T point) {
points.add(point);
}
/**
* Get the points contained in the cluster.
* @return points contained in the cluster
*/
public List<T> getPoints() {
return points;
}
/**
* Get the point chosen to be the center of this cluster.
* @return chosen cluster center
*/
public T getCenter() {
return center;
}
}

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.stat.clustering;
import java.io.Serializable;
import java.util.Collection;
/**
* Interface for points that can be clustered together.
* @param <T> the type of point that can be clustered
* @version $Revision$ $Date$
* @since 2.0
*/
public interface Clusterable<T> extends Serializable {
/**
* Returns the distance from the given point.
*
* @param p the point to compute the distance from
* @return the distance from the given point
*/
double distanceFrom(T p);
/**
* Returns the centroid of the given Collection of points.
*
* @param p the Collection of points to compute the centroid of
* @return the centroid of the given Collection of Points
*/
T centroidOf(Collection<T> p);
}

View File

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.stat.clustering;
import java.util.Collection;
import org.apache.commons.math.util.MathUtils;
/**
* A simple implementation of {@link Clusterable} for points with integer coordinates.
* @version $Revision$ $Date$
* @since 2.0
*/
public class EuclideanIntegerPoint implements Clusterable<EuclideanIntegerPoint> {
/** Serializable version identifier. */
private static final long serialVersionUID = 3946024775784901369L;
/** Point coordinates. */
private final int[] point;
/**
* @param point the n-dimensional point in integer space
*/
public EuclideanIntegerPoint(final int[] point) {
this.point = point;
}
/**
* Returns the n-dimensional point in integer space
*/
public int[] getPoint() {
return point;
}
/** {@inheritDoc} */
public double distanceFrom(final EuclideanIntegerPoint p) {
return MathUtils.distance(point, p.getPoint());
}
/** {@inheritDoc} */
public EuclideanIntegerPoint centroidOf(final Collection<EuclideanIntegerPoint> points) {
int[] centroid = new int[getPoint().length];
for (EuclideanIntegerPoint p : points) {
for (int i = 0; i < centroid.length; i++) {
centroid[i] += p.getPoint()[i];
}
}
for (int i = 0; i < centroid.length; i++) {
centroid[i] /= points.size();
}
return new EuclideanIntegerPoint(centroid);
}
/** {@inheritDoc} */
@Override
public boolean equals(final Object other) {
if (!(other instanceof EuclideanIntegerPoint)) {
return false;
}
final int[] otherPoint = ((EuclideanIntegerPoint) other).getPoint();
if (point.length != otherPoint.length) {
return false;
}
for (int i = 0; i < point.length; i++) {
if (point[i] != otherPoint[i]) {
return false;
}
}
return true;
}
/** {@inheritDoc} */
@Override
public int hashCode() {
int hashCode = 0;
for (Integer i : point) {
hashCode += i.hashCode() * 13 + 7;
}
return hashCode;
}
}

View File

@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.stat.clustering;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;
/**
* Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
* @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
* @version $Revision$ $Date$
* @since 2.0
*/
public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
/** Random generator for choosing initial centers. */
private final Random random;
/** Build a clusterer.
* @param random random generator to use for choosing initial centers
*/
public KMeansPlusPlusClusterer(final Random random) {
this.random = random;
}
/**
* Runs the K-means++ clustering algorithm.
*
* @param points the points to cluster
* @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm
* for. If negative, no maximum will be used
* @return a list of clusters containing the points
*/
public List<Cluster<T>> cluster(final Collection<T> points,
final int k, final int maxIterations) {
// create the initial clusters
List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
assignPointsToClusters(clusters, points);
// iterate through updating the centers until we're done
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
for (int count = 0; count < max; count++) {
boolean clusteringChanged = false;
List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
for (final Cluster<T> cluster : clusters) {
final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
if (!newCenter.equals(cluster.getCenter())) {
clusteringChanged = true;
}
newClusters.add(new Cluster<T>(newCenter));
}
if (!clusteringChanged) {
return clusters;
}
assignPointsToClusters(newClusters, points);
clusters = newClusters;
}
return clusters;
}
/**
* Adds the given points to the closest {@link Cluster}.
*
* @param clusters the {@link Cluster}s to add the points to
* @param points the points to add to the given {@link Cluster}s
*/
private static <T extends Clusterable<T>> void
assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
for (final T p : points) {
Cluster<T> cluster = getNearestCluster(clusters, p);
cluster.addPoint(p);
}
}
/**
* Use K-means++ to choose the initial centers.
*
* @param points the points to choose the initial centers from
* @param k the number of centers to choose
* @param random random generator to use
* @return the initial centers
*/
private static <T extends Clusterable<T>> List<Cluster<T>>
chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
final List<T> pointSet = new ArrayList<T>(points);
final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
// Choose one center uniformly at random from among the data points.
final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
resultSet.add(new Cluster<T>(firstPoint));
final double[] dx2 = new double[pointSet.size()];
while (resultSet.size() < k) {
// For each data point x, compute D(x), the distance between x and
// the nearest center that has already been chosen.
int sum = 0;
for (int i = 0; i < pointSet.size(); i++) {
final T p = pointSet.get(i);
final Cluster<T> nearest = getNearestCluster(resultSet, p);
final double d = p.distanceFrom(nearest.getCenter());
sum += d * d;
dx2[i] = sum;
}
// Add one new data point as a center. Each point x is chosen with
// probability proportional to D(x)2
final double r = random.nextDouble() * sum;
for (int i = 0 ; i < dx2.length; i++) {
if (dx2[i] >= r) {
final T p = pointSet.remove(i);
resultSet.add(new Cluster<T>(p));
break;
}
}
}
return resultSet;
}
/**
* Returns the nearest {@link Cluster} to the given point
*
* @param clusters the {@link Cluster}s to search
* @param point the point to find the nearest {@link Cluster} for
* @return the nearest {@link Cluster} to the given point
*/
private static <T extends Clusterable<T>> Cluster<T>
getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
double minDistance = Double.MAX_VALUE;
Cluster<T> minCluster = null;
for (final Cluster<T> c : clusters) {
final double distance = point.distanceFrom(c.getCenter());
if (distance < minDistance) {
minDistance = distance;
minCluster = c;
}
}
return minCluster;
}
}

View File

@ -0,0 +1,20 @@
<html>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<!-- $Revision$ $Date$ -->
<body>Clustering algorithms</body>
</html>

View File

@ -39,6 +39,9 @@ The <action> type attribute can be add,update,fix,remove.
</properties>
<body>
<release version="2.0" date="TBD" description="TBD">
<action dev="luc" type="add" issue="MATH-266" due-to="Benjamin McCann">
Added a clustering package with an implementation of the k-means++ algorithm
</action>
<action dev="luc" type="fix" issue="MATH-265" due-to="Benjamin McCann">
Added distance1, distance and distanceInf utility methods for double and
int arrays in MathUtils

View File

@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.stat.clustering;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.junit.Test;
public class KMeansPlusPlusClustererTest {
@Test
public void dimension2() {
KMeansPlusPlusClusterer<EuclideanIntegerPoint> transformer =
new KMeansPlusPlusClusterer<EuclideanIntegerPoint>(new Random(1746432956321l));
EuclideanIntegerPoint[] points = new EuclideanIntegerPoint[] {
// first expected cluster
new EuclideanIntegerPoint(new int[] { -15, 3 }),
new EuclideanIntegerPoint(new int[] { -15, 4 }),
new EuclideanIntegerPoint(new int[] { -15, 5 }),
new EuclideanIntegerPoint(new int[] { -14, 3 }),
new EuclideanIntegerPoint(new int[] { -14, 5 }),
new EuclideanIntegerPoint(new int[] { -13, 3 }),
new EuclideanIntegerPoint(new int[] { -13, 4 }),
new EuclideanIntegerPoint(new int[] { -13, 5 }),
// second expected cluster
new EuclideanIntegerPoint(new int[] { -1, 0 }),
new EuclideanIntegerPoint(new int[] { -1, -1 }),
new EuclideanIntegerPoint(new int[] { 0, -1 }),
new EuclideanIntegerPoint(new int[] { 1, -1 }),
new EuclideanIntegerPoint(new int[] { 1, -2 }),
// third expected cluster
new EuclideanIntegerPoint(new int[] { 13, 3 }),
new EuclideanIntegerPoint(new int[] { 13, 4 }),
new EuclideanIntegerPoint(new int[] { 14, 4 }),
new EuclideanIntegerPoint(new int[] { 14, 7 }),
new EuclideanIntegerPoint(new int[] { 16, 5 }),
new EuclideanIntegerPoint(new int[] { 16, 6 }),
new EuclideanIntegerPoint(new int[] { 17, 4 }),
new EuclideanIntegerPoint(new int[] { 17, 7 })
};
List<Cluster<EuclideanIntegerPoint>> clusters =
transformer.cluster(Arrays.asList(points), 3, 10);
assertEquals(3, clusters.size());
boolean cluster1Found = false;
boolean cluster2Found = false;
boolean cluster3Found = false;
for (Cluster<EuclideanIntegerPoint> cluster : clusters) {
int[] center = cluster.getCenter().getPoint();
if (center[0] < 0) {
cluster1Found = true;
assertEquals(8, cluster.getPoints().size());
assertEquals(-14, center[0]);
assertEquals( 4, center[1]);
} else if (center[1] < 0) {
cluster2Found = true;
assertEquals(5, cluster.getPoints().size());
assertEquals( 0, center[0]);
assertEquals(-1, center[1]);
} else {
cluster3Found = true;
assertEquals(8, cluster.getPoints().size());
assertEquals(15, center[0]);
assertEquals(5, center[1]);
}
}
assertTrue(cluster1Found);
assertTrue(cluster2Found);
assertTrue(cluster3Found);
}
}