From 363e116e34ea657ae0664bde3beb02dda773d9ca Mon Sep 17 00:00:00 2001 From: Thomas Neidhart Date: Wed, 27 Mar 2013 21:48:10 +0000 Subject: [PATCH] [MATH-917] Refactored clustering package to include more distance measures. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1461862 13f79535-47bb-0310-9956-ffa450edef68 --- .../math3/ml/clustering/CentroidCluster.java | 38 ++ .../commons/math3/ml/clustering/Cluster.java | 61 ++ .../math3/ml/clustering/Clusterable.java | 33 ++ .../math3/ml/clustering/Clusterer.java | 65 ++ .../math3/ml/clustering/DBSCANClusterer.java | 231 ++++++++ .../math3/ml/clustering/DoublePoint.java | 87 +++ .../clustering/KMeansPlusPlusClusterer.java | 561 ++++++++++++++++++ .../MultiKMeansPlusPlusClusterer.java | 121 ++++ .../math3/ml/clustering/package-info.java | 20 + .../math3/ml/distance/CanberraDistance.java | 27 + .../math3/ml/distance/ChebyshevDistance.java | 21 + .../math3/ml/distance/DistanceMeasure.java | 23 + .../math3/ml/distance/EuclideanDistance.java | 21 + .../math3/ml/distance/ManhattanDistance.java | 21 + .../math3/ml/distance/package-info.java | 20 + .../apache/commons/math3/ml/package-info.java | 20 + .../ml/clustering/DBSCANClustererTest.java | 190 ++++++ .../KMeansPlusPlusClustererTest.java | 191 ++++++ .../MultiKMeansPlusPlusClustererTest.java | 98 +++ 19 files changed, 1849 insertions(+) create mode 100644 src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java create mode 100644 src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java create mode 100644 src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java create mode 100644 src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java create mode 100644 src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java create mode 100644 src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java create mode 100644 src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java create mode 100644 src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java create mode 100644 src/main/java/org/apache/commons/math3/ml/clustering/package-info.java create mode 100644 src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java create mode 100644 src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java create mode 100644 src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java create mode 100644 src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java create mode 100644 src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java create mode 100644 src/main/java/org/apache/commons/math3/ml/distance/package-info.java create mode 100644 src/main/java/org/apache/commons/math3/ml/package-info.java create mode 100644 src/test/java/org/apache/commons/math3/ml/clustering/DBSCANClustererTest.java create mode 100644 src/test/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClustererTest.java create mode 100644 src/test/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClustererTest.java diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java b/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java new file mode 100644 index 000000000..56a57adc2 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java @@ -0,0 +1,38 @@ +package org.apache.commons.math3.ml.clustering; + +/** + * A Cluster used by centroid-based clustering algorithms. + *

+ * Defines additionally a cluster center which may not necessarily be a member + * of the original data set. + * + * @param the type of points that can be clustered + * @version $Id $ + * @since 3.2 + */ +public class CentroidCluster extends Cluster { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -3075288519071812288L; + + /** Center of the cluster. */ + private final Clusterable center; + + /** + * Build a cluster centered at a specified point. + * @param center the point which is to be the center of this cluster + */ + public CentroidCluster(final Clusterable center) { + super(); + this.center = center; + } + + /** + * Get the point chosen to be the center of this cluster. + * @return chosen cluster center + */ + public Clusterable getCenter() { + return center; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java b/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java new file mode 100644 index 000000000..8523a4603 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * Cluster holding a set of {@link Clusterable} points. + * @param the type of points that can be clustered + * @version $Id$ + * @since 3.2 + */ +public class Cluster implements Serializable { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -3442297081515880464L; + + /** The points contained in this cluster. */ + private final List points; + + /** + * Build a cluster centered at a specified point. + */ + public Cluster() { + points = new ArrayList(); + } + + /** + * Add a point to this cluster. + * @param point point to add + */ + public void addPoint(final T point) { + points.add(point); + } + + /** + * Get the points contained in the cluster. + * @return points contained in the cluster + */ + public List getPoints() { + return points; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java new file mode 100644 index 000000000..c4883b8f0 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +/** + * Interface for n-dimensional points that can be clustered together. + * @version $Id$ + * @since 3.2 + */ +public interface Clusterable { + + /** + * Gets the n-dimensional point. + * + * @return the point array + */ + double[] getPoint(); +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java new file mode 100644 index 000000000..83e572c0a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java @@ -0,0 +1,65 @@ +package org.apache.commons.math3.ml.clustering; + +import java.util.Collection; +import java.util.List; + +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.ml.distance.DistanceMeasure; + +/** + * Base class for clustering algorithms. + * + * @param the type of points that can be clustered + * @version $Id $ + * @since 3.2 + */ +public abstract class Clusterer { + + /** The distance measure to use. */ + private DistanceMeasure measure; + + /** + * Build a new clusterer with the given {@link DistanceMeasure}. + * + * @param measure the distance measure to use + */ + protected Clusterer(final DistanceMeasure measure) { + this.measure = measure; + } + + /** + * Perform a cluster analysis on the given set of {@link Clusterable} instances. + * + * @param points the set of {@link Clusterable} instances + * @return a {@link List} of clusters + * @throws MathIllegalArgumentException if points are null or the number of + * data points is not compatible with this clusterer + * @throws ConvergenceException if the algorithm has not yet converged after + * the maximum number of iterations has been exceeded + */ + public abstract List> cluster(Collection points) + throws MathIllegalArgumentException, ConvergenceException; + + /** + * Returns the {@link DistanceMeasure} instance used by this clusterer. + * + * @return the distance measure + */ + public DistanceMeasure getDistanceMeasure() { + return measure; + } + + /** + * Calculates the distance between two {@link Clusterable} instances + * with the configured {@link DistanceMeasure}. + * + * @param p1 the first clusterable + * @param p2 the second clusterable + * @return the distance between the two clusterables + */ + protected double distance(final Clusterable p1, final Clusterable p2) { + return measure.compute(p1.getPoint(), p2.getPoint()); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java new file mode 100644 index 000000000..80ac5af3b --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.clustering; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.util.MathUtils; + +/** + * DBSCAN (density-based spatial clustering of applications with noise) algorithm. + *

+ * The DBSCAN algorithm forms clusters based on the idea of density connectivity, i.e. + * a point p is density connected to another point q, if there exists a chain of + * points pi, with i = 1 .. n and p1 = p and pn = q, + * such that each pair <pi, pi+1> is directly density-reachable. + * A point q is directly density-reachable from point p if it is in the ε-neighborhood + * of this point. + *

+ * Any point that is not density-reachable from a formed cluster is treated as noise, and + * will thus not be present in the result. + *

+ * The algorithm requires two parameters: + *

+ * + * @param type of the points to cluster + * @see DBSCAN (wikipedia) + * @see + * A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise + * @version $Id$ + * @since 3.2 + */ +public class DBSCANClusterer extends Clusterer { + + /** Maximum radius of the neighborhood to be considered. */ + private final double eps; + + /** Minimum number of points needed for a cluster. */ + private final int minPts; + + /** Status of a point during the clustering process. */ + private enum PointStatus { + /** The point has is considered to be noise. */ + NOISE, + /** The point is already part of a cluster. */ + PART_OF_CLUSTER + } + + /** + * Creates a new instance of a DBSCANClusterer. + * + * @param eps maximum radius of the neighborhood to be considered + * @param minPts minimum number of points needed for a cluster + * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0} + */ + public DBSCANClusterer(final double eps, final int minPts) + throws NotPositiveException { + this(eps, minPts, new EuclideanDistance()); + } + + /** + * Creates a new instance of a DBSCANClusterer. + * + * @param eps maximum radius of the neighborhood to be considered + * @param minPts minimum number of points needed for a cluster + * @param measure the distance measure to use + * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0} + */ + public DBSCANClusterer(final double eps, final int minPts, final DistanceMeasure measure) + throws NotPositiveException { + super(measure); + + if (eps < 0.0d) { + throw new NotPositiveException(eps); + } + if (minPts < 0) { + throw new NotPositiveException(minPts); + } + this.eps = eps; + this.minPts = minPts; + } + + /** + * Returns the maximum radius of the neighborhood to be considered. + * @return maximum radius of the neighborhood + */ + public double getEps() { + return eps; + } + + /** + * Returns the minimum number of points needed for a cluster. + * @return minimum number of points needed for a cluster + */ + public int getMinPts() { + return minPts; + } + + /** + * Performs DBSCAN cluster analysis. + * + * @param points the points to cluster + * @return the list of clusters + * @throws NullArgumentException if the data points are null + */ + public List> cluster(final Collection points) throws NullArgumentException { + + // sanity checks + MathUtils.checkNotNull(points); + + final List> clusters = new ArrayList>(); + final Map visited = new HashMap(); + + for (final T point : points) { + if (visited.get(point) != null) { + continue; + } + final List neighbors = getNeighbors(point, points); + if (neighbors.size() >= minPts) { + // DBSCAN does not care about center points + final Cluster cluster = new Cluster(); + clusters.add(expandCluster(cluster, point, neighbors, points, visited)); + } else { + visited.put(point, PointStatus.NOISE); + } + } + + return clusters; + } + + /** + * Expands the cluster to include density-reachable items. + * + * @param cluster Cluster to expand + * @param point Point to add to cluster + * @param neighbors List of neighbors + * @param points the data set + * @param visited the set of already visited points + * @return the expanded cluster + */ + private Cluster expandCluster(final Cluster cluster, + final T point, + final List neighbors, + final Collection points, + final Map visited) { + cluster.addPoint(point); + visited.put(point, PointStatus.PART_OF_CLUSTER); + + List seeds = new ArrayList(neighbors); + int index = 0; + while (index < seeds.size()) { + final T current = seeds.get(index); + PointStatus pStatus = visited.get(current); + // only check non-visited points + if (pStatus == null) { + final List currentNeighbors = getNeighbors(current, points); + if (currentNeighbors.size() >= minPts) { + seeds = merge(seeds, currentNeighbors); + } + } + + if (pStatus != PointStatus.PART_OF_CLUSTER) { + visited.put(current, PointStatus.PART_OF_CLUSTER); + cluster.addPoint(current); + } + + index++; + } + return cluster; + } + + /** + * Returns a list of density-reachable neighbors of a {@code point}. + * + * @param point the point to look for + * @param points possible neighbors + * @return the List of neighbors + */ + private List getNeighbors(final T point, final Collection points) { + final List neighbors = new ArrayList(); + for (final T neighbor : points) { + if (point != neighbor && distance(neighbor, point) <= eps) { + neighbors.add(neighbor); + } + } + return neighbors; + } + + /** + * Merges two lists together. + * + * @param one first list + * @param two second list + * @return merged lists + */ + private List merge(final List one, final List two) { + final Set oneSet = new HashSet(one); + for (T item : two) { + if (!oneSet.contains(item)) { + one.add(item); + } + } + return one; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java b/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java new file mode 100644 index 000000000..3177bfada --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * A simple implementation of {@link Clusterable} for points with double coordinates. + * @version $Id$ + * @since 3.2 + */ +public class DoublePoint implements Clusterable, Serializable { + + /** Serializable version identifier. */ + private static final long serialVersionUID = 3946024775784901369L; + + /** Point coordinates. */ + private final double[] point; + + /** + * Build an instance wrapping an double array. + *

+ * The wrapped array is referenced, it is not copied. + * + * @param point the n-dimensional point in double space + */ + public DoublePoint(final double[] point) { + this.point = point; + } + + /** + * Build an instance wrapping an integer array. + *

+ * The wrapped array is copied to an internal double array. + * + * @param point the n-dimensional point in integer space + */ + public DoublePoint(final int[] point) { + this.point = new double[point.length]; + for ( int i = 0; i < point.length; i++) { + this.point[i] = point[i]; + } + } + + /** {@inheritDoc} */ + public double[] getPoint() { + return point; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(final Object other) { + if (!(other instanceof DoublePoint)) { + return false; + } + return Arrays.equals(point, ((DoublePoint) other).point); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Arrays.hashCode(point); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return Arrays.toString(point); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java new file mode 100644 index 000000000..771dd1234 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java @@ -0,0 +1,561 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.apache.commons.math3.util.MathUtils; + +/** + * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. + * @param type of the points to cluster + * @see K-means++ (wikipedia) + * @version $Id$ + * @since 3.2 + */ +public class KMeansPlusPlusClusterer extends Clusterer { + + /** Strategies to use for replacing an empty cluster. */ + public static enum EmptyClusterStrategy { + + /** Split the cluster with largest distance variance. */ + LARGEST_VARIANCE, + + /** Split the cluster with largest number of points. */ + LARGEST_POINTS_NUMBER, + + /** Create a cluster around the point farthest from its centroid. */ + FARTHEST_POINT, + + /** Generate an error. */ + ERROR + + } + + /** The number of clusters. */ + private final int k; + + /** The maximum number of iterations. */ + private final int maxIterations; + + /** Random generator for choosing initial centers. */ + private final RandomGenerator random; + + /** Selected strategy for empty clusters. */ + private final EmptyClusterStrategy emptyStrategy; + + /** Build a clusterer. + *

+ * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * + * @param k the number of clusters to split the data into + */ + public KMeansPlusPlusClusterer(final int k) { + this(k, -1); + } + + /** Build a clusterer. + *

+ * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations) { + this(k, maxIterations, new EuclideanDistance()); + } + + /** Build a clusterer. + *

+ * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) { + this(k, maxIterations, measure, new JDKRandomGenerator()); + } + + /** Build a clusterer. + *

+ * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param random random generator to use for choosing initial centers + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations, + final DistanceMeasure measure, + final RandomGenerator random) { + this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE); + } + + /** Build a clusterer. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param random random generator to use for choosing initial centers + * @param emptyStrategy strategy to use for handling empty clusters that + * may appear during algorithm iterations + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations, + final DistanceMeasure measure, + final RandomGenerator random, + final EmptyClusterStrategy emptyStrategy) { + super(measure); + this.k = k; + this.maxIterations = maxIterations; + this.random = random; + this.emptyStrategy = emptyStrategy; + } + + /** + * Return the number of clusters this instance will use. + * @return the number of clusters + */ + public int getK() { + return k; + } + + /** + * Returns the maximum number of iterations this instance will use. + * @return the maximum number of iterations, or -1 if no maximum is set + */ + public int getMaxIterations() { + return maxIterations; + } + + /** + * Returns the random generator this instance will use. + * @return the random generator + */ + public RandomGenerator getRandomGenerator() { + return random; + } + + /** + * Returns the {@link EmptyClusterStrategy} used by this instance. + * @return the {@link EmptyClusterStrategy} + */ + public EmptyClusterStrategy getEmptyClusterStrategy() { + return emptyStrategy; + } + + /** + * Runs the K-means++ clustering algorithm. + * + * @param points the points to cluster + * @return a list of clusters containing the points + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + * @throws ConvergenceException if an empty cluster is encountered and the + * {@link #emptyStrategy} is set to {@code ERROR} + */ + public List> cluster(final Collection points) + throws MathIllegalArgumentException, ConvergenceException { + + // sanity checks + MathUtils.checkNotNull(points); + + // number of clusters has to be smaller or equal the number of data points + if (points.size() < k) { + throw new NumberIsTooSmallException(points.size(), k, false); + } + + // create the initial clusters + List> clusters = chooseInitialCenters(points); + + // create an array containing the latest assignment of a point to a cluster + // no need to initialize the array, as it will be filled with the first assignment + int[] assignments = new int[points.size()]; + assignPointsToClusters(clusters, points, assignments); + + // iterate through updating the centers until we're done + final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; + for (int count = 0; count < max; count++) { + boolean emptyCluster = false; + List> newClusters = new ArrayList>(); + for (final CentroidCluster cluster : clusters) { + final Clusterable newCenter; + if (cluster.getPoints().isEmpty()) { + switch (emptyStrategy) { + case LARGEST_VARIANCE : + newCenter = getPointFromLargestVarianceCluster(clusters); + break; + case LARGEST_POINTS_NUMBER : + newCenter = getPointFromLargestNumberCluster(clusters); + break; + case FARTHEST_POINT : + newCenter = getFarthestPoint(clusters); + break; + default : + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + emptyCluster = true; + } else { + newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length); + } + newClusters.add(new CentroidCluster(newCenter)); + } + int changes = assignPointsToClusters(newClusters, points, assignments); + clusters = newClusters; + + // if there were no more changes in the point-to-cluster assignment + // and there are no empty clusters left, return the current clusters + if (changes == 0 && !emptyCluster) { + return clusters; + } + } + return clusters; + } + + /** + * Adds the given points to the closest {@link Cluster}. + * + * @param clusters the {@link Cluster}s to add the points to + * @param points the points to add to the given {@link Cluster}s + * @param assignments points assignments to clusters + * @return the number of points assigned to different clusters as the iteration before + */ + private int assignPointsToClusters(final List> clusters, + final Collection points, + final int[] assignments) { + int assignedDifferently = 0; + int pointIndex = 0; + for (final T p : points) { + int clusterIndex = getNearestCluster(clusters, p); + if (clusterIndex != assignments[pointIndex]) { + assignedDifferently++; + } + + CentroidCluster cluster = clusters.get(clusterIndex); + cluster.addPoint(p); + assignments[pointIndex++] = clusterIndex; + } + + return assignedDifferently; + } + + /** + * Use K-means++ to choose the initial centers. + * + * @param points the points to choose the initial centers from + * @return the initial centers + */ + private List> chooseInitialCenters(final Collection points) { + + // Convert to list for indexed access. Make it unmodifiable, since removal of items + // would screw up the logic of this method. + final List pointList = Collections.unmodifiableList(new ArrayList (points)); + + // The number of points in the list. + final int numPoints = pointList.size(); + + // Set the corresponding element in this array to indicate when + // elements of pointList are no longer available. + final boolean[] taken = new boolean[numPoints]; + + // The resulting list of initial centers. + final List> resultSet = new ArrayList>(); + + // Choose one center uniformly at random from among the data points. + final int firstPointIndex = random.nextInt(numPoints); + + final T firstPoint = pointList.get(firstPointIndex); + + resultSet.add(new CentroidCluster(firstPoint)); + + // Must mark it as taken + taken[firstPointIndex] = true; + + // To keep track of the minimum distance squared of elements of + // pointList to elements of resultSet. + final double[] minDistSquared = new double[numPoints]; + + // Initialize the elements. Since the only point in resultSet is firstPoint, + // this is very easy. + for (int i = 0; i < numPoints; i++) { + if (i != firstPointIndex) { // That point isn't considered + double d = distance(firstPoint, pointList.get(i)); + minDistSquared[i] = d*d; + } + } + + while (resultSet.size() < k) { + + // Sum up the squared distances for the points in pointList not + // already taken. + double distSqSum = 0.0; + + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + distSqSum += minDistSquared[i]; + } + } + + // Add one new data point as a center. Each point x is chosen with + // probability proportional to D(x)2 + final double r = random.nextDouble() * distSqSum; + + // The index of the next point to be added to the resultSet. + int nextPointIndex = -1; + + // Sum through the squared min distances again, stopping when + // sum >= r. + double sum = 0.0; + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + sum += minDistSquared[i]; + if (sum >= r) { + nextPointIndex = i; + break; + } + } + } + + // If it's not set to >= 0, the point wasn't found in the previous + // for loop, probably because distances are extremely small. Just pick + // the last available point. + if (nextPointIndex == -1) { + for (int i = numPoints - 1; i >= 0; i--) { + if (!taken[i]) { + nextPointIndex = i; + break; + } + } + } + + // We found one. + if (nextPointIndex >= 0) { + + final T p = pointList.get(nextPointIndex); + + resultSet.add(new CentroidCluster (p)); + + // Mark it as taken. + taken[nextPointIndex] = true; + + if (resultSet.size() < k) { + // Now update elements of minDistSquared. We only have to compute + // the distance to the new center to do this. + for (int j = 0; j < numPoints; j++) { + // Only have to worry about the points still not taken. + if (!taken[j]) { + double d = distance(p, pointList.get(j)); + double d2 = d * d; + if (d2 < minDistSquared[j]) { + minDistSquared[j] = d2; + } + } + } + } + + } else { + // None found -- + // Break from the while loop to prevent + // an infinite loop. + break; + } + } + + return resultSet; + } + + /** + * Get a random point from the {@link Cluster} with the largest distance variance. + * + * @param clusters the {@link Cluster}s to search + * @return a random point from the selected cluster + * @throws ConvergenceException if clusters are all empty + */ + private T getPointFromLargestVarianceCluster(final Collection> clusters) + throws ConvergenceException { + + double maxVariance = Double.NEGATIVE_INFINITY; + Cluster selected = null; + for (final CentroidCluster cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + // compute the distance variance of the current cluster + final Clusterable center = cluster.getCenter(); + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(distance(point, center)); + } + final double variance = stat.getResult(); + + // select the cluster with the largest variance + if (variance > maxVariance) { + maxVariance = variance; + selected = cluster; + } + + } + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + + } + + /** + * Get a random point from the {@link Cluster} with the largest number of points + * + * @param clusters the {@link Cluster}s to search + * @return a random point from the selected cluster + * @throws ConvergenceException if clusters are all empty + */ + private T getPointFromLargestNumberCluster(final Collection> clusters) + throws ConvergenceException { + + int maxNumber = 0; + Cluster selected = null; + for (final Cluster cluster : clusters) { + + // get the number of points of the current cluster + final int number = cluster.getPoints().size(); + + // select the cluster with the largest number of points + if (number > maxNumber) { + maxNumber = number; + selected = cluster; + } + + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + + } + + /** + * Get the point farthest to its cluster center + * + * @param clusters the {@link Cluster}s to search + * @return point farthest to its cluster center + * @throws ConvergenceException if clusters are all empty + */ + private T getFarthestPoint(final Collection> clusters) throws ConvergenceException { + + double maxDistance = Double.NEGATIVE_INFINITY; + Cluster selectedCluster = null; + int selectedPoint = -1; + for (final CentroidCluster cluster : clusters) { + + // get the farthest point + final Clusterable center = cluster.getCenter(); + final List points = cluster.getPoints(); + for (int i = 0; i < points.size(); ++i) { + final double distance = distance(points.get(i), center); + if (distance > maxDistance) { + maxDistance = distance; + selectedCluster = cluster; + selectedPoint = i; + } + } + + } + + // did we find at least one non-empty cluster ? + if (selectedCluster == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + return selectedCluster.getPoints().remove(selectedPoint); + + } + + /** + * Returns the nearest {@link Cluster} to the given point + * + * @param clusters the {@link Cluster}s to search + * @param point the point to find the nearest {@link Cluster} for + * @return the index of the nearest {@link Cluster} to the given point + */ + private int getNearestCluster(final Collection> clusters, final T point) { + double minDistance = Double.MAX_VALUE; + int clusterIndex = 0; + int minCluster = 0; + for (final CentroidCluster c : clusters) { + final double distance = distance(point, c.getCenter()); + if (distance < minDistance) { + minDistance = distance; + minCluster = clusterIndex; + } + clusterIndex++; + } + return minCluster; + } + + /** + * Computes the centroid for a set of points. + * + * @param points the set of points + * @param dimension the point dimension + * @return the computed centroid for the set of points + */ + private Clusterable centroidOf(final Collection points, final int dimension) { + final double[] centroid = new double[dimension]; + for (final T p : points) { + final double[] point = p.getPoint(); + for (int i = 0; i < centroid.length; i++) { + centroid[i] += point[i]; + } + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= points.size(); + } + return new DoublePoint(centroid); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java new file mode 100644 index 000000000..3b8b001fb --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.util.Collection; +import java.util.List; + +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.stat.descriptive.moment.Variance; + +/** + * A wrapper around a k-means++ clustering algorithm which performs multiple trials + * and returns the best solution. + * @param type of the points to cluster + * @version $Id$ + * @since 3.2 + */ +public class MultiKMeansPlusPlusClusterer extends Clusterer { + + /** The underlying k-means clusterer. */ + private final KMeansPlusPlusClusterer clusterer; + + /** The number of trial runs. */ + private final int numTrials; + + /** Build a clusterer. + * @param clusterer the k-means clusterer to use + * @param numTrials number of trial runs + */ + public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer clusterer, + final int numTrials) { + super(clusterer.getDistanceMeasure()); + this.clusterer = clusterer; + this.numTrials = numTrials; + } + + /** + * Returns the embedded k-means clusterer used by this instance. + * @return the embedded clusterer + */ + public KMeansPlusPlusClusterer getClusterer() { + return clusterer; + } + + /** + * Returns the number of trials this instance will do. + * @return the number of trials + */ + public int getNumTrials() { + return numTrials; + } + + /** + * Runs the K-means++ clustering algorithm. + * + * @param points the points to cluster + * @return a list of clusters containing the points + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + * @throws ConvergenceException if an empty cluster is encountered and the + * {@link #emptyStrategy} is set to {@code ERROR} + */ + public List> cluster(final Collection points) + throws MathIllegalArgumentException, ConvergenceException { + + // at first, we have not found any clusters list yet + List> best = null; + double bestVarianceSum = Double.POSITIVE_INFINITY; + + // do several clustering trials + for (int i = 0; i < numTrials; ++i) { + + // compute a clusters list + List> clusters = clusterer.cluster(points); + + // compute the variance of the current list + double varianceSum = 0.0; + for (final CentroidCluster cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + // compute the distance variance of the current cluster + final Clusterable center = cluster.getCenter(); + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(distance(point, center)); + } + varianceSum += stat.getResult(); + + } + } + + if (varianceSum <= bestVarianceSum) { + // this one is the best we have found so far, remember it + best = clusters; + bestVarianceSum = varianceSum; + } + + } + + // return the best clusters list found + return best; + + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java b/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java new file mode 100644 index 000000000..02f1d208f --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Clustering algorithms. + */ +package org.apache.commons.math3.ml.clustering; diff --git a/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java new file mode 100644 index 000000000..0ee8fa4a8 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java @@ -0,0 +1,27 @@ +package org.apache.commons.math3.ml.distance; + +import org.apache.commons.math3.util.FastMath; + +/** + * Calculates the Canberra distance between two points. + * + * @version $Id $ + * @since 3.2 + */ +public class CanberraDistance implements DistanceMeasure { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -6972277381587032228L; + + /** {@inheritDoc} */ + public double compute(double[] a, double[] b) { + double sum = 0; + for (int i = 0; i < a.length; i++) { + final double num = FastMath.abs(a[i] - b[i]); + final double denom = FastMath.abs(a[i]) + FastMath.abs(b[i]); + sum += num == 0.0 && denom == 0.0 ? 0.0 : num / denom; + } + return sum; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java new file mode 100644 index 000000000..22d52a135 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java @@ -0,0 +1,21 @@ +package org.apache.commons.math3.ml.distance; + +import org.apache.commons.math3.util.MathArrays; + +/** + * Calculates the L (max of abs) distance between two points. + * + * @version $Id $ + * @since 3.2 + */ +public class ChebyshevDistance implements DistanceMeasure { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -4694868171115238296L; + + /** {@inheritDoc} */ + public double compute(double[] a, double[] b) { + return MathArrays.distanceInf(a, b); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java new file mode 100644 index 000000000..2895084db --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java @@ -0,0 +1,23 @@ +package org.apache.commons.math3.ml.distance; + +import java.io.Serializable; + +/** + * Interface for distance measures of n-dimensional vectors. + * + * @version $Id $ + * @since 3.2 + */ +public interface DistanceMeasure extends Serializable { + + /** + * Compute the distance between two n-dimensional vectors. + *

+ * The two vectors are required to have the same dimension. + * + * @param a the first vector + * @param b the second vector + * @return the distance between the two vectors + */ + double compute(double[] a, double[] b); +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java new file mode 100644 index 000000000..fda91e0fa --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java @@ -0,0 +1,21 @@ +package org.apache.commons.math3.ml.distance; + +import org.apache.commons.math3.util.MathArrays; + +/** + * Calculates the L2 (Euclidean) distance between two points. + * + * @version $Id $ + * @since 3.2 + */ +public class EuclideanDistance implements DistanceMeasure { + + /** Serializable version identifier. */ + private static final long serialVersionUID = 1717556319784040040L; + + /** {@inheritDoc} */ + public double compute(double[] a, double[] b) { + return MathArrays.distance(a, b); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java new file mode 100644 index 000000000..552bd2cc1 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java @@ -0,0 +1,21 @@ +package org.apache.commons.math3.ml.distance; + +import org.apache.commons.math3.util.MathArrays; + +/** + * Calculates the L1 (sum of abs) distance between two points. + * + * @version $Id $ + * @since 3.2 + */ +public class ManhattanDistance implements DistanceMeasure { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -9108154600539125566L; + + /** {@inheritDoc} */ + public double compute(double[] a, double[] b) { + return MathArrays.distance1(a, b); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/package-info.java b/src/main/java/org/apache/commons/math3/ml/distance/package-info.java new file mode 100644 index 000000000..f6d124a29 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Common distance measures. + */ +package org.apache.commons.math3.ml.distance; diff --git a/src/main/java/org/apache/commons/math3/ml/package-info.java b/src/main/java/org/apache/commons/math3/ml/package-info.java new file mode 100644 index 000000000..80ae917d4 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Base package for machine learning algorithms. + */ +package org.apache.commons.math3.ml; diff --git a/src/test/java/org/apache/commons/math3/ml/clustering/DBSCANClustererTest.java b/src/test/java/org/apache/commons/math3/ml/clustering/DBSCANClustererTest.java new file mode 100644 index 000000000..497459f9f --- /dev/null +++ b/src/test/java/org/apache/commons/math3/ml/clustering/DBSCANClustererTest.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.clustering; + +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.junit.Assert; +import org.junit.Test; + +public class DBSCANClustererTest { + + @Test + public void testCluster() { + // Test data generated using: http://people.cs.nctu.edu.tw/~rsliang/dbscan/testdatagen.html + final DoublePoint[] points = new DoublePoint[] { + new DoublePoint(new double[] { 83.08303244924173, 58.83387754182331 }), + new DoublePoint(new double[] { 45.05445510940626, 23.469642649637535 }), + new DoublePoint(new double[] { 14.96417921432294, 69.0264096390456 }), + new DoublePoint(new double[] { 73.53189604333602, 34.896145021310076 }), + new DoublePoint(new double[] { 73.28498173551634, 33.96860806993209 }), + new DoublePoint(new double[] { 73.45828098873608, 33.92584423092194 }), + new DoublePoint(new double[] { 73.9657889183145, 35.73191006924026 }), + new DoublePoint(new double[] { 74.0074097183533, 36.81735596177168 }), + new DoublePoint(new double[] { 73.41247541410848, 34.27314856695011 }), + new DoublePoint(new double[] { 73.9156256353017, 36.83206791547127 }), + new DoublePoint(new double[] { 74.81499205809087, 37.15682749846019 }), + new DoublePoint(new double[] { 74.03144880081527, 37.57399178552441 }), + new DoublePoint(new double[] { 74.51870941207744, 38.674258946906775 }), + new DoublePoint(new double[] { 74.50754595105536, 35.58903978415765 }), + new DoublePoint(new double[] { 74.51322752749547, 36.030572259100154 }), + new DoublePoint(new double[] { 59.27900996617973, 46.41091720294207 }), + new DoublePoint(new double[] { 59.73744793841615, 46.20015558367595 }), + new DoublePoint(new double[] { 58.81134076672606, 45.71150126331486 }), + new DoublePoint(new double[] { 58.52225539437495, 47.416083617601544 }), + new DoublePoint(new double[] { 58.218626647023484, 47.36228902172297 }), + new DoublePoint(new double[] { 60.27139669447206, 46.606106348801404 }), + new DoublePoint(new double[] { 60.894962462363765, 46.976924697402865 }), + new DoublePoint(new double[] { 62.29048673878424, 47.66970563563518 }), + new DoublePoint(new double[] { 61.03857608977705, 46.212924720020965 }), + new DoublePoint(new double[] { 60.16916214139201, 45.18193661351688 }), + new DoublePoint(new double[] { 59.90036905976012, 47.555364347063005 }), + new DoublePoint(new double[] { 62.33003634144552, 47.83941489877179 }), + new DoublePoint(new double[] { 57.86035536718555, 47.31117930193432 }), + new DoublePoint(new double[] { 58.13715479685925, 48.985960494028404 }), + new DoublePoint(new double[] { 56.131923963548616, 46.8508904252667 }), + new DoublePoint(new double[] { 55.976329887053, 47.46384037658572 }), + new DoublePoint(new double[] { 56.23245975235477, 47.940035191131756 }), + new DoublePoint(new double[] { 58.51687048212625, 46.622885352699086 }), + new DoublePoint(new double[] { 57.85411081905477, 45.95394361577928 }), + new DoublePoint(new double[] { 56.445776311447844, 45.162093662656844 }), + new DoublePoint(new double[] { 57.36691949656233, 47.50097194337286 }), + new DoublePoint(new double[] { 58.243626387557015, 46.114052729681134 }), + new DoublePoint(new double[] { 56.27224595635198, 44.799080066150054 }), + new DoublePoint(new double[] { 57.606924816500396, 46.94291057763621 }), + new DoublePoint(new double[] { 30.18714230041951, 13.877149710431695 }), + new DoublePoint(new double[] { 30.449448810657486, 13.490778346545994 }), + new DoublePoint(new double[] { 30.295018390286714, 13.264889000216499 }), + new DoublePoint(new double[] { 30.160201832884923, 11.89278262341395 }), + new DoublePoint(new double[] { 31.341509791789576, 15.282655921997502 }), + new DoublePoint(new double[] { 31.68601630325429, 14.756873246748 }), + new DoublePoint(new double[] { 29.325963742565364, 12.097849250072613 }), + new DoublePoint(new double[] { 29.54820742388256, 13.613295356975868 }), + new DoublePoint(new double[] { 28.79359608888626, 10.36352064087987 }), + new DoublePoint(new double[] { 31.01284597092308, 12.788479208014905 }), + new DoublePoint(new double[] { 27.58509216737002, 11.47570110601373 }), + new DoublePoint(new double[] { 28.593799561727792, 10.780998203903437 }), + new DoublePoint(new double[] { 31.356105766724795, 15.080316198524088 }), + new DoublePoint(new double[] { 31.25948503636755, 13.674329151166603 }), + new DoublePoint(new double[] { 32.31590076372959, 14.95261758659035 }), + new DoublePoint(new double[] { 30.460413702763617, 15.88402809202671 }), + new DoublePoint(new double[] { 32.56178203062154, 14.586076852632686 }), + new DoublePoint(new double[] { 32.76138648530468, 16.239837325178087 }), + new DoublePoint(new double[] { 30.1829453331884, 14.709592407103628 }), + new DoublePoint(new double[] { 29.55088173528202, 15.0651247180067 }), + new DoublePoint(new double[] { 29.004155302187428, 14.089665298582986 }), + new DoublePoint(new double[] { 29.339624439831823, 13.29096065578051 }), + new DoublePoint(new double[] { 30.997460327576846, 14.551914158277214 }), + new DoublePoint(new double[] { 30.66784126125276, 16.269703107886016 }) + }; + + final DBSCANClusterer transformer = + new DBSCANClusterer(2.0, 5); + final List> clusters = transformer.cluster(Arrays.asList(points)); + + final List clusterOne = + Arrays.asList(points[3], points[4], points[5], points[6], points[7], points[8], points[9], points[10], + points[11], points[12], points[13], points[14]); + final List clusterTwo = + Arrays.asList(points[15], points[16], points[17], points[18], points[19], points[20], points[21], + points[22], points[23], points[24], points[25], points[26], points[27], points[28], + points[29], points[30], points[31], points[32], points[33], points[34], points[35], + points[36], points[37], points[38]); + final List clusterThree = + Arrays.asList(points[39], points[40], points[41], points[42], points[43], points[44], points[45], + points[46], points[47], points[48], points[49], points[50], points[51], points[52], + points[53], points[54], points[55], points[56], points[57], points[58], points[59], + points[60], points[61], points[62]); + + boolean cluster1Found = false; + boolean cluster2Found = false; + boolean cluster3Found = false; + Assert.assertEquals(3, clusters.size()); + for (final Cluster cluster : clusters) { + if (cluster.getPoints().containsAll(clusterOne)) { + cluster1Found = true; + } + if (cluster.getPoints().containsAll(clusterTwo)) { + cluster2Found = true; + } + if (cluster.getPoints().containsAll(clusterThree)) { + cluster3Found = true; + } + } + Assert.assertTrue(cluster1Found); + Assert.assertTrue(cluster2Found); + Assert.assertTrue(cluster3Found); + } + + @Test + public void testSingleLink() { + final DoublePoint[] points = { + new DoublePoint(new int[] {10, 10}), // A + new DoublePoint(new int[] {12, 9}), + new DoublePoint(new int[] {10, 8}), + new DoublePoint(new int[] {8, 8}), + new DoublePoint(new int[] {8, 6}), + new DoublePoint(new int[] {7, 7}), + new DoublePoint(new int[] {5, 6}), // B + new DoublePoint(new int[] {14, 8}), // C + new DoublePoint(new int[] {7, 15}), // N - Noise, should not be present + new DoublePoint(new int[] {17, 8}), // D - single-link connected to C should not be present + + }; + + final DBSCANClusterer clusterer = new DBSCANClusterer(3, 3); + List> clusters = clusterer.cluster(Arrays.asList(points)); + + Assert.assertEquals(1, clusters.size()); + + final List clusterOne = + Arrays.asList(points[0], points[1], points[2], points[3], points[4], points[5], points[6], points[7]); + Assert.assertTrue(clusters.get(0).getPoints().containsAll(clusterOne)); + } + + @Test + public void testGetEps() { + final DBSCANClusterer transformer = new DBSCANClusterer(2.0, 5); + Assert.assertEquals(2.0, transformer.getEps(), 0.0); + } + + @Test + public void testGetMinPts() { + final DBSCANClusterer transformer = new DBSCANClusterer(2.0, 5); + Assert.assertEquals(5, transformer.getMinPts()); + } + + @Test(expected = MathIllegalArgumentException.class) + public void testNegativeEps() { + new DBSCANClusterer(-2.0, 5); + } + + @Test(expected = MathIllegalArgumentException.class) + public void testNegativeMinPts() { + new DBSCANClusterer(2.0, -5); + } + + @Test(expected = NullArgumentException.class) + public void testNullDataset() { + DBSCANClusterer clusterer = new DBSCANClusterer(2.0, 5); + clusterer.cluster(null); + } + +} diff --git a/src/test/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClustererTest.java b/src/test/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClustererTest.java new file mode 100644 index 000000000..dc56996c7 --- /dev/null +++ b/src/test/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClustererTest.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class KMeansPlusPlusClustererTest { + + private RandomGenerator random; + + @Before + public void setUp() { + random = new JDKRandomGenerator(); + random.setSeed(1746432956321l); + } + + /** + * JIRA: MATH-305 + * + * Two points, one cluster, one iteration + */ + @Test + public void testPerformClusterAnalysisDegenerate() { + KMeansPlusPlusClusterer transformer = + new KMeansPlusPlusClusterer(1, 1); + + DoublePoint[] points = new DoublePoint[] { + new DoublePoint(new int[] { 1959, 325100 }), + new DoublePoint(new int[] { 1960, 373200 }), }; + List> clusters = transformer.cluster(Arrays.asList(points)); + Assert.assertEquals(1, clusters.size()); + Assert.assertEquals(2, (clusters.get(0).getPoints().size())); + DoublePoint pt1 = new DoublePoint(new int[] { 1959, 325100 }); + DoublePoint pt2 = new DoublePoint(new int[] { 1960, 373200 }); + Assert.assertTrue(clusters.get(0).getPoints().contains(pt1)); + Assert.assertTrue(clusters.get(0).getPoints().contains(pt2)); + + } + + @Test + public void testCertainSpace() { + KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = { + KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE, + KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER, + KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT + }; + for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) { + int numberOfVariables = 27; + // initialise testvalues + int position1 = 1; + int position2 = position1 + numberOfVariables; + int position3 = position2 + numberOfVariables; + int position4 = position3 + numberOfVariables; + // testvalues will be multiplied + int multiplier = 1000000; + + DoublePoint[] breakingPoints = new DoublePoint[numberOfVariables]; + // define the space which will break the cluster algorithm + for (int i = 0; i < numberOfVariables; i++) { + int points[] = { position1, position2, position3, position4 }; + // multiply the values + for (int j = 0; j < points.length; j++) { + points[j] = points[j] * multiplier; + } + DoublePoint DoublePoint = new DoublePoint(points); + breakingPoints[i] = DoublePoint; + position1 = position1 + numberOfVariables; + position2 = position2 + numberOfVariables; + position3 = position3 + numberOfVariables; + position4 = position4 + numberOfVariables; + } + + for (int n = 2; n < 27; ++n) { + KMeansPlusPlusClusterer transformer = + new KMeansPlusPlusClusterer(n, 100, new EuclideanDistance(), random, strategy); + + List> clusters = + transformer.cluster(Arrays.asList(breakingPoints)); + + Assert.assertEquals(n, clusters.size()); + int sum = 0; + for (Cluster cluster : clusters) { + sum += cluster.getPoints().size(); + } + Assert.assertEquals(numberOfVariables, sum); + } + } + + } + + /** + * A helper class for testSmallDistances(). This class is similar to DoublePoint, but + * it defines a different distanceFrom() method that tends to return distances less than 1. + */ + private class CloseDistance extends EuclideanDistance { + private static final long serialVersionUID = 1L; + + @Override + public double compute(double[] a, double[] b) { + return super.compute(a, b) * 0.001; + } + } + + /** + * Test points that are very close together. See issue MATH-546. + */ + @Test + public void testSmallDistances() { + // Create a bunch of CloseDoublePoints. Most are identical, but one is different by a + // small distance. + int[] repeatedArray = { 0 }; + int[] uniqueArray = { 1 }; + DoublePoint repeatedPoint = new DoublePoint(repeatedArray); + DoublePoint uniquePoint = new DoublePoint(uniqueArray); + + Collection points = new ArrayList(); + final int NUM_REPEATED_POINTS = 10 * 1000; + for (int i = 0; i < NUM_REPEATED_POINTS; ++i) { + points.add(repeatedPoint); + } + points.add(uniquePoint); + + // Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to simply choose initial + // cluster centers). + final long RANDOM_SEED = 0; + final int NUM_CLUSTERS = 2; + final int NUM_ITERATIONS = 0; + random.setSeed(RANDOM_SEED); + + KMeansPlusPlusClusterer clusterer = + new KMeansPlusPlusClusterer(NUM_CLUSTERS, NUM_ITERATIONS, + new CloseDistance(), random); + List> clusters = clusterer.cluster(points); + + // Check that one of the chosen centers is the unique point. + boolean uniquePointIsCenter = false; + for (CentroidCluster cluster : clusters) { + if (cluster.getCenter().equals(uniquePoint)) { + uniquePointIsCenter = true; + } + } + Assert.assertTrue(uniquePointIsCenter); + } + + /** + * 2 variables cannot be clustered into 3 clusters. See issue MATH-436. + */ + @Test(expected=NumberIsTooSmallException.class) + public void testPerformClusterAnalysisToManyClusters() { + KMeansPlusPlusClusterer transformer = + new KMeansPlusPlusClusterer(3, 1, new EuclideanDistance(), random); + + DoublePoint[] points = new DoublePoint[] { + new DoublePoint(new int[] { + 1959, 325100 + }), new DoublePoint(new int[] { + 1960, 373200 + }) + }; + + transformer.cluster(Arrays.asList(points)); + + } + +} diff --git a/src/test/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClustererTest.java b/src/test/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClustererTest.java new file mode 100644 index 000000000..23d178ac8 --- /dev/null +++ b/src/test/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClustererTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + + +import java.util.Arrays; +import java.util.List; + +import org.junit.Assert; +import org.junit.Test; + +public class MultiKMeansPlusPlusClustererTest { + + @Test + public void dimension2() { + MultiKMeansPlusPlusClusterer transformer = + new MultiKMeansPlusPlusClusterer( + new KMeansPlusPlusClusterer(3, 10), 5); + + DoublePoint[] points = new DoublePoint[] { + + // first expected cluster + new DoublePoint(new int[] { -15, 3 }), + new DoublePoint(new int[] { -15, 4 }), + new DoublePoint(new int[] { -15, 5 }), + new DoublePoint(new int[] { -14, 3 }), + new DoublePoint(new int[] { -14, 5 }), + new DoublePoint(new int[] { -13, 3 }), + new DoublePoint(new int[] { -13, 4 }), + new DoublePoint(new int[] { -13, 5 }), + + // second expected cluster + new DoublePoint(new int[] { -1, 0 }), + new DoublePoint(new int[] { -1, -1 }), + new DoublePoint(new int[] { 0, -1 }), + new DoublePoint(new int[] { 1, -1 }), + new DoublePoint(new int[] { 1, -2 }), + + // third expected cluster + new DoublePoint(new int[] { 13, 3 }), + new DoublePoint(new int[] { 13, 4 }), + new DoublePoint(new int[] { 14, 4 }), + new DoublePoint(new int[] { 14, 7 }), + new DoublePoint(new int[] { 16, 5 }), + new DoublePoint(new int[] { 16, 6 }), + new DoublePoint(new int[] { 17, 4 }), + new DoublePoint(new int[] { 17, 7 }) + + }; + List> clusters = transformer.cluster(Arrays.asList(points)); + + Assert.assertEquals(3, clusters.size()); + boolean cluster1Found = false; + boolean cluster2Found = false; + boolean cluster3Found = false; + double epsilon = 1e-6; + for (CentroidCluster cluster : clusters) { + Clusterable center = cluster.getCenter(); + double[] point = center.getPoint(); + if (point[0] < 0) { + cluster1Found = true; + Assert.assertEquals(8, cluster.getPoints().size()); + Assert.assertEquals(-14, point[0], epsilon); + Assert.assertEquals( 4, point[1], epsilon); + } else if (point[1] < 0) { + cluster2Found = true; + Assert.assertEquals(5, cluster.getPoints().size()); + Assert.assertEquals( 0, point[0], epsilon); + Assert.assertEquals(-1, point[1], epsilon); + } else { + cluster3Found = true; + Assert.assertEquals(8, cluster.getPoints().size()); + Assert.assertEquals(15, point[0], epsilon); + Assert.assertEquals(5, point[1], epsilon); + } + } + Assert.assertTrue(cluster1Found); + Assert.assertTrue(cluster2Found); + Assert.assertTrue(cluster3Found); + + } + +}