added a clustering package with an implementation of k-means++
JIRA: MATH-266 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@770979 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
d94d0a556a
commit
28257de180
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math.stat.clustering;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Cluster holding a set of {@link Clusterable} points.
|
||||
* @param <T> the type of points that can be clustered
|
||||
* @version $Revision$ $Date$
|
||||
* @since 2.0
|
||||
*/
|
||||
public class Cluster<T extends Clusterable<T>> implements Serializable {
|
||||
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = -1741417096265465690L;
|
||||
|
||||
/** The points contained in this cluster. */
|
||||
final List<T> points;
|
||||
|
||||
/** Center of the cluster. */
|
||||
final T center;
|
||||
|
||||
/**
|
||||
* Build a cluster centered at a specified point.
|
||||
* @param center the point which is to be the center of this cluster
|
||||
*/
|
||||
public Cluster(final T center) {
|
||||
this.center = center;
|
||||
points = new ArrayList<T>();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a point to this cluster.
|
||||
* @param point point to add
|
||||
*/
|
||||
public void addPoint(final T point) {
|
||||
points.add(point);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the points contained in the cluster.
|
||||
* @return points contained in the cluster
|
||||
*/
|
||||
public List<T> getPoints() {
|
||||
return points;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the point chosen to be the center of this cluster.
|
||||
* @return chosen cluster center
|
||||
*/
|
||||
public T getCenter() {
|
||||
return center;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math.stat.clustering;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Collection;
|
||||
|
||||
/**
|
||||
* Interface for points that can be clustered together.
|
||||
* @param <T> the type of point that can be clustered
|
||||
* @version $Revision$ $Date$
|
||||
* @since 2.0
|
||||
*/
|
||||
public interface Clusterable<T> extends Serializable {
|
||||
|
||||
/**
|
||||
* Returns the distance from the given point.
|
||||
*
|
||||
* @param p the point to compute the distance from
|
||||
* @return the distance from the given point
|
||||
*/
|
||||
double distanceFrom(T p);
|
||||
|
||||
/**
|
||||
* Returns the centroid of the given Collection of points.
|
||||
*
|
||||
* @param p the Collection of points to compute the centroid of
|
||||
* @return the centroid of the given Collection of Points
|
||||
*/
|
||||
T centroidOf(Collection<T> p);
|
||||
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math.stat.clustering;
|
||||
|
||||
import java.util.Collection;
|
||||
|
||||
import org.apache.commons.math.util.MathUtils;
|
||||
|
||||
/**
|
||||
* A simple implementation of {@link Clusterable} for points with integer coordinates.
|
||||
* @version $Revision$ $Date$
|
||||
* @since 2.0
|
||||
*/
|
||||
public class EuclideanIntegerPoint implements Clusterable<EuclideanIntegerPoint> {
|
||||
|
||||
/** Serializable version identifier. */
|
||||
private static final long serialVersionUID = 3946024775784901369L;
|
||||
|
||||
/** Point coordinates. */
|
||||
private final int[] point;
|
||||
|
||||
/**
|
||||
* @param point the n-dimensional point in integer space
|
||||
*/
|
||||
public EuclideanIntegerPoint(final int[] point) {
|
||||
this.point = point;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the n-dimensional point in integer space
|
||||
*/
|
||||
public int[] getPoint() {
|
||||
return point;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double distanceFrom(final EuclideanIntegerPoint p) {
|
||||
return MathUtils.distance(point, p.getPoint());
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public EuclideanIntegerPoint centroidOf(final Collection<EuclideanIntegerPoint> points) {
|
||||
int[] centroid = new int[getPoint().length];
|
||||
for (EuclideanIntegerPoint p : points) {
|
||||
for (int i = 0; i < centroid.length; i++) {
|
||||
centroid[i] += p.getPoint()[i];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < centroid.length; i++) {
|
||||
centroid[i] /= points.size();
|
||||
}
|
||||
return new EuclideanIntegerPoint(centroid);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public boolean equals(final Object other) {
|
||||
if (!(other instanceof EuclideanIntegerPoint)) {
|
||||
return false;
|
||||
}
|
||||
final int[] otherPoint = ((EuclideanIntegerPoint) other).getPoint();
|
||||
if (point.length != otherPoint.length) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < point.length; i++) {
|
||||
if (point[i] != otherPoint[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int hashCode = 0;
|
||||
for (Integer i : point) {
|
||||
hashCode += i.hashCode() * 13 + 7;
|
||||
}
|
||||
return hashCode;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,161 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math.stat.clustering;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
|
||||
* @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
|
||||
* @version $Revision$ $Date$
|
||||
* @since 2.0
|
||||
*/
|
||||
public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
|
||||
|
||||
/** Random generator for choosing initial centers. */
|
||||
private final Random random;
|
||||
|
||||
/** Build a clusterer.
|
||||
* @param random random generator to use for choosing initial centers
|
||||
*/
|
||||
public KMeansPlusPlusClusterer(final Random random) {
|
||||
this.random = random;
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the K-means++ clustering algorithm.
|
||||
*
|
||||
* @param points the points to cluster
|
||||
* @param k the number of clusters to split the data into
|
||||
* @param maxIterations the maximum number of iterations to run the algorithm
|
||||
* for. If negative, no maximum will be used
|
||||
* @return a list of clusters containing the points
|
||||
*/
|
||||
public List<Cluster<T>> cluster(final Collection<T> points,
|
||||
final int k, final int maxIterations) {
|
||||
// create the initial clusters
|
||||
List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
|
||||
assignPointsToClusters(clusters, points);
|
||||
|
||||
// iterate through updating the centers until we're done
|
||||
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
|
||||
for (int count = 0; count < max; count++) {
|
||||
boolean clusteringChanged = false;
|
||||
List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
|
||||
for (final Cluster<T> cluster : clusters) {
|
||||
final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
|
||||
if (!newCenter.equals(cluster.getCenter())) {
|
||||
clusteringChanged = true;
|
||||
}
|
||||
newClusters.add(new Cluster<T>(newCenter));
|
||||
}
|
||||
if (!clusteringChanged) {
|
||||
return clusters;
|
||||
}
|
||||
assignPointsToClusters(newClusters, points);
|
||||
clusters = newClusters;
|
||||
}
|
||||
return clusters;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the given points to the closest {@link Cluster}.
|
||||
*
|
||||
* @param clusters the {@link Cluster}s to add the points to
|
||||
* @param points the points to add to the given {@link Cluster}s
|
||||
*/
|
||||
private static <T extends Clusterable<T>> void
|
||||
assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
|
||||
for (final T p : points) {
|
||||
Cluster<T> cluster = getNearestCluster(clusters, p);
|
||||
cluster.addPoint(p);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Use K-means++ to choose the initial centers.
|
||||
*
|
||||
* @param points the points to choose the initial centers from
|
||||
* @param k the number of centers to choose
|
||||
* @param random random generator to use
|
||||
* @return the initial centers
|
||||
*/
|
||||
private static <T extends Clusterable<T>> List<Cluster<T>>
|
||||
chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
|
||||
|
||||
final List<T> pointSet = new ArrayList<T>(points);
|
||||
final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
|
||||
|
||||
// Choose one center uniformly at random from among the data points.
|
||||
final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
|
||||
resultSet.add(new Cluster<T>(firstPoint));
|
||||
|
||||
final double[] dx2 = new double[pointSet.size()];
|
||||
while (resultSet.size() < k) {
|
||||
// For each data point x, compute D(x), the distance between x and
|
||||
// the nearest center that has already been chosen.
|
||||
int sum = 0;
|
||||
for (int i = 0; i < pointSet.size(); i++) {
|
||||
final T p = pointSet.get(i);
|
||||
final Cluster<T> nearest = getNearestCluster(resultSet, p);
|
||||
final double d = p.distanceFrom(nearest.getCenter());
|
||||
sum += d * d;
|
||||
dx2[i] = sum;
|
||||
}
|
||||
|
||||
// Add one new data point as a center. Each point x is chosen with
|
||||
// probability proportional to D(x)2
|
||||
final double r = random.nextDouble() * sum;
|
||||
for (int i = 0 ; i < dx2.length; i++) {
|
||||
if (dx2[i] >= r) {
|
||||
final T p = pointSet.remove(i);
|
||||
resultSet.add(new Cluster<T>(p));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resultSet;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the nearest {@link Cluster} to the given point
|
||||
*
|
||||
* @param clusters the {@link Cluster}s to search
|
||||
* @param point the point to find the nearest {@link Cluster} for
|
||||
* @return the nearest {@link Cluster} to the given point
|
||||
*/
|
||||
private static <T extends Clusterable<T>> Cluster<T>
|
||||
getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
|
||||
double minDistance = Double.MAX_VALUE;
|
||||
Cluster<T> minCluster = null;
|
||||
for (final Cluster<T> c : clusters) {
|
||||
final double distance = point.distanceFrom(c.getCenter());
|
||||
if (distance < minDistance) {
|
||||
minDistance = distance;
|
||||
minCluster = c;
|
||||
}
|
||||
}
|
||||
return minCluster;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
<html>
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
contributor license agreements. See the NOTICE file distributed with
|
||||
this work for additional information regarding copyright ownership.
|
||||
The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
(the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
<!-- $Revision$ $Date$ -->
|
||||
<body>Clustering algorithms</body>
|
||||
</html>
|
|
@ -39,6 +39,9 @@ The <action> type attribute can be add,update,fix,remove.
|
|||
</properties>
|
||||
<body>
|
||||
<release version="2.0" date="TBD" description="TBD">
|
||||
<action dev="luc" type="add" issue="MATH-266" due-to="Benjamin McCann">
|
||||
Added a clustering package with an implementation of the k-means++ algorithm
|
||||
</action>
|
||||
<action dev="luc" type="fix" issue="MATH-265" due-to="Benjamin McCann">
|
||||
Added distance1, distance and distanceInf utility methods for double and
|
||||
int arrays in MathUtils
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math.stat.clustering;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class KMeansPlusPlusClustererTest {
|
||||
|
||||
@Test
|
||||
public void dimension2() {
|
||||
KMeansPlusPlusClusterer<EuclideanIntegerPoint> transformer =
|
||||
new KMeansPlusPlusClusterer<EuclideanIntegerPoint>(new Random(1746432956321l));
|
||||
EuclideanIntegerPoint[] points = new EuclideanIntegerPoint[] {
|
||||
|
||||
// first expected cluster
|
||||
new EuclideanIntegerPoint(new int[] { -15, 3 }),
|
||||
new EuclideanIntegerPoint(new int[] { -15, 4 }),
|
||||
new EuclideanIntegerPoint(new int[] { -15, 5 }),
|
||||
new EuclideanIntegerPoint(new int[] { -14, 3 }),
|
||||
new EuclideanIntegerPoint(new int[] { -14, 5 }),
|
||||
new EuclideanIntegerPoint(new int[] { -13, 3 }),
|
||||
new EuclideanIntegerPoint(new int[] { -13, 4 }),
|
||||
new EuclideanIntegerPoint(new int[] { -13, 5 }),
|
||||
|
||||
// second expected cluster
|
||||
new EuclideanIntegerPoint(new int[] { -1, 0 }),
|
||||
new EuclideanIntegerPoint(new int[] { -1, -1 }),
|
||||
new EuclideanIntegerPoint(new int[] { 0, -1 }),
|
||||
new EuclideanIntegerPoint(new int[] { 1, -1 }),
|
||||
new EuclideanIntegerPoint(new int[] { 1, -2 }),
|
||||
|
||||
// third expected cluster
|
||||
new EuclideanIntegerPoint(new int[] { 13, 3 }),
|
||||
new EuclideanIntegerPoint(new int[] { 13, 4 }),
|
||||
new EuclideanIntegerPoint(new int[] { 14, 4 }),
|
||||
new EuclideanIntegerPoint(new int[] { 14, 7 }),
|
||||
new EuclideanIntegerPoint(new int[] { 16, 5 }),
|
||||
new EuclideanIntegerPoint(new int[] { 16, 6 }),
|
||||
new EuclideanIntegerPoint(new int[] { 17, 4 }),
|
||||
new EuclideanIntegerPoint(new int[] { 17, 7 })
|
||||
|
||||
};
|
||||
List<Cluster<EuclideanIntegerPoint>> clusters =
|
||||
transformer.cluster(Arrays.asList(points), 3, 10);
|
||||
|
||||
assertEquals(3, clusters.size());
|
||||
boolean cluster1Found = false;
|
||||
boolean cluster2Found = false;
|
||||
boolean cluster3Found = false;
|
||||
for (Cluster<EuclideanIntegerPoint> cluster : clusters) {
|
||||
int[] center = cluster.getCenter().getPoint();
|
||||
if (center[0] < 0) {
|
||||
cluster1Found = true;
|
||||
assertEquals(8, cluster.getPoints().size());
|
||||
assertEquals(-14, center[0]);
|
||||
assertEquals( 4, center[1]);
|
||||
} else if (center[1] < 0) {
|
||||
cluster2Found = true;
|
||||
assertEquals(5, cluster.getPoints().size());
|
||||
assertEquals( 0, center[0]);
|
||||
assertEquals(-1, center[1]);
|
||||
} else {
|
||||
cluster3Found = true;
|
||||
assertEquals(8, cluster.getPoints().size());
|
||||
assertEquals(15, center[0]);
|
||||
assertEquals(5, center[1]);
|
||||
}
|
||||
}
|
||||
assertTrue(cluster1Found);
|
||||
assertTrue(cluster2Found);
|
||||
assertTrue(cluster3Found);
|
||||
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue