[MATH-917] Refactored clustering package to include more distance measures.
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1461862 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
2b852d79dc
commit
363e116e34
|
@ -0,0 +1,38 @@
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Cluster used by centroid-based clustering algorithms.
|
||||||
|
* <p>
|
||||||
|
* Defines additionally a cluster center which may not necessarily be a member
|
||||||
|
* of the original data set.
|
||||||
|
*
|
||||||
|
* @param <T> the type of points that can be clustered
|
||||||
|
* @version $Id $
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public class CentroidCluster<T extends Clusterable> extends Cluster<T> {
|
||||||
|
|
||||||
|
/** Serializable version identifier. */
|
||||||
|
private static final long serialVersionUID = -3075288519071812288L;
|
||||||
|
|
||||||
|
/** Center of the cluster. */
|
||||||
|
private final Clusterable center;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a cluster centered at a specified point.
|
||||||
|
* @param center the point which is to be the center of this cluster
|
||||||
|
*/
|
||||||
|
public CentroidCluster(final Clusterable center) {
|
||||||
|
super();
|
||||||
|
this.center = center;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the point chosen to be the center of this cluster.
|
||||||
|
* @return chosen cluster center
|
||||||
|
*/
|
||||||
|
public Clusterable getCenter() {
|
||||||
|
return center;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cluster holding a set of {@link Clusterable} points.
|
||||||
|
* @param <T> the type of points that can be clustered
|
||||||
|
* @version $Id$
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public class Cluster<T extends Clusterable> implements Serializable {
|
||||||
|
|
||||||
|
/** Serializable version identifier. */
|
||||||
|
private static final long serialVersionUID = -3442297081515880464L;
|
||||||
|
|
||||||
|
/** The points contained in this cluster. */
|
||||||
|
private final List<T> points;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a cluster centered at a specified point.
|
||||||
|
*/
|
||||||
|
public Cluster() {
|
||||||
|
points = new ArrayList<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a point to this cluster.
|
||||||
|
* @param point point to add
|
||||||
|
*/
|
||||||
|
public void addPoint(final T point) {
|
||||||
|
points.add(point);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the points contained in the cluster.
|
||||||
|
* @return points contained in the cluster
|
||||||
|
*/
|
||||||
|
public List<T> getPoints() {
|
||||||
|
return points;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for n-dimensional points that can be clustered together.
|
||||||
|
* @version $Id$
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public interface Clusterable {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the n-dimensional point.
|
||||||
|
*
|
||||||
|
* @return the point array
|
||||||
|
*/
|
||||||
|
double[] getPoint();
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.exception.ConvergenceException;
|
||||||
|
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||||
|
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base class for clustering algorithms.
|
||||||
|
*
|
||||||
|
* @param <T> the type of points that can be clustered
|
||||||
|
* @version $Id $
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public abstract class Clusterer<T extends Clusterable> {
|
||||||
|
|
||||||
|
/** The distance measure to use. */
|
||||||
|
private DistanceMeasure measure;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a new clusterer with the given {@link DistanceMeasure}.
|
||||||
|
*
|
||||||
|
* @param measure the distance measure to use
|
||||||
|
*/
|
||||||
|
protected Clusterer(final DistanceMeasure measure) {
|
||||||
|
this.measure = measure;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform a cluster analysis on the given set of {@link Clusterable} instances.
|
||||||
|
*
|
||||||
|
* @param points the set of {@link Clusterable} instances
|
||||||
|
* @return a {@link List} of clusters
|
||||||
|
* @throws MathIllegalArgumentException if points are null or the number of
|
||||||
|
* data points is not compatible with this clusterer
|
||||||
|
* @throws ConvergenceException if the algorithm has not yet converged after
|
||||||
|
* the maximum number of iterations has been exceeded
|
||||||
|
*/
|
||||||
|
public abstract List<? extends Cluster<T>> cluster(Collection<T> points)
|
||||||
|
throws MathIllegalArgumentException, ConvergenceException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link DistanceMeasure} instance used by this clusterer.
|
||||||
|
*
|
||||||
|
* @return the distance measure
|
||||||
|
*/
|
||||||
|
public DistanceMeasure getDistanceMeasure() {
|
||||||
|
return measure;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates the distance between two {@link Clusterable} instances
|
||||||
|
* with the configured {@link DistanceMeasure}.
|
||||||
|
*
|
||||||
|
* @param p1 the first clusterable
|
||||||
|
* @param p2 the second clusterable
|
||||||
|
* @return the distance between the two clusterables
|
||||||
|
*/
|
||||||
|
protected double distance(final Clusterable p1, final Clusterable p2) {
|
||||||
|
return measure.compute(p1.getPoint(), p2.getPoint());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,231 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.exception.NotPositiveException;
|
||||||
|
import org.apache.commons.math3.exception.NullArgumentException;
|
||||||
|
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||||
|
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||||
|
import org.apache.commons.math3.util.MathUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DBSCAN (density-based spatial clustering of applications with noise) algorithm.
|
||||||
|
* <p>
|
||||||
|
* The DBSCAN algorithm forms clusters based on the idea of density connectivity, i.e.
|
||||||
|
* a point p is density connected to another point q, if there exists a chain of
|
||||||
|
* points p<sub>i</sub>, with i = 1 .. n and p<sub>1</sub> = p and p<sub>n</sub> = q,
|
||||||
|
* such that each pair <p<sub>i</sub>, p<sub>i+1</sub>> is directly density-reachable.
|
||||||
|
* A point q is directly density-reachable from point p if it is in the ε-neighborhood
|
||||||
|
* of this point.
|
||||||
|
* <p>
|
||||||
|
* Any point that is not density-reachable from a formed cluster is treated as noise, and
|
||||||
|
* will thus not be present in the result.
|
||||||
|
* <p>
|
||||||
|
* The algorithm requires two parameters:
|
||||||
|
* <ul>
|
||||||
|
* <li>eps: the distance that defines the ε-neighborhood of a point
|
||||||
|
* <li>minPoints: the minimum number of density-connected points required to form a cluster
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param <T> type of the points to cluster
|
||||||
|
* @see <a href="http://en.wikipedia.org/wiki/DBSCAN">DBSCAN (wikipedia)</a>
|
||||||
|
* @see <a href="http://www.dbs.ifi.lmu.de/Publikationen/Papers/KDD-96.final.frame.pdf">
|
||||||
|
* A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise</a>
|
||||||
|
* @version $Id$
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public class DBSCANClusterer<T extends Clusterable> extends Clusterer<T> {
|
||||||
|
|
||||||
|
/** Maximum radius of the neighborhood to be considered. */
|
||||||
|
private final double eps;
|
||||||
|
|
||||||
|
/** Minimum number of points needed for a cluster. */
|
||||||
|
private final int minPts;
|
||||||
|
|
||||||
|
/** Status of a point during the clustering process. */
|
||||||
|
private enum PointStatus {
|
||||||
|
/** The point has is considered to be noise. */
|
||||||
|
NOISE,
|
||||||
|
/** The point is already part of a cluster. */
|
||||||
|
PART_OF_CLUSTER
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new instance of a DBSCANClusterer.
|
||||||
|
*
|
||||||
|
* @param eps maximum radius of the neighborhood to be considered
|
||||||
|
* @param minPts minimum number of points needed for a cluster
|
||||||
|
* @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0}
|
||||||
|
*/
|
||||||
|
public DBSCANClusterer(final double eps, final int minPts)
|
||||||
|
throws NotPositiveException {
|
||||||
|
this(eps, minPts, new EuclideanDistance());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new instance of a DBSCANClusterer.
|
||||||
|
*
|
||||||
|
* @param eps maximum radius of the neighborhood to be considered
|
||||||
|
* @param minPts minimum number of points needed for a cluster
|
||||||
|
* @param measure the distance measure to use
|
||||||
|
* @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0}
|
||||||
|
*/
|
||||||
|
public DBSCANClusterer(final double eps, final int minPts, final DistanceMeasure measure)
|
||||||
|
throws NotPositiveException {
|
||||||
|
super(measure);
|
||||||
|
|
||||||
|
if (eps < 0.0d) {
|
||||||
|
throw new NotPositiveException(eps);
|
||||||
|
}
|
||||||
|
if (minPts < 0) {
|
||||||
|
throw new NotPositiveException(minPts);
|
||||||
|
}
|
||||||
|
this.eps = eps;
|
||||||
|
this.minPts = minPts;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the maximum radius of the neighborhood to be considered.
|
||||||
|
* @return maximum radius of the neighborhood
|
||||||
|
*/
|
||||||
|
public double getEps() {
|
||||||
|
return eps;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the minimum number of points needed for a cluster.
|
||||||
|
* @return minimum number of points needed for a cluster
|
||||||
|
*/
|
||||||
|
public int getMinPts() {
|
||||||
|
return minPts;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs DBSCAN cluster analysis.
|
||||||
|
*
|
||||||
|
* @param points the points to cluster
|
||||||
|
* @return the list of clusters
|
||||||
|
* @throws NullArgumentException if the data points are null
|
||||||
|
*/
|
||||||
|
public List<Cluster<T>> cluster(final Collection<T> points) throws NullArgumentException {
|
||||||
|
|
||||||
|
// sanity checks
|
||||||
|
MathUtils.checkNotNull(points);
|
||||||
|
|
||||||
|
final List<Cluster<T>> clusters = new ArrayList<Cluster<T>>();
|
||||||
|
final Map<Clusterable, PointStatus> visited = new HashMap<Clusterable, PointStatus>();
|
||||||
|
|
||||||
|
for (final T point : points) {
|
||||||
|
if (visited.get(point) != null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
final List<T> neighbors = getNeighbors(point, points);
|
||||||
|
if (neighbors.size() >= minPts) {
|
||||||
|
// DBSCAN does not care about center points
|
||||||
|
final Cluster<T> cluster = new Cluster<T>();
|
||||||
|
clusters.add(expandCluster(cluster, point, neighbors, points, visited));
|
||||||
|
} else {
|
||||||
|
visited.put(point, PointStatus.NOISE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return clusters;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Expands the cluster to include density-reachable items.
|
||||||
|
*
|
||||||
|
* @param cluster Cluster to expand
|
||||||
|
* @param point Point to add to cluster
|
||||||
|
* @param neighbors List of neighbors
|
||||||
|
* @param points the data set
|
||||||
|
* @param visited the set of already visited points
|
||||||
|
* @return the expanded cluster
|
||||||
|
*/
|
||||||
|
private Cluster<T> expandCluster(final Cluster<T> cluster,
|
||||||
|
final T point,
|
||||||
|
final List<T> neighbors,
|
||||||
|
final Collection<T> points,
|
||||||
|
final Map<Clusterable, PointStatus> visited) {
|
||||||
|
cluster.addPoint(point);
|
||||||
|
visited.put(point, PointStatus.PART_OF_CLUSTER);
|
||||||
|
|
||||||
|
List<T> seeds = new ArrayList<T>(neighbors);
|
||||||
|
int index = 0;
|
||||||
|
while (index < seeds.size()) {
|
||||||
|
final T current = seeds.get(index);
|
||||||
|
PointStatus pStatus = visited.get(current);
|
||||||
|
// only check non-visited points
|
||||||
|
if (pStatus == null) {
|
||||||
|
final List<T> currentNeighbors = getNeighbors(current, points);
|
||||||
|
if (currentNeighbors.size() >= minPts) {
|
||||||
|
seeds = merge(seeds, currentNeighbors);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pStatus != PointStatus.PART_OF_CLUSTER) {
|
||||||
|
visited.put(current, PointStatus.PART_OF_CLUSTER);
|
||||||
|
cluster.addPoint(current);
|
||||||
|
}
|
||||||
|
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
return cluster;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a list of density-reachable neighbors of a {@code point}.
|
||||||
|
*
|
||||||
|
* @param point the point to look for
|
||||||
|
* @param points possible neighbors
|
||||||
|
* @return the List of neighbors
|
||||||
|
*/
|
||||||
|
private List<T> getNeighbors(final T point, final Collection<T> points) {
|
||||||
|
final List<T> neighbors = new ArrayList<T>();
|
||||||
|
for (final T neighbor : points) {
|
||||||
|
if (point != neighbor && distance(neighbor, point) <= eps) {
|
||||||
|
neighbors.add(neighbor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return neighbors;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Merges two lists together.
|
||||||
|
*
|
||||||
|
* @param one first list
|
||||||
|
* @param two second list
|
||||||
|
* @return merged lists
|
||||||
|
*/
|
||||||
|
private List<T> merge(final List<T> one, final List<T> two) {
|
||||||
|
final Set<T> oneSet = new HashSet<T>(one);
|
||||||
|
for (T item : two) {
|
||||||
|
if (!oneSet.contains(item)) {
|
||||||
|
one.add(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return one;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,87 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple implementation of {@link Clusterable} for points with double coordinates.
|
||||||
|
* @version $Id$
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public class DoublePoint implements Clusterable, Serializable {
|
||||||
|
|
||||||
|
/** Serializable version identifier. */
|
||||||
|
private static final long serialVersionUID = 3946024775784901369L;
|
||||||
|
|
||||||
|
/** Point coordinates. */
|
||||||
|
private final double[] point;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build an instance wrapping an double array.
|
||||||
|
* <p>
|
||||||
|
* The wrapped array is referenced, it is <em>not</em> copied.
|
||||||
|
*
|
||||||
|
* @param point the n-dimensional point in double space
|
||||||
|
*/
|
||||||
|
public DoublePoint(final double[] point) {
|
||||||
|
this.point = point;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build an instance wrapping an integer array.
|
||||||
|
* <p>
|
||||||
|
* The wrapped array is copied to an internal double array.
|
||||||
|
*
|
||||||
|
* @param point the n-dimensional point in integer space
|
||||||
|
*/
|
||||||
|
public DoublePoint(final int[] point) {
|
||||||
|
this.point = new double[point.length];
|
||||||
|
for ( int i = 0; i < point.length; i++) {
|
||||||
|
this.point[i] = point[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** {@inheritDoc} */
|
||||||
|
public double[] getPoint() {
|
||||||
|
return point;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** {@inheritDoc} */
|
||||||
|
@Override
|
||||||
|
public boolean equals(final Object other) {
|
||||||
|
if (!(other instanceof DoublePoint)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return Arrays.equals(point, ((DoublePoint) other).point);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** {@inheritDoc} */
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Arrays.hashCode(point);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** {@inheritDoc} */
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return Arrays.toString(point);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,561 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.exception.ConvergenceException;
|
||||||
|
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||||
|
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||||
|
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||||
|
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||||
|
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.apache.commons.math3.stat.descriptive.moment.Variance;
|
||||||
|
import org.apache.commons.math3.util.MathUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
|
||||||
|
* @param <T> type of the points to cluster
|
||||||
|
* @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
|
||||||
|
* @version $Id$
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
|
||||||
|
|
||||||
|
/** Strategies to use for replacing an empty cluster. */
|
||||||
|
public static enum EmptyClusterStrategy {
|
||||||
|
|
||||||
|
/** Split the cluster with largest distance variance. */
|
||||||
|
LARGEST_VARIANCE,
|
||||||
|
|
||||||
|
/** Split the cluster with largest number of points. */
|
||||||
|
LARGEST_POINTS_NUMBER,
|
||||||
|
|
||||||
|
/** Create a cluster around the point farthest from its centroid. */
|
||||||
|
FARTHEST_POINT,
|
||||||
|
|
||||||
|
/** Generate an error. */
|
||||||
|
ERROR
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The number of clusters. */
|
||||||
|
private final int k;
|
||||||
|
|
||||||
|
/** The maximum number of iterations. */
|
||||||
|
private final int maxIterations;
|
||||||
|
|
||||||
|
/** Random generator for choosing initial centers. */
|
||||||
|
private final RandomGenerator random;
|
||||||
|
|
||||||
|
/** Selected strategy for empty clusters. */
|
||||||
|
private final EmptyClusterStrategy emptyStrategy;
|
||||||
|
|
||||||
|
/** Build a clusterer.
|
||||||
|
* <p>
|
||||||
|
* The default strategy for handling empty clusters that may appear during
|
||||||
|
* algorithm iterations is to split the cluster with largest distance variance.
|
||||||
|
*
|
||||||
|
* @param k the number of clusters to split the data into
|
||||||
|
*/
|
||||||
|
public KMeansPlusPlusClusterer(final int k) {
|
||||||
|
this(k, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Build a clusterer.
|
||||||
|
* <p>
|
||||||
|
* The default strategy for handling empty clusters that may appear during
|
||||||
|
* algorithm iterations is to split the cluster with largest distance variance.
|
||||||
|
*
|
||||||
|
* @param k the number of clusters to split the data into
|
||||||
|
* @param maxIterations the maximum number of iterations to run the algorithm for.
|
||||||
|
* If negative, no maximum will be used.
|
||||||
|
*/
|
||||||
|
public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
|
||||||
|
this(k, maxIterations, new EuclideanDistance());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Build a clusterer.
|
||||||
|
* <p>
|
||||||
|
* The default strategy for handling empty clusters that may appear during
|
||||||
|
* algorithm iterations is to split the cluster with largest distance variance.
|
||||||
|
*
|
||||||
|
* @param k the number of clusters to split the data into
|
||||||
|
* @param maxIterations the maximum number of iterations to run the algorithm for.
|
||||||
|
* If negative, no maximum will be used.
|
||||||
|
* @param measure the distance measure to use
|
||||||
|
*/
|
||||||
|
public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
|
||||||
|
this(k, maxIterations, measure, new JDKRandomGenerator());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Build a clusterer.
|
||||||
|
* <p>
|
||||||
|
* The default strategy for handling empty clusters that may appear during
|
||||||
|
* algorithm iterations is to split the cluster with largest distance variance.
|
||||||
|
*
|
||||||
|
* @param k the number of clusters to split the data into
|
||||||
|
* @param maxIterations the maximum number of iterations to run the algorithm for.
|
||||||
|
* If negative, no maximum will be used.
|
||||||
|
* @param measure the distance measure to use
|
||||||
|
* @param random random generator to use for choosing initial centers
|
||||||
|
*/
|
||||||
|
public KMeansPlusPlusClusterer(final int k, final int maxIterations,
|
||||||
|
final DistanceMeasure measure,
|
||||||
|
final RandomGenerator random) {
|
||||||
|
this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Build a clusterer.
|
||||||
|
*
|
||||||
|
* @param k the number of clusters to split the data into
|
||||||
|
* @param maxIterations the maximum number of iterations to run the algorithm for.
|
||||||
|
* If negative, no maximum will be used.
|
||||||
|
* @param measure the distance measure to use
|
||||||
|
* @param random random generator to use for choosing initial centers
|
||||||
|
* @param emptyStrategy strategy to use for handling empty clusters that
|
||||||
|
* may appear during algorithm iterations
|
||||||
|
*/
|
||||||
|
public KMeansPlusPlusClusterer(final int k, final int maxIterations,
|
||||||
|
final DistanceMeasure measure,
|
||||||
|
final RandomGenerator random,
|
||||||
|
final EmptyClusterStrategy emptyStrategy) {
|
||||||
|
super(measure);
|
||||||
|
this.k = k;
|
||||||
|
this.maxIterations = maxIterations;
|
||||||
|
this.random = random;
|
||||||
|
this.emptyStrategy = emptyStrategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the number of clusters this instance will use.
|
||||||
|
* @return the number of clusters
|
||||||
|
*/
|
||||||
|
public int getK() {
|
||||||
|
return k;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the maximum number of iterations this instance will use.
|
||||||
|
* @return the maximum number of iterations, or -1 if no maximum is set
|
||||||
|
*/
|
||||||
|
public int getMaxIterations() {
|
||||||
|
return maxIterations;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the random generator this instance will use.
|
||||||
|
* @return the random generator
|
||||||
|
*/
|
||||||
|
public RandomGenerator getRandomGenerator() {
|
||||||
|
return random;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link EmptyClusterStrategy} used by this instance.
|
||||||
|
* @return the {@link EmptyClusterStrategy}
|
||||||
|
*/
|
||||||
|
public EmptyClusterStrategy getEmptyClusterStrategy() {
|
||||||
|
return emptyStrategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs the K-means++ clustering algorithm.
|
||||||
|
*
|
||||||
|
* @param points the points to cluster
|
||||||
|
* @return a list of clusters containing the points
|
||||||
|
* @throws MathIllegalArgumentException if the data points are null or the number
|
||||||
|
* of clusters is larger than the number of data points
|
||||||
|
* @throws ConvergenceException if an empty cluster is encountered and the
|
||||||
|
* {@link #emptyStrategy} is set to {@code ERROR}
|
||||||
|
*/
|
||||||
|
public List<CentroidCluster<T>> cluster(final Collection<T> points)
|
||||||
|
throws MathIllegalArgumentException, ConvergenceException {
|
||||||
|
|
||||||
|
// sanity checks
|
||||||
|
MathUtils.checkNotNull(points);
|
||||||
|
|
||||||
|
// number of clusters has to be smaller or equal the number of data points
|
||||||
|
if (points.size() < k) {
|
||||||
|
throw new NumberIsTooSmallException(points.size(), k, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the initial clusters
|
||||||
|
List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
|
||||||
|
|
||||||
|
// create an array containing the latest assignment of a point to a cluster
|
||||||
|
// no need to initialize the array, as it will be filled with the first assignment
|
||||||
|
int[] assignments = new int[points.size()];
|
||||||
|
assignPointsToClusters(clusters, points, assignments);
|
||||||
|
|
||||||
|
// iterate through updating the centers until we're done
|
||||||
|
final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
|
||||||
|
for (int count = 0; count < max; count++) {
|
||||||
|
boolean emptyCluster = false;
|
||||||
|
List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>();
|
||||||
|
for (final CentroidCluster<T> cluster : clusters) {
|
||||||
|
final Clusterable newCenter;
|
||||||
|
if (cluster.getPoints().isEmpty()) {
|
||||||
|
switch (emptyStrategy) {
|
||||||
|
case LARGEST_VARIANCE :
|
||||||
|
newCenter = getPointFromLargestVarianceCluster(clusters);
|
||||||
|
break;
|
||||||
|
case LARGEST_POINTS_NUMBER :
|
||||||
|
newCenter = getPointFromLargestNumberCluster(clusters);
|
||||||
|
break;
|
||||||
|
case FARTHEST_POINT :
|
||||||
|
newCenter = getFarthestPoint(clusters);
|
||||||
|
break;
|
||||||
|
default :
|
||||||
|
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||||
|
}
|
||||||
|
emptyCluster = true;
|
||||||
|
} else {
|
||||||
|
newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
|
||||||
|
}
|
||||||
|
newClusters.add(new CentroidCluster<T>(newCenter));
|
||||||
|
}
|
||||||
|
int changes = assignPointsToClusters(newClusters, points, assignments);
|
||||||
|
clusters = newClusters;
|
||||||
|
|
||||||
|
// if there were no more changes in the point-to-cluster assignment
|
||||||
|
// and there are no empty clusters left, return the current clusters
|
||||||
|
if (changes == 0 && !emptyCluster) {
|
||||||
|
return clusters;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return clusters;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the given points to the closest {@link Cluster}.
|
||||||
|
*
|
||||||
|
* @param clusters the {@link Cluster}s to add the points to
|
||||||
|
* @param points the points to add to the given {@link Cluster}s
|
||||||
|
* @param assignments points assignments to clusters
|
||||||
|
* @return the number of points assigned to different clusters as the iteration before
|
||||||
|
*/
|
||||||
|
private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
|
||||||
|
final Collection<T> points,
|
||||||
|
final int[] assignments) {
|
||||||
|
int assignedDifferently = 0;
|
||||||
|
int pointIndex = 0;
|
||||||
|
for (final T p : points) {
|
||||||
|
int clusterIndex = getNearestCluster(clusters, p);
|
||||||
|
if (clusterIndex != assignments[pointIndex]) {
|
||||||
|
assignedDifferently++;
|
||||||
|
}
|
||||||
|
|
||||||
|
CentroidCluster<T> cluster = clusters.get(clusterIndex);
|
||||||
|
cluster.addPoint(p);
|
||||||
|
assignments[pointIndex++] = clusterIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
return assignedDifferently;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use K-means++ to choose the initial centers.
|
||||||
|
*
|
||||||
|
* @param points the points to choose the initial centers from
|
||||||
|
* @return the initial centers
|
||||||
|
*/
|
||||||
|
private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
|
||||||
|
|
||||||
|
// Convert to list for indexed access. Make it unmodifiable, since removal of items
|
||||||
|
// would screw up the logic of this method.
|
||||||
|
final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
|
||||||
|
|
||||||
|
// The number of points in the list.
|
||||||
|
final int numPoints = pointList.size();
|
||||||
|
|
||||||
|
// Set the corresponding element in this array to indicate when
|
||||||
|
// elements of pointList are no longer available.
|
||||||
|
final boolean[] taken = new boolean[numPoints];
|
||||||
|
|
||||||
|
// The resulting list of initial centers.
|
||||||
|
final List<CentroidCluster<T>> resultSet = new ArrayList<CentroidCluster<T>>();
|
||||||
|
|
||||||
|
// Choose one center uniformly at random from among the data points.
|
||||||
|
final int firstPointIndex = random.nextInt(numPoints);
|
||||||
|
|
||||||
|
final T firstPoint = pointList.get(firstPointIndex);
|
||||||
|
|
||||||
|
resultSet.add(new CentroidCluster<T>(firstPoint));
|
||||||
|
|
||||||
|
// Must mark it as taken
|
||||||
|
taken[firstPointIndex] = true;
|
||||||
|
|
||||||
|
// To keep track of the minimum distance squared of elements of
|
||||||
|
// pointList to elements of resultSet.
|
||||||
|
final double[] minDistSquared = new double[numPoints];
|
||||||
|
|
||||||
|
// Initialize the elements. Since the only point in resultSet is firstPoint,
|
||||||
|
// this is very easy.
|
||||||
|
for (int i = 0; i < numPoints; i++) {
|
||||||
|
if (i != firstPointIndex) { // That point isn't considered
|
||||||
|
double d = distance(firstPoint, pointList.get(i));
|
||||||
|
minDistSquared[i] = d*d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (resultSet.size() < k) {
|
||||||
|
|
||||||
|
// Sum up the squared distances for the points in pointList not
|
||||||
|
// already taken.
|
||||||
|
double distSqSum = 0.0;
|
||||||
|
|
||||||
|
for (int i = 0; i < numPoints; i++) {
|
||||||
|
if (!taken[i]) {
|
||||||
|
distSqSum += minDistSquared[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add one new data point as a center. Each point x is chosen with
|
||||||
|
// probability proportional to D(x)2
|
||||||
|
final double r = random.nextDouble() * distSqSum;
|
||||||
|
|
||||||
|
// The index of the next point to be added to the resultSet.
|
||||||
|
int nextPointIndex = -1;
|
||||||
|
|
||||||
|
// Sum through the squared min distances again, stopping when
|
||||||
|
// sum >= r.
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int i = 0; i < numPoints; i++) {
|
||||||
|
if (!taken[i]) {
|
||||||
|
sum += minDistSquared[i];
|
||||||
|
if (sum >= r) {
|
||||||
|
nextPointIndex = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's not set to >= 0, the point wasn't found in the previous
|
||||||
|
// for loop, probably because distances are extremely small. Just pick
|
||||||
|
// the last available point.
|
||||||
|
if (nextPointIndex == -1) {
|
||||||
|
for (int i = numPoints - 1; i >= 0; i--) {
|
||||||
|
if (!taken[i]) {
|
||||||
|
nextPointIndex = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We found one.
|
||||||
|
if (nextPointIndex >= 0) {
|
||||||
|
|
||||||
|
final T p = pointList.get(nextPointIndex);
|
||||||
|
|
||||||
|
resultSet.add(new CentroidCluster<T> (p));
|
||||||
|
|
||||||
|
// Mark it as taken.
|
||||||
|
taken[nextPointIndex] = true;
|
||||||
|
|
||||||
|
if (resultSet.size() < k) {
|
||||||
|
// Now update elements of minDistSquared. We only have to compute
|
||||||
|
// the distance to the new center to do this.
|
||||||
|
for (int j = 0; j < numPoints; j++) {
|
||||||
|
// Only have to worry about the points still not taken.
|
||||||
|
if (!taken[j]) {
|
||||||
|
double d = distance(p, pointList.get(j));
|
||||||
|
double d2 = d * d;
|
||||||
|
if (d2 < minDistSquared[j]) {
|
||||||
|
minDistSquared[j] = d2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// None found --
|
||||||
|
// Break from the while loop to prevent
|
||||||
|
// an infinite loop.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultSet;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a random point from the {@link Cluster} with the largest distance variance.
|
||||||
|
*
|
||||||
|
* @param clusters the {@link Cluster}s to search
|
||||||
|
* @return a random point from the selected cluster
|
||||||
|
* @throws ConvergenceException if clusters are all empty
|
||||||
|
*/
|
||||||
|
private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
|
||||||
|
throws ConvergenceException {
|
||||||
|
|
||||||
|
double maxVariance = Double.NEGATIVE_INFINITY;
|
||||||
|
Cluster<T> selected = null;
|
||||||
|
for (final CentroidCluster<T> cluster : clusters) {
|
||||||
|
if (!cluster.getPoints().isEmpty()) {
|
||||||
|
|
||||||
|
// compute the distance variance of the current cluster
|
||||||
|
final Clusterable center = cluster.getCenter();
|
||||||
|
final Variance stat = new Variance();
|
||||||
|
for (final T point : cluster.getPoints()) {
|
||||||
|
stat.increment(distance(point, center));
|
||||||
|
}
|
||||||
|
final double variance = stat.getResult();
|
||||||
|
|
||||||
|
// select the cluster with the largest variance
|
||||||
|
if (variance > maxVariance) {
|
||||||
|
maxVariance = variance;
|
||||||
|
selected = cluster;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// did we find at least one non-empty cluster ?
|
||||||
|
if (selected == null) {
|
||||||
|
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract a random point from the cluster
|
||||||
|
final List<T> selectedPoints = selected.getPoints();
|
||||||
|
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a random point from the {@link Cluster} with the largest number of points
|
||||||
|
*
|
||||||
|
* @param clusters the {@link Cluster}s to search
|
||||||
|
* @return a random point from the selected cluster
|
||||||
|
* @throws ConvergenceException if clusters are all empty
|
||||||
|
*/
|
||||||
|
private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
|
||||||
|
throws ConvergenceException {
|
||||||
|
|
||||||
|
int maxNumber = 0;
|
||||||
|
Cluster<T> selected = null;
|
||||||
|
for (final Cluster<T> cluster : clusters) {
|
||||||
|
|
||||||
|
// get the number of points of the current cluster
|
||||||
|
final int number = cluster.getPoints().size();
|
||||||
|
|
||||||
|
// select the cluster with the largest number of points
|
||||||
|
if (number > maxNumber) {
|
||||||
|
maxNumber = number;
|
||||||
|
selected = cluster;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// did we find at least one non-empty cluster ?
|
||||||
|
if (selected == null) {
|
||||||
|
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract a random point from the cluster
|
||||||
|
final List<T> selectedPoints = selected.getPoints();
|
||||||
|
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the point farthest to its cluster center
|
||||||
|
*
|
||||||
|
* @param clusters the {@link Cluster}s to search
|
||||||
|
* @return point farthest to its cluster center
|
||||||
|
* @throws ConvergenceException if clusters are all empty
|
||||||
|
*/
|
||||||
|
private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws ConvergenceException {
|
||||||
|
|
||||||
|
double maxDistance = Double.NEGATIVE_INFINITY;
|
||||||
|
Cluster<T> selectedCluster = null;
|
||||||
|
int selectedPoint = -1;
|
||||||
|
for (final CentroidCluster<T> cluster : clusters) {
|
||||||
|
|
||||||
|
// get the farthest point
|
||||||
|
final Clusterable center = cluster.getCenter();
|
||||||
|
final List<T> points = cluster.getPoints();
|
||||||
|
for (int i = 0; i < points.size(); ++i) {
|
||||||
|
final double distance = distance(points.get(i), center);
|
||||||
|
if (distance > maxDistance) {
|
||||||
|
maxDistance = distance;
|
||||||
|
selectedCluster = cluster;
|
||||||
|
selectedPoint = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// did we find at least one non-empty cluster ?
|
||||||
|
if (selectedCluster == null) {
|
||||||
|
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
|
||||||
|
}
|
||||||
|
|
||||||
|
return selectedCluster.getPoints().remove(selectedPoint);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the nearest {@link Cluster} to the given point
|
||||||
|
*
|
||||||
|
* @param clusters the {@link Cluster}s to search
|
||||||
|
* @param point the point to find the nearest {@link Cluster} for
|
||||||
|
* @return the index of the nearest {@link Cluster} to the given point
|
||||||
|
*/
|
||||||
|
private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
|
||||||
|
double minDistance = Double.MAX_VALUE;
|
||||||
|
int clusterIndex = 0;
|
||||||
|
int minCluster = 0;
|
||||||
|
for (final CentroidCluster<T> c : clusters) {
|
||||||
|
final double distance = distance(point, c.getCenter());
|
||||||
|
if (distance < minDistance) {
|
||||||
|
minDistance = distance;
|
||||||
|
minCluster = clusterIndex;
|
||||||
|
}
|
||||||
|
clusterIndex++;
|
||||||
|
}
|
||||||
|
return minCluster;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the centroid for a set of points.
|
||||||
|
*
|
||||||
|
* @param points the set of points
|
||||||
|
* @param dimension the point dimension
|
||||||
|
* @return the computed centroid for the set of points
|
||||||
|
*/
|
||||||
|
private Clusterable centroidOf(final Collection<T> points, final int dimension) {
|
||||||
|
final double[] centroid = new double[dimension];
|
||||||
|
for (final T p : points) {
|
||||||
|
final double[] point = p.getPoint();
|
||||||
|
for (int i = 0; i < centroid.length; i++) {
|
||||||
|
centroid[i] += point[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < centroid.length; i++) {
|
||||||
|
centroid[i] /= points.size();
|
||||||
|
}
|
||||||
|
return new DoublePoint(centroid);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,121 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.exception.ConvergenceException;
|
||||||
|
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||||
|
import org.apache.commons.math3.stat.descriptive.moment.Variance;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A wrapper around a k-means++ clustering algorithm which performs multiple trials
|
||||||
|
* and returns the best solution.
|
||||||
|
* @param <T> type of the points to cluster
|
||||||
|
* @version $Id$
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
|
||||||
|
|
||||||
|
/** The underlying k-means clusterer. */
|
||||||
|
private final KMeansPlusPlusClusterer<T> clusterer;
|
||||||
|
|
||||||
|
/** The number of trial runs. */
|
||||||
|
private final int numTrials;
|
||||||
|
|
||||||
|
/** Build a clusterer.
|
||||||
|
* @param clusterer the k-means clusterer to use
|
||||||
|
* @param numTrials number of trial runs
|
||||||
|
*/
|
||||||
|
public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
|
||||||
|
final int numTrials) {
|
||||||
|
super(clusterer.getDistanceMeasure());
|
||||||
|
this.clusterer = clusterer;
|
||||||
|
this.numTrials = numTrials;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the embedded k-means clusterer used by this instance.
|
||||||
|
* @return the embedded clusterer
|
||||||
|
*/
|
||||||
|
public KMeansPlusPlusClusterer<T> getClusterer() {
|
||||||
|
return clusterer;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the number of trials this instance will do.
|
||||||
|
* @return the number of trials
|
||||||
|
*/
|
||||||
|
public int getNumTrials() {
|
||||||
|
return numTrials;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs the K-means++ clustering algorithm.
|
||||||
|
*
|
||||||
|
* @param points the points to cluster
|
||||||
|
* @return a list of clusters containing the points
|
||||||
|
* @throws MathIllegalArgumentException if the data points are null or the number
|
||||||
|
* of clusters is larger than the number of data points
|
||||||
|
* @throws ConvergenceException if an empty cluster is encountered and the
|
||||||
|
* {@link #emptyStrategy} is set to {@code ERROR}
|
||||||
|
*/
|
||||||
|
public List<CentroidCluster<T>> cluster(final Collection<T> points)
|
||||||
|
throws MathIllegalArgumentException, ConvergenceException {
|
||||||
|
|
||||||
|
// at first, we have not found any clusters list yet
|
||||||
|
List<CentroidCluster<T>> best = null;
|
||||||
|
double bestVarianceSum = Double.POSITIVE_INFINITY;
|
||||||
|
|
||||||
|
// do several clustering trials
|
||||||
|
for (int i = 0; i < numTrials; ++i) {
|
||||||
|
|
||||||
|
// compute a clusters list
|
||||||
|
List<CentroidCluster<T>> clusters = clusterer.cluster(points);
|
||||||
|
|
||||||
|
// compute the variance of the current list
|
||||||
|
double varianceSum = 0.0;
|
||||||
|
for (final CentroidCluster<T> cluster : clusters) {
|
||||||
|
if (!cluster.getPoints().isEmpty()) {
|
||||||
|
|
||||||
|
// compute the distance variance of the current cluster
|
||||||
|
final Clusterable center = cluster.getCenter();
|
||||||
|
final Variance stat = new Variance();
|
||||||
|
for (final T point : cluster.getPoints()) {
|
||||||
|
stat.increment(distance(point, center));
|
||||||
|
}
|
||||||
|
varianceSum += stat.getResult();
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (varianceSum <= bestVarianceSum) {
|
||||||
|
// this one is the best we have found so far, remember it
|
||||||
|
best = clusters;
|
||||||
|
bestVarianceSum = varianceSum;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the best clusters list found
|
||||||
|
return best;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* Clustering algorithms.
|
||||||
|
*/
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
|
@ -0,0 +1,27 @@
|
||||||
|
package org.apache.commons.math3.ml.distance;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.util.FastMath;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates the Canberra distance between two points.
|
||||||
|
*
|
||||||
|
* @version $Id $
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public class CanberraDistance implements DistanceMeasure {
|
||||||
|
|
||||||
|
/** Serializable version identifier. */
|
||||||
|
private static final long serialVersionUID = -6972277381587032228L;
|
||||||
|
|
||||||
|
/** {@inheritDoc} */
|
||||||
|
public double compute(double[] a, double[] b) {
|
||||||
|
double sum = 0;
|
||||||
|
for (int i = 0; i < a.length; i++) {
|
||||||
|
final double num = FastMath.abs(a[i] - b[i]);
|
||||||
|
final double denom = FastMath.abs(a[i]) + FastMath.abs(b[i]);
|
||||||
|
sum += num == 0.0 && denom == 0.0 ? 0.0 : num / denom;
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
package org.apache.commons.math3.ml.distance;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.util.MathArrays;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates the L<sub>∞</sub> (max of abs) distance between two points.
|
||||||
|
*
|
||||||
|
* @version $Id $
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public class ChebyshevDistance implements DistanceMeasure {
|
||||||
|
|
||||||
|
/** Serializable version identifier. */
|
||||||
|
private static final long serialVersionUID = -4694868171115238296L;
|
||||||
|
|
||||||
|
/** {@inheritDoc} */
|
||||||
|
public double compute(double[] a, double[] b) {
|
||||||
|
return MathArrays.distanceInf(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package org.apache.commons.math3.ml.distance;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for distance measures of n-dimensional vectors.
|
||||||
|
*
|
||||||
|
* @version $Id $
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public interface DistanceMeasure extends Serializable {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the distance between two n-dimensional vectors.
|
||||||
|
* <p>
|
||||||
|
* The two vectors are required to have the same dimension.
|
||||||
|
*
|
||||||
|
* @param a the first vector
|
||||||
|
* @param b the second vector
|
||||||
|
* @return the distance between the two vectors
|
||||||
|
*/
|
||||||
|
double compute(double[] a, double[] b);
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
package org.apache.commons.math3.ml.distance;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.util.MathArrays;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates the L<sub>2</sub> (Euclidean) distance between two points.
|
||||||
|
*
|
||||||
|
* @version $Id $
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public class EuclideanDistance implements DistanceMeasure {
|
||||||
|
|
||||||
|
/** Serializable version identifier. */
|
||||||
|
private static final long serialVersionUID = 1717556319784040040L;
|
||||||
|
|
||||||
|
/** {@inheritDoc} */
|
||||||
|
public double compute(double[] a, double[] b) {
|
||||||
|
return MathArrays.distance(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
package org.apache.commons.math3.ml.distance;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.util.MathArrays;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates the L<sub>1</sub> (sum of abs) distance between two points.
|
||||||
|
*
|
||||||
|
* @version $Id $
|
||||||
|
* @since 3.2
|
||||||
|
*/
|
||||||
|
public class ManhattanDistance implements DistanceMeasure {
|
||||||
|
|
||||||
|
/** Serializable version identifier. */
|
||||||
|
private static final long serialVersionUID = -9108154600539125566L;
|
||||||
|
|
||||||
|
/** {@inheritDoc} */
|
||||||
|
public double compute(double[] a, double[] b) {
|
||||||
|
return MathArrays.distance1(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* Common distance measures.
|
||||||
|
*/
|
||||||
|
package org.apache.commons.math3.ml.distance;
|
|
@ -0,0 +1,20 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* Base package for machine learning algorithms.
|
||||||
|
*/
|
||||||
|
package org.apache.commons.math3.ml;
|
|
@ -0,0 +1,190 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
||||||
|
import org.apache.commons.math3.exception.NullArgumentException;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class DBSCANClustererTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCluster() {
|
||||||
|
// Test data generated using: http://people.cs.nctu.edu.tw/~rsliang/dbscan/testdatagen.html
|
||||||
|
final DoublePoint[] points = new DoublePoint[] {
|
||||||
|
new DoublePoint(new double[] { 83.08303244924173, 58.83387754182331 }),
|
||||||
|
new DoublePoint(new double[] { 45.05445510940626, 23.469642649637535 }),
|
||||||
|
new DoublePoint(new double[] { 14.96417921432294, 69.0264096390456 }),
|
||||||
|
new DoublePoint(new double[] { 73.53189604333602, 34.896145021310076 }),
|
||||||
|
new DoublePoint(new double[] { 73.28498173551634, 33.96860806993209 }),
|
||||||
|
new DoublePoint(new double[] { 73.45828098873608, 33.92584423092194 }),
|
||||||
|
new DoublePoint(new double[] { 73.9657889183145, 35.73191006924026 }),
|
||||||
|
new DoublePoint(new double[] { 74.0074097183533, 36.81735596177168 }),
|
||||||
|
new DoublePoint(new double[] { 73.41247541410848, 34.27314856695011 }),
|
||||||
|
new DoublePoint(new double[] { 73.9156256353017, 36.83206791547127 }),
|
||||||
|
new DoublePoint(new double[] { 74.81499205809087, 37.15682749846019 }),
|
||||||
|
new DoublePoint(new double[] { 74.03144880081527, 37.57399178552441 }),
|
||||||
|
new DoublePoint(new double[] { 74.51870941207744, 38.674258946906775 }),
|
||||||
|
new DoublePoint(new double[] { 74.50754595105536, 35.58903978415765 }),
|
||||||
|
new DoublePoint(new double[] { 74.51322752749547, 36.030572259100154 }),
|
||||||
|
new DoublePoint(new double[] { 59.27900996617973, 46.41091720294207 }),
|
||||||
|
new DoublePoint(new double[] { 59.73744793841615, 46.20015558367595 }),
|
||||||
|
new DoublePoint(new double[] { 58.81134076672606, 45.71150126331486 }),
|
||||||
|
new DoublePoint(new double[] { 58.52225539437495, 47.416083617601544 }),
|
||||||
|
new DoublePoint(new double[] { 58.218626647023484, 47.36228902172297 }),
|
||||||
|
new DoublePoint(new double[] { 60.27139669447206, 46.606106348801404 }),
|
||||||
|
new DoublePoint(new double[] { 60.894962462363765, 46.976924697402865 }),
|
||||||
|
new DoublePoint(new double[] { 62.29048673878424, 47.66970563563518 }),
|
||||||
|
new DoublePoint(new double[] { 61.03857608977705, 46.212924720020965 }),
|
||||||
|
new DoublePoint(new double[] { 60.16916214139201, 45.18193661351688 }),
|
||||||
|
new DoublePoint(new double[] { 59.90036905976012, 47.555364347063005 }),
|
||||||
|
new DoublePoint(new double[] { 62.33003634144552, 47.83941489877179 }),
|
||||||
|
new DoublePoint(new double[] { 57.86035536718555, 47.31117930193432 }),
|
||||||
|
new DoublePoint(new double[] { 58.13715479685925, 48.985960494028404 }),
|
||||||
|
new DoublePoint(new double[] { 56.131923963548616, 46.8508904252667 }),
|
||||||
|
new DoublePoint(new double[] { 55.976329887053, 47.46384037658572 }),
|
||||||
|
new DoublePoint(new double[] { 56.23245975235477, 47.940035191131756 }),
|
||||||
|
new DoublePoint(new double[] { 58.51687048212625, 46.622885352699086 }),
|
||||||
|
new DoublePoint(new double[] { 57.85411081905477, 45.95394361577928 }),
|
||||||
|
new DoublePoint(new double[] { 56.445776311447844, 45.162093662656844 }),
|
||||||
|
new DoublePoint(new double[] { 57.36691949656233, 47.50097194337286 }),
|
||||||
|
new DoublePoint(new double[] { 58.243626387557015, 46.114052729681134 }),
|
||||||
|
new DoublePoint(new double[] { 56.27224595635198, 44.799080066150054 }),
|
||||||
|
new DoublePoint(new double[] { 57.606924816500396, 46.94291057763621 }),
|
||||||
|
new DoublePoint(new double[] { 30.18714230041951, 13.877149710431695 }),
|
||||||
|
new DoublePoint(new double[] { 30.449448810657486, 13.490778346545994 }),
|
||||||
|
new DoublePoint(new double[] { 30.295018390286714, 13.264889000216499 }),
|
||||||
|
new DoublePoint(new double[] { 30.160201832884923, 11.89278262341395 }),
|
||||||
|
new DoublePoint(new double[] { 31.341509791789576, 15.282655921997502 }),
|
||||||
|
new DoublePoint(new double[] { 31.68601630325429, 14.756873246748 }),
|
||||||
|
new DoublePoint(new double[] { 29.325963742565364, 12.097849250072613 }),
|
||||||
|
new DoublePoint(new double[] { 29.54820742388256, 13.613295356975868 }),
|
||||||
|
new DoublePoint(new double[] { 28.79359608888626, 10.36352064087987 }),
|
||||||
|
new DoublePoint(new double[] { 31.01284597092308, 12.788479208014905 }),
|
||||||
|
new DoublePoint(new double[] { 27.58509216737002, 11.47570110601373 }),
|
||||||
|
new DoublePoint(new double[] { 28.593799561727792, 10.780998203903437 }),
|
||||||
|
new DoublePoint(new double[] { 31.356105766724795, 15.080316198524088 }),
|
||||||
|
new DoublePoint(new double[] { 31.25948503636755, 13.674329151166603 }),
|
||||||
|
new DoublePoint(new double[] { 32.31590076372959, 14.95261758659035 }),
|
||||||
|
new DoublePoint(new double[] { 30.460413702763617, 15.88402809202671 }),
|
||||||
|
new DoublePoint(new double[] { 32.56178203062154, 14.586076852632686 }),
|
||||||
|
new DoublePoint(new double[] { 32.76138648530468, 16.239837325178087 }),
|
||||||
|
new DoublePoint(new double[] { 30.1829453331884, 14.709592407103628 }),
|
||||||
|
new DoublePoint(new double[] { 29.55088173528202, 15.0651247180067 }),
|
||||||
|
new DoublePoint(new double[] { 29.004155302187428, 14.089665298582986 }),
|
||||||
|
new DoublePoint(new double[] { 29.339624439831823, 13.29096065578051 }),
|
||||||
|
new DoublePoint(new double[] { 30.997460327576846, 14.551914158277214 }),
|
||||||
|
new DoublePoint(new double[] { 30.66784126125276, 16.269703107886016 })
|
||||||
|
};
|
||||||
|
|
||||||
|
final DBSCANClusterer<DoublePoint> transformer =
|
||||||
|
new DBSCANClusterer<DoublePoint>(2.0, 5);
|
||||||
|
final List<Cluster<DoublePoint>> clusters = transformer.cluster(Arrays.asList(points));
|
||||||
|
|
||||||
|
final List<DoublePoint> clusterOne =
|
||||||
|
Arrays.asList(points[3], points[4], points[5], points[6], points[7], points[8], points[9], points[10],
|
||||||
|
points[11], points[12], points[13], points[14]);
|
||||||
|
final List<DoublePoint> clusterTwo =
|
||||||
|
Arrays.asList(points[15], points[16], points[17], points[18], points[19], points[20], points[21],
|
||||||
|
points[22], points[23], points[24], points[25], points[26], points[27], points[28],
|
||||||
|
points[29], points[30], points[31], points[32], points[33], points[34], points[35],
|
||||||
|
points[36], points[37], points[38]);
|
||||||
|
final List<DoublePoint> clusterThree =
|
||||||
|
Arrays.asList(points[39], points[40], points[41], points[42], points[43], points[44], points[45],
|
||||||
|
points[46], points[47], points[48], points[49], points[50], points[51], points[52],
|
||||||
|
points[53], points[54], points[55], points[56], points[57], points[58], points[59],
|
||||||
|
points[60], points[61], points[62]);
|
||||||
|
|
||||||
|
boolean cluster1Found = false;
|
||||||
|
boolean cluster2Found = false;
|
||||||
|
boolean cluster3Found = false;
|
||||||
|
Assert.assertEquals(3, clusters.size());
|
||||||
|
for (final Cluster<DoublePoint> cluster : clusters) {
|
||||||
|
if (cluster.getPoints().containsAll(clusterOne)) {
|
||||||
|
cluster1Found = true;
|
||||||
|
}
|
||||||
|
if (cluster.getPoints().containsAll(clusterTwo)) {
|
||||||
|
cluster2Found = true;
|
||||||
|
}
|
||||||
|
if (cluster.getPoints().containsAll(clusterThree)) {
|
||||||
|
cluster3Found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Assert.assertTrue(cluster1Found);
|
||||||
|
Assert.assertTrue(cluster2Found);
|
||||||
|
Assert.assertTrue(cluster3Found);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSingleLink() {
|
||||||
|
final DoublePoint[] points = {
|
||||||
|
new DoublePoint(new int[] {10, 10}), // A
|
||||||
|
new DoublePoint(new int[] {12, 9}),
|
||||||
|
new DoublePoint(new int[] {10, 8}),
|
||||||
|
new DoublePoint(new int[] {8, 8}),
|
||||||
|
new DoublePoint(new int[] {8, 6}),
|
||||||
|
new DoublePoint(new int[] {7, 7}),
|
||||||
|
new DoublePoint(new int[] {5, 6}), // B
|
||||||
|
new DoublePoint(new int[] {14, 8}), // C
|
||||||
|
new DoublePoint(new int[] {7, 15}), // N - Noise, should not be present
|
||||||
|
new DoublePoint(new int[] {17, 8}), // D - single-link connected to C should not be present
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
final DBSCANClusterer<DoublePoint> clusterer = new DBSCANClusterer<DoublePoint>(3, 3);
|
||||||
|
List<Cluster<DoublePoint>> clusters = clusterer.cluster(Arrays.asList(points));
|
||||||
|
|
||||||
|
Assert.assertEquals(1, clusters.size());
|
||||||
|
|
||||||
|
final List<DoublePoint> clusterOne =
|
||||||
|
Arrays.asList(points[0], points[1], points[2], points[3], points[4], points[5], points[6], points[7]);
|
||||||
|
Assert.assertTrue(clusters.get(0).getPoints().containsAll(clusterOne));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetEps() {
|
||||||
|
final DBSCANClusterer<DoublePoint> transformer = new DBSCANClusterer<DoublePoint>(2.0, 5);
|
||||||
|
Assert.assertEquals(2.0, transformer.getEps(), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetMinPts() {
|
||||||
|
final DBSCANClusterer<DoublePoint> transformer = new DBSCANClusterer<DoublePoint>(2.0, 5);
|
||||||
|
Assert.assertEquals(5, transformer.getMinPts());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = MathIllegalArgumentException.class)
|
||||||
|
public void testNegativeEps() {
|
||||||
|
new DBSCANClusterer<DoublePoint>(-2.0, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = MathIllegalArgumentException.class)
|
||||||
|
public void testNegativeMinPts() {
|
||||||
|
new DBSCANClusterer<DoublePoint>(2.0, -5);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = NullArgumentException.class)
|
||||||
|
public void testNullDataset() {
|
||||||
|
DBSCANClusterer<DoublePoint> clusterer = new DBSCANClusterer<DoublePoint>(2.0, 5);
|
||||||
|
clusterer.cluster(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,191 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||||
|
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class KMeansPlusPlusClustererTest {
|
||||||
|
|
||||||
|
private RandomGenerator random;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() {
|
||||||
|
random = new JDKRandomGenerator();
|
||||||
|
random.setSeed(1746432956321l);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* JIRA: MATH-305
|
||||||
|
*
|
||||||
|
* Two points, one cluster, one iteration
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testPerformClusterAnalysisDegenerate() {
|
||||||
|
KMeansPlusPlusClusterer<DoublePoint> transformer =
|
||||||
|
new KMeansPlusPlusClusterer<DoublePoint>(1, 1);
|
||||||
|
|
||||||
|
DoublePoint[] points = new DoublePoint[] {
|
||||||
|
new DoublePoint(new int[] { 1959, 325100 }),
|
||||||
|
new DoublePoint(new int[] { 1960, 373200 }), };
|
||||||
|
List<? extends Cluster<DoublePoint>> clusters = transformer.cluster(Arrays.asList(points));
|
||||||
|
Assert.assertEquals(1, clusters.size());
|
||||||
|
Assert.assertEquals(2, (clusters.get(0).getPoints().size()));
|
||||||
|
DoublePoint pt1 = new DoublePoint(new int[] { 1959, 325100 });
|
||||||
|
DoublePoint pt2 = new DoublePoint(new int[] { 1960, 373200 });
|
||||||
|
Assert.assertTrue(clusters.get(0).getPoints().contains(pt1));
|
||||||
|
Assert.assertTrue(clusters.get(0).getPoints().contains(pt2));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCertainSpace() {
|
||||||
|
KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = {
|
||||||
|
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE,
|
||||||
|
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER,
|
||||||
|
KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT
|
||||||
|
};
|
||||||
|
for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) {
|
||||||
|
int numberOfVariables = 27;
|
||||||
|
// initialise testvalues
|
||||||
|
int position1 = 1;
|
||||||
|
int position2 = position1 + numberOfVariables;
|
||||||
|
int position3 = position2 + numberOfVariables;
|
||||||
|
int position4 = position3 + numberOfVariables;
|
||||||
|
// testvalues will be multiplied
|
||||||
|
int multiplier = 1000000;
|
||||||
|
|
||||||
|
DoublePoint[] breakingPoints = new DoublePoint[numberOfVariables];
|
||||||
|
// define the space which will break the cluster algorithm
|
||||||
|
for (int i = 0; i < numberOfVariables; i++) {
|
||||||
|
int points[] = { position1, position2, position3, position4 };
|
||||||
|
// multiply the values
|
||||||
|
for (int j = 0; j < points.length; j++) {
|
||||||
|
points[j] = points[j] * multiplier;
|
||||||
|
}
|
||||||
|
DoublePoint DoublePoint = new DoublePoint(points);
|
||||||
|
breakingPoints[i] = DoublePoint;
|
||||||
|
position1 = position1 + numberOfVariables;
|
||||||
|
position2 = position2 + numberOfVariables;
|
||||||
|
position3 = position3 + numberOfVariables;
|
||||||
|
position4 = position4 + numberOfVariables;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int n = 2; n < 27; ++n) {
|
||||||
|
KMeansPlusPlusClusterer<DoublePoint> transformer =
|
||||||
|
new KMeansPlusPlusClusterer<DoublePoint>(n, 100, new EuclideanDistance(), random, strategy);
|
||||||
|
|
||||||
|
List<? extends Cluster<DoublePoint>> clusters =
|
||||||
|
transformer.cluster(Arrays.asList(breakingPoints));
|
||||||
|
|
||||||
|
Assert.assertEquals(n, clusters.size());
|
||||||
|
int sum = 0;
|
||||||
|
for (Cluster<DoublePoint> cluster : clusters) {
|
||||||
|
sum += cluster.getPoints().size();
|
||||||
|
}
|
||||||
|
Assert.assertEquals(numberOfVariables, sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A helper class for testSmallDistances(). This class is similar to DoublePoint, but
|
||||||
|
* it defines a different distanceFrom() method that tends to return distances less than 1.
|
||||||
|
*/
|
||||||
|
private class CloseDistance extends EuclideanDistance {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double compute(double[] a, double[] b) {
|
||||||
|
return super.compute(a, b) * 0.001;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test points that are very close together. See issue MATH-546.
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testSmallDistances() {
|
||||||
|
// Create a bunch of CloseDoublePoints. Most are identical, but one is different by a
|
||||||
|
// small distance.
|
||||||
|
int[] repeatedArray = { 0 };
|
||||||
|
int[] uniqueArray = { 1 };
|
||||||
|
DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
|
||||||
|
DoublePoint uniquePoint = new DoublePoint(uniqueArray);
|
||||||
|
|
||||||
|
Collection<DoublePoint> points = new ArrayList<DoublePoint>();
|
||||||
|
final int NUM_REPEATED_POINTS = 10 * 1000;
|
||||||
|
for (int i = 0; i < NUM_REPEATED_POINTS; ++i) {
|
||||||
|
points.add(repeatedPoint);
|
||||||
|
}
|
||||||
|
points.add(uniquePoint);
|
||||||
|
|
||||||
|
// Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to simply choose initial
|
||||||
|
// cluster centers).
|
||||||
|
final long RANDOM_SEED = 0;
|
||||||
|
final int NUM_CLUSTERS = 2;
|
||||||
|
final int NUM_ITERATIONS = 0;
|
||||||
|
random.setSeed(RANDOM_SEED);
|
||||||
|
|
||||||
|
KMeansPlusPlusClusterer<DoublePoint> clusterer =
|
||||||
|
new KMeansPlusPlusClusterer<DoublePoint>(NUM_CLUSTERS, NUM_ITERATIONS,
|
||||||
|
new CloseDistance(), random);
|
||||||
|
List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
|
||||||
|
|
||||||
|
// Check that one of the chosen centers is the unique point.
|
||||||
|
boolean uniquePointIsCenter = false;
|
||||||
|
for (CentroidCluster<DoublePoint> cluster : clusters) {
|
||||||
|
if (cluster.getCenter().equals(uniquePoint)) {
|
||||||
|
uniquePointIsCenter = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Assert.assertTrue(uniquePointIsCenter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 2 variables cannot be clustered into 3 clusters. See issue MATH-436.
|
||||||
|
*/
|
||||||
|
@Test(expected=NumberIsTooSmallException.class)
|
||||||
|
public void testPerformClusterAnalysisToManyClusters() {
|
||||||
|
KMeansPlusPlusClusterer<DoublePoint> transformer =
|
||||||
|
new KMeansPlusPlusClusterer<DoublePoint>(3, 1, new EuclideanDistance(), random);
|
||||||
|
|
||||||
|
DoublePoint[] points = new DoublePoint[] {
|
||||||
|
new DoublePoint(new int[] {
|
||||||
|
1959, 325100
|
||||||
|
}), new DoublePoint(new int[] {
|
||||||
|
1960, 373200
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
transformer.cluster(Arrays.asList(points));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,98 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.commons.math3.ml.clustering;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class MultiKMeansPlusPlusClustererTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void dimension2() {
|
||||||
|
MultiKMeansPlusPlusClusterer<DoublePoint> transformer =
|
||||||
|
new MultiKMeansPlusPlusClusterer<DoublePoint>(
|
||||||
|
new KMeansPlusPlusClusterer<DoublePoint>(3, 10), 5);
|
||||||
|
|
||||||
|
DoublePoint[] points = new DoublePoint[] {
|
||||||
|
|
||||||
|
// first expected cluster
|
||||||
|
new DoublePoint(new int[] { -15, 3 }),
|
||||||
|
new DoublePoint(new int[] { -15, 4 }),
|
||||||
|
new DoublePoint(new int[] { -15, 5 }),
|
||||||
|
new DoublePoint(new int[] { -14, 3 }),
|
||||||
|
new DoublePoint(new int[] { -14, 5 }),
|
||||||
|
new DoublePoint(new int[] { -13, 3 }),
|
||||||
|
new DoublePoint(new int[] { -13, 4 }),
|
||||||
|
new DoublePoint(new int[] { -13, 5 }),
|
||||||
|
|
||||||
|
// second expected cluster
|
||||||
|
new DoublePoint(new int[] { -1, 0 }),
|
||||||
|
new DoublePoint(new int[] { -1, -1 }),
|
||||||
|
new DoublePoint(new int[] { 0, -1 }),
|
||||||
|
new DoublePoint(new int[] { 1, -1 }),
|
||||||
|
new DoublePoint(new int[] { 1, -2 }),
|
||||||
|
|
||||||
|
// third expected cluster
|
||||||
|
new DoublePoint(new int[] { 13, 3 }),
|
||||||
|
new DoublePoint(new int[] { 13, 4 }),
|
||||||
|
new DoublePoint(new int[] { 14, 4 }),
|
||||||
|
new DoublePoint(new int[] { 14, 7 }),
|
||||||
|
new DoublePoint(new int[] { 16, 5 }),
|
||||||
|
new DoublePoint(new int[] { 16, 6 }),
|
||||||
|
new DoublePoint(new int[] { 17, 4 }),
|
||||||
|
new DoublePoint(new int[] { 17, 7 })
|
||||||
|
|
||||||
|
};
|
||||||
|
List<CentroidCluster<DoublePoint>> clusters = transformer.cluster(Arrays.asList(points));
|
||||||
|
|
||||||
|
Assert.assertEquals(3, clusters.size());
|
||||||
|
boolean cluster1Found = false;
|
||||||
|
boolean cluster2Found = false;
|
||||||
|
boolean cluster3Found = false;
|
||||||
|
double epsilon = 1e-6;
|
||||||
|
for (CentroidCluster<DoublePoint> cluster : clusters) {
|
||||||
|
Clusterable center = cluster.getCenter();
|
||||||
|
double[] point = center.getPoint();
|
||||||
|
if (point[0] < 0) {
|
||||||
|
cluster1Found = true;
|
||||||
|
Assert.assertEquals(8, cluster.getPoints().size());
|
||||||
|
Assert.assertEquals(-14, point[0], epsilon);
|
||||||
|
Assert.assertEquals( 4, point[1], epsilon);
|
||||||
|
} else if (point[1] < 0) {
|
||||||
|
cluster2Found = true;
|
||||||
|
Assert.assertEquals(5, cluster.getPoints().size());
|
||||||
|
Assert.assertEquals( 0, point[0], epsilon);
|
||||||
|
Assert.assertEquals(-1, point[1], epsilon);
|
||||||
|
} else {
|
||||||
|
cluster3Found = true;
|
||||||
|
Assert.assertEquals(8, cluster.getPoints().size());
|
||||||
|
Assert.assertEquals(15, point[0], epsilon);
|
||||||
|
Assert.assertEquals(5, point[1], epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Assert.assertTrue(cluster1Found);
|
||||||
|
Assert.assertTrue(cluster2Found);
|
||||||
|
Assert.assertTrue(cluster3Found);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue