diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java b/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java
new file mode 100644
index 000000000..56a57adc2
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java
@@ -0,0 +1,38 @@
+package org.apache.commons.math3.ml.clustering;
+
+/**
+ * A Cluster used by centroid-based clustering algorithms.
+ *
+ * Defines additionally a cluster center which may not necessarily be a member
+ * of the original data set.
+ *
+ * @param the type of points that can be clustered
+ * @version $Id $
+ * @since 3.2
+ */
+public class CentroidCluster extends Cluster {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -3075288519071812288L;
+
+ /** Center of the cluster. */
+ private final Clusterable center;
+
+ /**
+ * Build a cluster centered at a specified point.
+ * @param center the point which is to be the center of this cluster
+ */
+ public CentroidCluster(final Clusterable center) {
+ super();
+ this.center = center;
+ }
+
+ /**
+ * Get the point chosen to be the center of this cluster.
+ * @return chosen cluster center
+ */
+ public Clusterable getCenter() {
+ return center;
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java b/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java
new file mode 100644
index 000000000..8523a4603
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Cluster holding a set of {@link Clusterable} points.
+ * @param the type of points that can be clustered
+ * @version $Id$
+ * @since 3.2
+ */
+public class Cluster implements Serializable {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -3442297081515880464L;
+
+ /** The points contained in this cluster. */
+ private final List points;
+
+ /**
+ * Build a cluster centered at a specified point.
+ */
+ public Cluster() {
+ points = new ArrayList();
+ }
+
+ /**
+ * Add a point to this cluster.
+ * @param point point to add
+ */
+ public void addPoint(final T point) {
+ points.add(point);
+ }
+
+ /**
+ * Get the points contained in the cluster.
+ * @return points contained in the cluster
+ */
+ public List getPoints() {
+ return points;
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java
new file mode 100644
index 000000000..c4883b8f0
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+/**
+ * Interface for n-dimensional points that can be clustered together.
+ * @version $Id$
+ * @since 3.2
+ */
+public interface Clusterable {
+
+ /**
+ * Gets the n-dimensional point.
+ *
+ * @return the point array
+ */
+ double[] getPoint();
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java
new file mode 100644
index 000000000..83e572c0a
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java
@@ -0,0 +1,65 @@
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.math3.exception.ConvergenceException;
+import org.apache.commons.math3.exception.MathIllegalArgumentException;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+
+/**
+ * Base class for clustering algorithms.
+ *
+ * @param the type of points that can be clustered
+ * @version $Id $
+ * @since 3.2
+ */
+public abstract class Clusterer {
+
+ /** The distance measure to use. */
+ private DistanceMeasure measure;
+
+ /**
+ * Build a new clusterer with the given {@link DistanceMeasure}.
+ *
+ * @param measure the distance measure to use
+ */
+ protected Clusterer(final DistanceMeasure measure) {
+ this.measure = measure;
+ }
+
+ /**
+ * Perform a cluster analysis on the given set of {@link Clusterable} instances.
+ *
+ * @param points the set of {@link Clusterable} instances
+ * @return a {@link List} of clusters
+ * @throws MathIllegalArgumentException if points are null or the number of
+ * data points is not compatible with this clusterer
+ * @throws ConvergenceException if the algorithm has not yet converged after
+ * the maximum number of iterations has been exceeded
+ */
+ public abstract List extends Cluster> cluster(Collection points)
+ throws MathIllegalArgumentException, ConvergenceException;
+
+ /**
+ * Returns the {@link DistanceMeasure} instance used by this clusterer.
+ *
+ * @return the distance measure
+ */
+ public DistanceMeasure getDistanceMeasure() {
+ return measure;
+ }
+
+ /**
+ * Calculates the distance between two {@link Clusterable} instances
+ * with the configured {@link DistanceMeasure}.
+ *
+ * @param p1 the first clusterable
+ * @param p2 the second clusterable
+ * @return the distance between the two clusterables
+ */
+ protected double distance(final Clusterable p1, final Clusterable p2) {
+ return measure.compute(p1.getPoint(), p2.getPoint());
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java
new file mode 100644
index 000000000..80ac5af3b
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.commons.math3.exception.NotPositiveException;
+import org.apache.commons.math3.exception.NullArgumentException;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.util.MathUtils;
+
+/**
+ * DBSCAN (density-based spatial clustering of applications with noise) algorithm.
+ *
+ * The DBSCAN algorithm forms clusters based on the idea of density connectivity, i.e.
+ * a point p is density connected to another point q, if there exists a chain of
+ * points pi, with i = 1 .. n and p1 = p and pn = q,
+ * such that each pair <pi, pi+1> is directly density-reachable.
+ * A point q is directly density-reachable from point p if it is in the ε-neighborhood
+ * of this point.
+ *
+ * Any point that is not density-reachable from a formed cluster is treated as noise, and
+ * will thus not be present in the result.
+ *
+ * The algorithm requires two parameters:
+ *
+ * - eps: the distance that defines the ε-neighborhood of a point
+ *
- minPoints: the minimum number of density-connected points required to form a cluster
+ *
+ *
+ * @param type of the points to cluster
+ * @see DBSCAN (wikipedia)
+ * @see
+ * A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise
+ * @version $Id$
+ * @since 3.2
+ */
+public class DBSCANClusterer extends Clusterer {
+
+ /** Maximum radius of the neighborhood to be considered. */
+ private final double eps;
+
+ /** Minimum number of points needed for a cluster. */
+ private final int minPts;
+
+ /** Status of a point during the clustering process. */
+ private enum PointStatus {
+ /** The point has is considered to be noise. */
+ NOISE,
+ /** The point is already part of a cluster. */
+ PART_OF_CLUSTER
+ }
+
+ /**
+ * Creates a new instance of a DBSCANClusterer.
+ *
+ * @param eps maximum radius of the neighborhood to be considered
+ * @param minPts minimum number of points needed for a cluster
+ * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0}
+ */
+ public DBSCANClusterer(final double eps, final int minPts)
+ throws NotPositiveException {
+ this(eps, minPts, new EuclideanDistance());
+ }
+
+ /**
+ * Creates a new instance of a DBSCANClusterer.
+ *
+ * @param eps maximum radius of the neighborhood to be considered
+ * @param minPts minimum number of points needed for a cluster
+ * @param measure the distance measure to use
+ * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0}
+ */
+ public DBSCANClusterer(final double eps, final int minPts, final DistanceMeasure measure)
+ throws NotPositiveException {
+ super(measure);
+
+ if (eps < 0.0d) {
+ throw new NotPositiveException(eps);
+ }
+ if (minPts < 0) {
+ throw new NotPositiveException(minPts);
+ }
+ this.eps = eps;
+ this.minPts = minPts;
+ }
+
+ /**
+ * Returns the maximum radius of the neighborhood to be considered.
+ * @return maximum radius of the neighborhood
+ */
+ public double getEps() {
+ return eps;
+ }
+
+ /**
+ * Returns the minimum number of points needed for a cluster.
+ * @return minimum number of points needed for a cluster
+ */
+ public int getMinPts() {
+ return minPts;
+ }
+
+ /**
+ * Performs DBSCAN cluster analysis.
+ *
+ * @param points the points to cluster
+ * @return the list of clusters
+ * @throws NullArgumentException if the data points are null
+ */
+ public List> cluster(final Collection points) throws NullArgumentException {
+
+ // sanity checks
+ MathUtils.checkNotNull(points);
+
+ final List> clusters = new ArrayList>();
+ final Map visited = new HashMap();
+
+ for (final T point : points) {
+ if (visited.get(point) != null) {
+ continue;
+ }
+ final List neighbors = getNeighbors(point, points);
+ if (neighbors.size() >= minPts) {
+ // DBSCAN does not care about center points
+ final Cluster cluster = new Cluster();
+ clusters.add(expandCluster(cluster, point, neighbors, points, visited));
+ } else {
+ visited.put(point, PointStatus.NOISE);
+ }
+ }
+
+ return clusters;
+ }
+
+ /**
+ * Expands the cluster to include density-reachable items.
+ *
+ * @param cluster Cluster to expand
+ * @param point Point to add to cluster
+ * @param neighbors List of neighbors
+ * @param points the data set
+ * @param visited the set of already visited points
+ * @return the expanded cluster
+ */
+ private Cluster expandCluster(final Cluster cluster,
+ final T point,
+ final List neighbors,
+ final Collection points,
+ final Map visited) {
+ cluster.addPoint(point);
+ visited.put(point, PointStatus.PART_OF_CLUSTER);
+
+ List seeds = new ArrayList(neighbors);
+ int index = 0;
+ while (index < seeds.size()) {
+ final T current = seeds.get(index);
+ PointStatus pStatus = visited.get(current);
+ // only check non-visited points
+ if (pStatus == null) {
+ final List currentNeighbors = getNeighbors(current, points);
+ if (currentNeighbors.size() >= minPts) {
+ seeds = merge(seeds, currentNeighbors);
+ }
+ }
+
+ if (pStatus != PointStatus.PART_OF_CLUSTER) {
+ visited.put(current, PointStatus.PART_OF_CLUSTER);
+ cluster.addPoint(current);
+ }
+
+ index++;
+ }
+ return cluster;
+ }
+
+ /**
+ * Returns a list of density-reachable neighbors of a {@code point}.
+ *
+ * @param point the point to look for
+ * @param points possible neighbors
+ * @return the List of neighbors
+ */
+ private List getNeighbors(final T point, final Collection points) {
+ final List neighbors = new ArrayList();
+ for (final T neighbor : points) {
+ if (point != neighbor && distance(neighbor, point) <= eps) {
+ neighbors.add(neighbor);
+ }
+ }
+ return neighbors;
+ }
+
+ /**
+ * Merges two lists together.
+ *
+ * @param one first list
+ * @param two second list
+ * @return merged lists
+ */
+ private List merge(final List one, final List two) {
+ final Set oneSet = new HashSet(one);
+ for (T item : two) {
+ if (!oneSet.contains(item)) {
+ one.add(item);
+ }
+ }
+ return one;
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java b/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java
new file mode 100644
index 000000000..3177bfada
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+/**
+ * A simple implementation of {@link Clusterable} for points with double coordinates.
+ * @version $Id$
+ * @since 3.2
+ */
+public class DoublePoint implements Clusterable, Serializable {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = 3946024775784901369L;
+
+ /** Point coordinates. */
+ private final double[] point;
+
+ /**
+ * Build an instance wrapping an double array.
+ *
+ * The wrapped array is referenced, it is not copied.
+ *
+ * @param point the n-dimensional point in double space
+ */
+ public DoublePoint(final double[] point) {
+ this.point = point;
+ }
+
+ /**
+ * Build an instance wrapping an integer array.
+ *
+ * The wrapped array is copied to an internal double array.
+ *
+ * @param point the n-dimensional point in integer space
+ */
+ public DoublePoint(final int[] point) {
+ this.point = new double[point.length];
+ for ( int i = 0; i < point.length; i++) {
+ this.point[i] = point[i];
+ }
+ }
+
+ /** {@inheritDoc} */
+ public double[] getPoint() {
+ return point;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(final Object other) {
+ if (!(other instanceof DoublePoint)) {
+ return false;
+ }
+ return Arrays.equals(point, ((DoublePoint) other).point);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(point);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return Arrays.toString(point);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java
new file mode 100644
index 000000000..771dd1234
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java
@@ -0,0 +1,561 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.commons.math3.exception.ConvergenceException;
+import org.apache.commons.math3.exception.MathIllegalArgumentException;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.exception.util.LocalizedFormats;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.random.JDKRandomGenerator;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.stat.descriptive.moment.Variance;
+import org.apache.commons.math3.util.MathUtils;
+
+/**
+ * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
+ * @param type of the points to cluster
+ * @see K-means++ (wikipedia)
+ * @version $Id$
+ * @since 3.2
+ */
+public class KMeansPlusPlusClusterer extends Clusterer {
+
+ /** Strategies to use for replacing an empty cluster. */
+ public static enum EmptyClusterStrategy {
+
+ /** Split the cluster with largest distance variance. */
+ LARGEST_VARIANCE,
+
+ /** Split the cluster with largest number of points. */
+ LARGEST_POINTS_NUMBER,
+
+ /** Create a cluster around the point farthest from its centroid. */
+ FARTHEST_POINT,
+
+ /** Generate an error. */
+ ERROR
+
+ }
+
+ /** The number of clusters. */
+ private final int k;
+
+ /** The maximum number of iterations. */
+ private final int maxIterations;
+
+ /** Random generator for choosing initial centers. */
+ private final RandomGenerator random;
+
+ /** Selected strategy for empty clusters. */
+ private final EmptyClusterStrategy emptyStrategy;
+
+ /** Build a clusterer.
+ *
+ * The default strategy for handling empty clusters that may appear during
+ * algorithm iterations is to split the cluster with largest distance variance.
+ *
+ * @param k the number of clusters to split the data into
+ */
+ public KMeansPlusPlusClusterer(final int k) {
+ this(k, -1);
+ }
+
+ /** Build a clusterer.
+ *
+ * The default strategy for handling empty clusters that may appear during
+ * algorithm iterations is to split the cluster with largest distance variance.
+ *
+ * @param k the number of clusters to split the data into
+ * @param maxIterations the maximum number of iterations to run the algorithm for.
+ * If negative, no maximum will be used.
+ */
+ public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
+ this(k, maxIterations, new EuclideanDistance());
+ }
+
+ /** Build a clusterer.
+ *
+ * The default strategy for handling empty clusters that may appear during
+ * algorithm iterations is to split the cluster with largest distance variance.
+ *
+ * @param k the number of clusters to split the data into
+ * @param maxIterations the maximum number of iterations to run the algorithm for.
+ * If negative, no maximum will be used.
+ * @param measure the distance measure to use
+ */
+ public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
+ this(k, maxIterations, measure, new JDKRandomGenerator());
+ }
+
+ /** Build a clusterer.
+ *
+ * The default strategy for handling empty clusters that may appear during
+ * algorithm iterations is to split the cluster with largest distance variance.
+ *
+ * @param k the number of clusters to split the data into
+ * @param maxIterations the maximum number of iterations to run the algorithm for.
+ * If negative, no maximum will be used.
+ * @param measure the distance measure to use
+ * @param random random generator to use for choosing initial centers
+ */
+ public KMeansPlusPlusClusterer(final int k, final int maxIterations,
+ final DistanceMeasure measure,
+ final RandomGenerator random) {
+ this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
+ }
+
+ /** Build a clusterer.
+ *
+ * @param k the number of clusters to split the data into
+ * @param maxIterations the maximum number of iterations to run the algorithm for.
+ * If negative, no maximum will be used.
+ * @param measure the distance measure to use
+ * @param random random generator to use for choosing initial centers
+ * @param emptyStrategy strategy to use for handling empty clusters that
+ * may appear during algorithm iterations
+ */
+ public KMeansPlusPlusClusterer(final int k, final int maxIterations,
+ final DistanceMeasure measure,
+ final RandomGenerator random,
+ final EmptyClusterStrategy emptyStrategy) {
+ super(measure);
+ this.k = k;
+ this.maxIterations = maxIterations;
+ this.random = random;
+ this.emptyStrategy = emptyStrategy;
+ }
+
+ /**
+ * Return the number of clusters this instance will use.
+ * @return the number of clusters
+ */
+ public int getK() {
+ return k;
+ }
+
+ /**
+ * Returns the maximum number of iterations this instance will use.
+ * @return the maximum number of iterations, or -1 if no maximum is set
+ */
+ public int getMaxIterations() {
+ return maxIterations;
+ }
+
+ /**
+ * Returns the random generator this instance will use.
+ * @return the random generator
+ */
+ public RandomGenerator getRandomGenerator() {
+ return random;
+ }
+
+ /**
+ * Returns the {@link EmptyClusterStrategy} used by this instance.
+ * @return the {@link EmptyClusterStrategy}
+ */
+ public EmptyClusterStrategy getEmptyClusterStrategy() {
+ return emptyStrategy;
+ }
+
+ /**
+ * Runs the K-means++ clustering algorithm.
+ *
+ * @param points the points to cluster
+ * @return a list of clusters containing the points
+ * @throws MathIllegalArgumentException if the data points are null or the number
+ * of clusters is larger than the number of data points
+ * @throws ConvergenceException if an empty cluster is encountered and the
+ * {@link #emptyStrategy} is set to {@code ERROR}
+ */
+ public List> cluster(final Collection points)
+ throws MathIllegalArgumentException, ConvergenceException {
+
+ // sanity checks
+ MathUtils.checkNotNull(points);
+
+ // number of clusters has to be smaller or equal the number of data points
+ if (points.size() < k) {
+ throw new NumberIsTooSmallException(points.size(), k, false);
+ }
+
+ // create the initial clusters
+ List> clusters = chooseInitialCenters(points);
+
+ // create an array containing the latest assignment of a point to a cluster
+ // no need to initialize the array, as it will be filled with the first assignment
+ int[] assignments = new int[points.size()];
+ assignPointsToClusters(clusters, points, assignments);
+
+ // iterate through updating the centers until we're done
+ final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
+ for (int count = 0; count < max; count++) {
+ boolean emptyCluster = false;
+ List> newClusters = new ArrayList>();
+ for (final CentroidCluster cluster : clusters) {
+ final Clusterable newCenter;
+ if (cluster.getPoints().isEmpty()) {
+ switch (emptyStrategy) {
+ case LARGEST_VARIANCE :
+ newCenter = getPointFromLargestVarianceCluster(clusters);
+ break;
+ case LARGEST_POINTS_NUMBER :
+ newCenter = getPointFromLargestNumberCluster(clusters);
+ break;
+ case FARTHEST_POINT :
+ newCenter = getFarthestPoint(clusters);
+ break;
+ default :
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+ emptyCluster = true;
+ } else {
+ newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
+ }
+ newClusters.add(new CentroidCluster(newCenter));
+ }
+ int changes = assignPointsToClusters(newClusters, points, assignments);
+ clusters = newClusters;
+
+ // if there were no more changes in the point-to-cluster assignment
+ // and there are no empty clusters left, return the current clusters
+ if (changes == 0 && !emptyCluster) {
+ return clusters;
+ }
+ }
+ return clusters;
+ }
+
+ /**
+ * Adds the given points to the closest {@link Cluster}.
+ *
+ * @param clusters the {@link Cluster}s to add the points to
+ * @param points the points to add to the given {@link Cluster}s
+ * @param assignments points assignments to clusters
+ * @return the number of points assigned to different clusters as the iteration before
+ */
+ private int assignPointsToClusters(final List> clusters,
+ final Collection points,
+ final int[] assignments) {
+ int assignedDifferently = 0;
+ int pointIndex = 0;
+ for (final T p : points) {
+ int clusterIndex = getNearestCluster(clusters, p);
+ if (clusterIndex != assignments[pointIndex]) {
+ assignedDifferently++;
+ }
+
+ CentroidCluster cluster = clusters.get(clusterIndex);
+ cluster.addPoint(p);
+ assignments[pointIndex++] = clusterIndex;
+ }
+
+ return assignedDifferently;
+ }
+
+ /**
+ * Use K-means++ to choose the initial centers.
+ *
+ * @param points the points to choose the initial centers from
+ * @return the initial centers
+ */
+ private List> chooseInitialCenters(final Collection points) {
+
+ // Convert to list for indexed access. Make it unmodifiable, since removal of items
+ // would screw up the logic of this method.
+ final List pointList = Collections.unmodifiableList(new ArrayList (points));
+
+ // The number of points in the list.
+ final int numPoints = pointList.size();
+
+ // Set the corresponding element in this array to indicate when
+ // elements of pointList are no longer available.
+ final boolean[] taken = new boolean[numPoints];
+
+ // The resulting list of initial centers.
+ final List> resultSet = new ArrayList>();
+
+ // Choose one center uniformly at random from among the data points.
+ final int firstPointIndex = random.nextInt(numPoints);
+
+ final T firstPoint = pointList.get(firstPointIndex);
+
+ resultSet.add(new CentroidCluster(firstPoint));
+
+ // Must mark it as taken
+ taken[firstPointIndex] = true;
+
+ // To keep track of the minimum distance squared of elements of
+ // pointList to elements of resultSet.
+ final double[] minDistSquared = new double[numPoints];
+
+ // Initialize the elements. Since the only point in resultSet is firstPoint,
+ // this is very easy.
+ for (int i = 0; i < numPoints; i++) {
+ if (i != firstPointIndex) { // That point isn't considered
+ double d = distance(firstPoint, pointList.get(i));
+ minDistSquared[i] = d*d;
+ }
+ }
+
+ while (resultSet.size() < k) {
+
+ // Sum up the squared distances for the points in pointList not
+ // already taken.
+ double distSqSum = 0.0;
+
+ for (int i = 0; i < numPoints; i++) {
+ if (!taken[i]) {
+ distSqSum += minDistSquared[i];
+ }
+ }
+
+ // Add one new data point as a center. Each point x is chosen with
+ // probability proportional to D(x)2
+ final double r = random.nextDouble() * distSqSum;
+
+ // The index of the next point to be added to the resultSet.
+ int nextPointIndex = -1;
+
+ // Sum through the squared min distances again, stopping when
+ // sum >= r.
+ double sum = 0.0;
+ for (int i = 0; i < numPoints; i++) {
+ if (!taken[i]) {
+ sum += minDistSquared[i];
+ if (sum >= r) {
+ nextPointIndex = i;
+ break;
+ }
+ }
+ }
+
+ // If it's not set to >= 0, the point wasn't found in the previous
+ // for loop, probably because distances are extremely small. Just pick
+ // the last available point.
+ if (nextPointIndex == -1) {
+ for (int i = numPoints - 1; i >= 0; i--) {
+ if (!taken[i]) {
+ nextPointIndex = i;
+ break;
+ }
+ }
+ }
+
+ // We found one.
+ if (nextPointIndex >= 0) {
+
+ final T p = pointList.get(nextPointIndex);
+
+ resultSet.add(new CentroidCluster (p));
+
+ // Mark it as taken.
+ taken[nextPointIndex] = true;
+
+ if (resultSet.size() < k) {
+ // Now update elements of minDistSquared. We only have to compute
+ // the distance to the new center to do this.
+ for (int j = 0; j < numPoints; j++) {
+ // Only have to worry about the points still not taken.
+ if (!taken[j]) {
+ double d = distance(p, pointList.get(j));
+ double d2 = d * d;
+ if (d2 < minDistSquared[j]) {
+ minDistSquared[j] = d2;
+ }
+ }
+ }
+ }
+
+ } else {
+ // None found --
+ // Break from the while loop to prevent
+ // an infinite loop.
+ break;
+ }
+ }
+
+ return resultSet;
+ }
+
+ /**
+ * Get a random point from the {@link Cluster} with the largest distance variance.
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @return a random point from the selected cluster
+ * @throws ConvergenceException if clusters are all empty
+ */
+ private T getPointFromLargestVarianceCluster(final Collection> clusters)
+ throws ConvergenceException {
+
+ double maxVariance = Double.NEGATIVE_INFINITY;
+ Cluster selected = null;
+ for (final CentroidCluster cluster : clusters) {
+ if (!cluster.getPoints().isEmpty()) {
+
+ // compute the distance variance of the current cluster
+ final Clusterable center = cluster.getCenter();
+ final Variance stat = new Variance();
+ for (final T point : cluster.getPoints()) {
+ stat.increment(distance(point, center));
+ }
+ final double variance = stat.getResult();
+
+ // select the cluster with the largest variance
+ if (variance > maxVariance) {
+ maxVariance = variance;
+ selected = cluster;
+ }
+
+ }
+ }
+
+ // did we find at least one non-empty cluster ?
+ if (selected == null) {
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+
+ // extract a random point from the cluster
+ final List selectedPoints = selected.getPoints();
+ return selectedPoints.remove(random.nextInt(selectedPoints.size()));
+
+ }
+
+ /**
+ * Get a random point from the {@link Cluster} with the largest number of points
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @return a random point from the selected cluster
+ * @throws ConvergenceException if clusters are all empty
+ */
+ private T getPointFromLargestNumberCluster(final Collection extends Cluster> clusters)
+ throws ConvergenceException {
+
+ int maxNumber = 0;
+ Cluster selected = null;
+ for (final Cluster cluster : clusters) {
+
+ // get the number of points of the current cluster
+ final int number = cluster.getPoints().size();
+
+ // select the cluster with the largest number of points
+ if (number > maxNumber) {
+ maxNumber = number;
+ selected = cluster;
+ }
+
+ }
+
+ // did we find at least one non-empty cluster ?
+ if (selected == null) {
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+
+ // extract a random point from the cluster
+ final List selectedPoints = selected.getPoints();
+ return selectedPoints.remove(random.nextInt(selectedPoints.size()));
+
+ }
+
+ /**
+ * Get the point farthest to its cluster center
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @return point farthest to its cluster center
+ * @throws ConvergenceException if clusters are all empty
+ */
+ private T getFarthestPoint(final Collection> clusters) throws ConvergenceException {
+
+ double maxDistance = Double.NEGATIVE_INFINITY;
+ Cluster selectedCluster = null;
+ int selectedPoint = -1;
+ for (final CentroidCluster cluster : clusters) {
+
+ // get the farthest point
+ final Clusterable center = cluster.getCenter();
+ final List points = cluster.getPoints();
+ for (int i = 0; i < points.size(); ++i) {
+ final double distance = distance(points.get(i), center);
+ if (distance > maxDistance) {
+ maxDistance = distance;
+ selectedCluster = cluster;
+ selectedPoint = i;
+ }
+ }
+
+ }
+
+ // did we find at least one non-empty cluster ?
+ if (selectedCluster == null) {
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+
+ return selectedCluster.getPoints().remove(selectedPoint);
+
+ }
+
+ /**
+ * Returns the nearest {@link Cluster} to the given point
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @param point the point to find the nearest {@link Cluster} for
+ * @return the index of the nearest {@link Cluster} to the given point
+ */
+ private int getNearestCluster(final Collection> clusters, final T point) {
+ double minDistance = Double.MAX_VALUE;
+ int clusterIndex = 0;
+ int minCluster = 0;
+ for (final CentroidCluster c : clusters) {
+ final double distance = distance(point, c.getCenter());
+ if (distance < minDistance) {
+ minDistance = distance;
+ minCluster = clusterIndex;
+ }
+ clusterIndex++;
+ }
+ return minCluster;
+ }
+
+ /**
+ * Computes the centroid for a set of points.
+ *
+ * @param points the set of points
+ * @param dimension the point dimension
+ * @return the computed centroid for the set of points
+ */
+ private Clusterable centroidOf(final Collection points, final int dimension) {
+ final double[] centroid = new double[dimension];
+ for (final T p : points) {
+ final double[] point = p.getPoint();
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] += point[i];
+ }
+ }
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] /= points.size();
+ }
+ return new DoublePoint(centroid);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java
new file mode 100644
index 000000000..3b8b001fb
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.math3.exception.ConvergenceException;
+import org.apache.commons.math3.exception.MathIllegalArgumentException;
+import org.apache.commons.math3.stat.descriptive.moment.Variance;
+
+/**
+ * A wrapper around a k-means++ clustering algorithm which performs multiple trials
+ * and returns the best solution.
+ * @param type of the points to cluster
+ * @version $Id$
+ * @since 3.2
+ */
+public class MultiKMeansPlusPlusClusterer extends Clusterer {
+
+ /** The underlying k-means clusterer. */
+ private final KMeansPlusPlusClusterer clusterer;
+
+ /** The number of trial runs. */
+ private final int numTrials;
+
+ /** Build a clusterer.
+ * @param clusterer the k-means clusterer to use
+ * @param numTrials number of trial runs
+ */
+ public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer clusterer,
+ final int numTrials) {
+ super(clusterer.getDistanceMeasure());
+ this.clusterer = clusterer;
+ this.numTrials = numTrials;
+ }
+
+ /**
+ * Returns the embedded k-means clusterer used by this instance.
+ * @return the embedded clusterer
+ */
+ public KMeansPlusPlusClusterer getClusterer() {
+ return clusterer;
+ }
+
+ /**
+ * Returns the number of trials this instance will do.
+ * @return the number of trials
+ */
+ public int getNumTrials() {
+ return numTrials;
+ }
+
+ /**
+ * Runs the K-means++ clustering algorithm.
+ *
+ * @param points the points to cluster
+ * @return a list of clusters containing the points
+ * @throws MathIllegalArgumentException if the data points are null or the number
+ * of clusters is larger than the number of data points
+ * @throws ConvergenceException if an empty cluster is encountered and the
+ * {@link #emptyStrategy} is set to {@code ERROR}
+ */
+ public List> cluster(final Collection points)
+ throws MathIllegalArgumentException, ConvergenceException {
+
+ // at first, we have not found any clusters list yet
+ List> best = null;
+ double bestVarianceSum = Double.POSITIVE_INFINITY;
+
+ // do several clustering trials
+ for (int i = 0; i < numTrials; ++i) {
+
+ // compute a clusters list
+ List> clusters = clusterer.cluster(points);
+
+ // compute the variance of the current list
+ double varianceSum = 0.0;
+ for (final CentroidCluster cluster : clusters) {
+ if (!cluster.getPoints().isEmpty()) {
+
+ // compute the distance variance of the current cluster
+ final Clusterable center = cluster.getCenter();
+ final Variance stat = new Variance();
+ for (final T point : cluster.getPoints()) {
+ stat.increment(distance(point, center));
+ }
+ varianceSum += stat.getResult();
+
+ }
+ }
+
+ if (varianceSum <= bestVarianceSum) {
+ // this one is the best we have found so far, remember it
+ best = clusters;
+ bestVarianceSum = varianceSum;
+ }
+
+ }
+
+ // return the best clusters list found
+ return best;
+
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java b/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java
new file mode 100644
index 000000000..02f1d208f
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * Clustering algorithms.
+ */
+package org.apache.commons.math3.ml.clustering;
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
new file mode 100644
index 000000000..0ee8fa4a8
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
@@ -0,0 +1,27 @@
+package org.apache.commons.math3.ml.distance;
+
+import org.apache.commons.math3.util.FastMath;
+
+/**
+ * Calculates the Canberra distance between two points.
+ *
+ * @version $Id $
+ * @since 3.2
+ */
+public class CanberraDistance implements DistanceMeasure {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -6972277381587032228L;
+
+ /** {@inheritDoc} */
+ public double compute(double[] a, double[] b) {
+ double sum = 0;
+ for (int i = 0; i < a.length; i++) {
+ final double num = FastMath.abs(a[i] - b[i]);
+ final double denom = FastMath.abs(a[i]) + FastMath.abs(b[i]);
+ sum += num == 0.0 && denom == 0.0 ? 0.0 : num / denom;
+ }
+ return sum;
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
new file mode 100644
index 000000000..22d52a135
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
@@ -0,0 +1,21 @@
+package org.apache.commons.math3.ml.distance;
+
+import org.apache.commons.math3.util.MathArrays;
+
+/**
+ * Calculates the L∞ (max of abs) distance between two points.
+ *
+ * @version $Id $
+ * @since 3.2
+ */
+public class ChebyshevDistance implements DistanceMeasure {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -4694868171115238296L;
+
+ /** {@inheritDoc} */
+ public double compute(double[] a, double[] b) {
+ return MathArrays.distanceInf(a, b);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
new file mode 100644
index 000000000..2895084db
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
@@ -0,0 +1,23 @@
+package org.apache.commons.math3.ml.distance;
+
+import java.io.Serializable;
+
+/**
+ * Interface for distance measures of n-dimensional vectors.
+ *
+ * @version $Id $
+ * @since 3.2
+ */
+public interface DistanceMeasure extends Serializable {
+
+ /**
+ * Compute the distance between two n-dimensional vectors.
+ *
+ * The two vectors are required to have the same dimension.
+ *
+ * @param a the first vector
+ * @param b the second vector
+ * @return the distance between the two vectors
+ */
+ double compute(double[] a, double[] b);
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
new file mode 100644
index 000000000..fda91e0fa
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
@@ -0,0 +1,21 @@
+package org.apache.commons.math3.ml.distance;
+
+import org.apache.commons.math3.util.MathArrays;
+
+/**
+ * Calculates the L2 (Euclidean) distance between two points.
+ *
+ * @version $Id $
+ * @since 3.2
+ */
+public class EuclideanDistance implements DistanceMeasure {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = 1717556319784040040L;
+
+ /** {@inheritDoc} */
+ public double compute(double[] a, double[] b) {
+ return MathArrays.distance(a, b);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
new file mode 100644
index 000000000..552bd2cc1
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
@@ -0,0 +1,21 @@
+package org.apache.commons.math3.ml.distance;
+
+import org.apache.commons.math3.util.MathArrays;
+
+/**
+ * Calculates the L1 (sum of abs) distance between two points.
+ *
+ * @version $Id $
+ * @since 3.2
+ */
+public class ManhattanDistance implements DistanceMeasure {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -9108154600539125566L;
+
+ /** {@inheritDoc} */
+ public double compute(double[] a, double[] b) {
+ return MathArrays.distance1(a, b);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/package-info.java b/src/main/java/org/apache/commons/math3/ml/distance/package-info.java
new file mode 100644
index 000000000..f6d124a29
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * Common distance measures.
+ */
+package org.apache.commons.math3.ml.distance;
diff --git a/src/main/java/org/apache/commons/math3/ml/package-info.java b/src/main/java/org/apache/commons/math3/ml/package-info.java
new file mode 100644
index 000000000..80ae917d4
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * Base package for machine learning algorithms.
+ */
+package org.apache.commons.math3.ml;
diff --git a/src/test/java/org/apache/commons/math3/ml/clustering/DBSCANClustererTest.java b/src/test/java/org/apache/commons/math3/ml/clustering/DBSCANClustererTest.java
new file mode 100644
index 000000000..497459f9f
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/clustering/DBSCANClustererTest.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.commons.math3.exception.MathIllegalArgumentException;
+import org.apache.commons.math3.exception.NullArgumentException;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class DBSCANClustererTest {
+
+ @Test
+ public void testCluster() {
+ // Test data generated using: http://people.cs.nctu.edu.tw/~rsliang/dbscan/testdatagen.html
+ final DoublePoint[] points = new DoublePoint[] {
+ new DoublePoint(new double[] { 83.08303244924173, 58.83387754182331 }),
+ new DoublePoint(new double[] { 45.05445510940626, 23.469642649637535 }),
+ new DoublePoint(new double[] { 14.96417921432294, 69.0264096390456 }),
+ new DoublePoint(new double[] { 73.53189604333602, 34.896145021310076 }),
+ new DoublePoint(new double[] { 73.28498173551634, 33.96860806993209 }),
+ new DoublePoint(new double[] { 73.45828098873608, 33.92584423092194 }),
+ new DoublePoint(new double[] { 73.9657889183145, 35.73191006924026 }),
+ new DoublePoint(new double[] { 74.0074097183533, 36.81735596177168 }),
+ new DoublePoint(new double[] { 73.41247541410848, 34.27314856695011 }),
+ new DoublePoint(new double[] { 73.9156256353017, 36.83206791547127 }),
+ new DoublePoint(new double[] { 74.81499205809087, 37.15682749846019 }),
+ new DoublePoint(new double[] { 74.03144880081527, 37.57399178552441 }),
+ new DoublePoint(new double[] { 74.51870941207744, 38.674258946906775 }),
+ new DoublePoint(new double[] { 74.50754595105536, 35.58903978415765 }),
+ new DoublePoint(new double[] { 74.51322752749547, 36.030572259100154 }),
+ new DoublePoint(new double[] { 59.27900996617973, 46.41091720294207 }),
+ new DoublePoint(new double[] { 59.73744793841615, 46.20015558367595 }),
+ new DoublePoint(new double[] { 58.81134076672606, 45.71150126331486 }),
+ new DoublePoint(new double[] { 58.52225539437495, 47.416083617601544 }),
+ new DoublePoint(new double[] { 58.218626647023484, 47.36228902172297 }),
+ new DoublePoint(new double[] { 60.27139669447206, 46.606106348801404 }),
+ new DoublePoint(new double[] { 60.894962462363765, 46.976924697402865 }),
+ new DoublePoint(new double[] { 62.29048673878424, 47.66970563563518 }),
+ new DoublePoint(new double[] { 61.03857608977705, 46.212924720020965 }),
+ new DoublePoint(new double[] { 60.16916214139201, 45.18193661351688 }),
+ new DoublePoint(new double[] { 59.90036905976012, 47.555364347063005 }),
+ new DoublePoint(new double[] { 62.33003634144552, 47.83941489877179 }),
+ new DoublePoint(new double[] { 57.86035536718555, 47.31117930193432 }),
+ new DoublePoint(new double[] { 58.13715479685925, 48.985960494028404 }),
+ new DoublePoint(new double[] { 56.131923963548616, 46.8508904252667 }),
+ new DoublePoint(new double[] { 55.976329887053, 47.46384037658572 }),
+ new DoublePoint(new double[] { 56.23245975235477, 47.940035191131756 }),
+ new DoublePoint(new double[] { 58.51687048212625, 46.622885352699086 }),
+ new DoublePoint(new double[] { 57.85411081905477, 45.95394361577928 }),
+ new DoublePoint(new double[] { 56.445776311447844, 45.162093662656844 }),
+ new DoublePoint(new double[] { 57.36691949656233, 47.50097194337286 }),
+ new DoublePoint(new double[] { 58.243626387557015, 46.114052729681134 }),
+ new DoublePoint(new double[] { 56.27224595635198, 44.799080066150054 }),
+ new DoublePoint(new double[] { 57.606924816500396, 46.94291057763621 }),
+ new DoublePoint(new double[] { 30.18714230041951, 13.877149710431695 }),
+ new DoublePoint(new double[] { 30.449448810657486, 13.490778346545994 }),
+ new DoublePoint(new double[] { 30.295018390286714, 13.264889000216499 }),
+ new DoublePoint(new double[] { 30.160201832884923, 11.89278262341395 }),
+ new DoublePoint(new double[] { 31.341509791789576, 15.282655921997502 }),
+ new DoublePoint(new double[] { 31.68601630325429, 14.756873246748 }),
+ new DoublePoint(new double[] { 29.325963742565364, 12.097849250072613 }),
+ new DoublePoint(new double[] { 29.54820742388256, 13.613295356975868 }),
+ new DoublePoint(new double[] { 28.79359608888626, 10.36352064087987 }),
+ new DoublePoint(new double[] { 31.01284597092308, 12.788479208014905 }),
+ new DoublePoint(new double[] { 27.58509216737002, 11.47570110601373 }),
+ new DoublePoint(new double[] { 28.593799561727792, 10.780998203903437 }),
+ new DoublePoint(new double[] { 31.356105766724795, 15.080316198524088 }),
+ new DoublePoint(new double[] { 31.25948503636755, 13.674329151166603 }),
+ new DoublePoint(new double[] { 32.31590076372959, 14.95261758659035 }),
+ new DoublePoint(new double[] { 30.460413702763617, 15.88402809202671 }),
+ new DoublePoint(new double[] { 32.56178203062154, 14.586076852632686 }),
+ new DoublePoint(new double[] { 32.76138648530468, 16.239837325178087 }),
+ new DoublePoint(new double[] { 30.1829453331884, 14.709592407103628 }),
+ new DoublePoint(new double[] { 29.55088173528202, 15.0651247180067 }),
+ new DoublePoint(new double[] { 29.004155302187428, 14.089665298582986 }),
+ new DoublePoint(new double[] { 29.339624439831823, 13.29096065578051 }),
+ new DoublePoint(new double[] { 30.997460327576846, 14.551914158277214 }),
+ new DoublePoint(new double[] { 30.66784126125276, 16.269703107886016 })
+ };
+
+ final DBSCANClusterer transformer =
+ new DBSCANClusterer(2.0, 5);
+ final List> clusters = transformer.cluster(Arrays.asList(points));
+
+ final List clusterOne =
+ Arrays.asList(points[3], points[4], points[5], points[6], points[7], points[8], points[9], points[10],
+ points[11], points[12], points[13], points[14]);
+ final List clusterTwo =
+ Arrays.asList(points[15], points[16], points[17], points[18], points[19], points[20], points[21],
+ points[22], points[23], points[24], points[25], points[26], points[27], points[28],
+ points[29], points[30], points[31], points[32], points[33], points[34], points[35],
+ points[36], points[37], points[38]);
+ final List clusterThree =
+ Arrays.asList(points[39], points[40], points[41], points[42], points[43], points[44], points[45],
+ points[46], points[47], points[48], points[49], points[50], points[51], points[52],
+ points[53], points[54], points[55], points[56], points[57], points[58], points[59],
+ points[60], points[61], points[62]);
+
+ boolean cluster1Found = false;
+ boolean cluster2Found = false;
+ boolean cluster3Found = false;
+ Assert.assertEquals(3, clusters.size());
+ for (final Cluster cluster : clusters) {
+ if (cluster.getPoints().containsAll(clusterOne)) {
+ cluster1Found = true;
+ }
+ if (cluster.getPoints().containsAll(clusterTwo)) {
+ cluster2Found = true;
+ }
+ if (cluster.getPoints().containsAll(clusterThree)) {
+ cluster3Found = true;
+ }
+ }
+ Assert.assertTrue(cluster1Found);
+ Assert.assertTrue(cluster2Found);
+ Assert.assertTrue(cluster3Found);
+ }
+
+ @Test
+ public void testSingleLink() {
+ final DoublePoint[] points = {
+ new DoublePoint(new int[] {10, 10}), // A
+ new DoublePoint(new int[] {12, 9}),
+ new DoublePoint(new int[] {10, 8}),
+ new DoublePoint(new int[] {8, 8}),
+ new DoublePoint(new int[] {8, 6}),
+ new DoublePoint(new int[] {7, 7}),
+ new DoublePoint(new int[] {5, 6}), // B
+ new DoublePoint(new int[] {14, 8}), // C
+ new DoublePoint(new int[] {7, 15}), // N - Noise, should not be present
+ new DoublePoint(new int[] {17, 8}), // D - single-link connected to C should not be present
+
+ };
+
+ final DBSCANClusterer clusterer = new DBSCANClusterer(3, 3);
+ List> clusters = clusterer.cluster(Arrays.asList(points));
+
+ Assert.assertEquals(1, clusters.size());
+
+ final List clusterOne =
+ Arrays.asList(points[0], points[1], points[2], points[3], points[4], points[5], points[6], points[7]);
+ Assert.assertTrue(clusters.get(0).getPoints().containsAll(clusterOne));
+ }
+
+ @Test
+ public void testGetEps() {
+ final DBSCANClusterer transformer = new DBSCANClusterer(2.0, 5);
+ Assert.assertEquals(2.0, transformer.getEps(), 0.0);
+ }
+
+ @Test
+ public void testGetMinPts() {
+ final DBSCANClusterer transformer = new DBSCANClusterer(2.0, 5);
+ Assert.assertEquals(5, transformer.getMinPts());
+ }
+
+ @Test(expected = MathIllegalArgumentException.class)
+ public void testNegativeEps() {
+ new DBSCANClusterer(-2.0, 5);
+ }
+
+ @Test(expected = MathIllegalArgumentException.class)
+ public void testNegativeMinPts() {
+ new DBSCANClusterer(2.0, -5);
+ }
+
+ @Test(expected = NullArgumentException.class)
+ public void testNullDataset() {
+ DBSCANClusterer clusterer = new DBSCANClusterer(2.0, 5);
+ clusterer.cluster(null);
+ }
+
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClustererTest.java b/src/test/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClustererTest.java
new file mode 100644
index 000000000..dc56996c7
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClustererTest.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.random.JDKRandomGenerator;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class KMeansPlusPlusClustererTest {
+
+ private RandomGenerator random;
+
+ @Before
+ public void setUp() {
+ random = new JDKRandomGenerator();
+ random.setSeed(1746432956321l);
+ }
+
+ /**
+ * JIRA: MATH-305
+ *
+ * Two points, one cluster, one iteration
+ */
+ @Test
+ public void testPerformClusterAnalysisDegenerate() {
+ KMeansPlusPlusClusterer transformer =
+ new KMeansPlusPlusClusterer(1, 1);
+
+ DoublePoint[] points = new DoublePoint[] {
+ new DoublePoint(new int[] { 1959, 325100 }),
+ new DoublePoint(new int[] { 1960, 373200 }), };
+ List extends Cluster> clusters = transformer.cluster(Arrays.asList(points));
+ Assert.assertEquals(1, clusters.size());
+ Assert.assertEquals(2, (clusters.get(0).getPoints().size()));
+ DoublePoint pt1 = new DoublePoint(new int[] { 1959, 325100 });
+ DoublePoint pt2 = new DoublePoint(new int[] { 1960, 373200 });
+ Assert.assertTrue(clusters.get(0).getPoints().contains(pt1));
+ Assert.assertTrue(clusters.get(0).getPoints().contains(pt2));
+
+ }
+
+ @Test
+ public void testCertainSpace() {
+ KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = {
+ KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE,
+ KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER,
+ KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT
+ };
+ for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) {
+ int numberOfVariables = 27;
+ // initialise testvalues
+ int position1 = 1;
+ int position2 = position1 + numberOfVariables;
+ int position3 = position2 + numberOfVariables;
+ int position4 = position3 + numberOfVariables;
+ // testvalues will be multiplied
+ int multiplier = 1000000;
+
+ DoublePoint[] breakingPoints = new DoublePoint[numberOfVariables];
+ // define the space which will break the cluster algorithm
+ for (int i = 0; i < numberOfVariables; i++) {
+ int points[] = { position1, position2, position3, position4 };
+ // multiply the values
+ for (int j = 0; j < points.length; j++) {
+ points[j] = points[j] * multiplier;
+ }
+ DoublePoint DoublePoint = new DoublePoint(points);
+ breakingPoints[i] = DoublePoint;
+ position1 = position1 + numberOfVariables;
+ position2 = position2 + numberOfVariables;
+ position3 = position3 + numberOfVariables;
+ position4 = position4 + numberOfVariables;
+ }
+
+ for (int n = 2; n < 27; ++n) {
+ KMeansPlusPlusClusterer transformer =
+ new KMeansPlusPlusClusterer(n, 100, new EuclideanDistance(), random, strategy);
+
+ List extends Cluster> clusters =
+ transformer.cluster(Arrays.asList(breakingPoints));
+
+ Assert.assertEquals(n, clusters.size());
+ int sum = 0;
+ for (Cluster cluster : clusters) {
+ sum += cluster.getPoints().size();
+ }
+ Assert.assertEquals(numberOfVariables, sum);
+ }
+ }
+
+ }
+
+ /**
+ * A helper class for testSmallDistances(). This class is similar to DoublePoint, but
+ * it defines a different distanceFrom() method that tends to return distances less than 1.
+ */
+ private class CloseDistance extends EuclideanDistance {
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public double compute(double[] a, double[] b) {
+ return super.compute(a, b) * 0.001;
+ }
+ }
+
+ /**
+ * Test points that are very close together. See issue MATH-546.
+ */
+ @Test
+ public void testSmallDistances() {
+ // Create a bunch of CloseDoublePoints. Most are identical, but one is different by a
+ // small distance.
+ int[] repeatedArray = { 0 };
+ int[] uniqueArray = { 1 };
+ DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
+ DoublePoint uniquePoint = new DoublePoint(uniqueArray);
+
+ Collection points = new ArrayList();
+ final int NUM_REPEATED_POINTS = 10 * 1000;
+ for (int i = 0; i < NUM_REPEATED_POINTS; ++i) {
+ points.add(repeatedPoint);
+ }
+ points.add(uniquePoint);
+
+ // Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to simply choose initial
+ // cluster centers).
+ final long RANDOM_SEED = 0;
+ final int NUM_CLUSTERS = 2;
+ final int NUM_ITERATIONS = 0;
+ random.setSeed(RANDOM_SEED);
+
+ KMeansPlusPlusClusterer clusterer =
+ new KMeansPlusPlusClusterer(NUM_CLUSTERS, NUM_ITERATIONS,
+ new CloseDistance(), random);
+ List> clusters = clusterer.cluster(points);
+
+ // Check that one of the chosen centers is the unique point.
+ boolean uniquePointIsCenter = false;
+ for (CentroidCluster cluster : clusters) {
+ if (cluster.getCenter().equals(uniquePoint)) {
+ uniquePointIsCenter = true;
+ }
+ }
+ Assert.assertTrue(uniquePointIsCenter);
+ }
+
+ /**
+ * 2 variables cannot be clustered into 3 clusters. See issue MATH-436.
+ */
+ @Test(expected=NumberIsTooSmallException.class)
+ public void testPerformClusterAnalysisToManyClusters() {
+ KMeansPlusPlusClusterer transformer =
+ new KMeansPlusPlusClusterer(3, 1, new EuclideanDistance(), random);
+
+ DoublePoint[] points = new DoublePoint[] {
+ new DoublePoint(new int[] {
+ 1959, 325100
+ }), new DoublePoint(new int[] {
+ 1960, 373200
+ })
+ };
+
+ transformer.cluster(Arrays.asList(points));
+
+ }
+
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClustererTest.java b/src/test/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClustererTest.java
new file mode 100644
index 000000000..23d178ac8
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClustererTest.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class MultiKMeansPlusPlusClustererTest {
+
+ @Test
+ public void dimension2() {
+ MultiKMeansPlusPlusClusterer transformer =
+ new MultiKMeansPlusPlusClusterer(
+ new KMeansPlusPlusClusterer(3, 10), 5);
+
+ DoublePoint[] points = new DoublePoint[] {
+
+ // first expected cluster
+ new DoublePoint(new int[] { -15, 3 }),
+ new DoublePoint(new int[] { -15, 4 }),
+ new DoublePoint(new int[] { -15, 5 }),
+ new DoublePoint(new int[] { -14, 3 }),
+ new DoublePoint(new int[] { -14, 5 }),
+ new DoublePoint(new int[] { -13, 3 }),
+ new DoublePoint(new int[] { -13, 4 }),
+ new DoublePoint(new int[] { -13, 5 }),
+
+ // second expected cluster
+ new DoublePoint(new int[] { -1, 0 }),
+ new DoublePoint(new int[] { -1, -1 }),
+ new DoublePoint(new int[] { 0, -1 }),
+ new DoublePoint(new int[] { 1, -1 }),
+ new DoublePoint(new int[] { 1, -2 }),
+
+ // third expected cluster
+ new DoublePoint(new int[] { 13, 3 }),
+ new DoublePoint(new int[] { 13, 4 }),
+ new DoublePoint(new int[] { 14, 4 }),
+ new DoublePoint(new int[] { 14, 7 }),
+ new DoublePoint(new int[] { 16, 5 }),
+ new DoublePoint(new int[] { 16, 6 }),
+ new DoublePoint(new int[] { 17, 4 }),
+ new DoublePoint(new int[] { 17, 7 })
+
+ };
+ List> clusters = transformer.cluster(Arrays.asList(points));
+
+ Assert.assertEquals(3, clusters.size());
+ boolean cluster1Found = false;
+ boolean cluster2Found = false;
+ boolean cluster3Found = false;
+ double epsilon = 1e-6;
+ for (CentroidCluster cluster : clusters) {
+ Clusterable center = cluster.getCenter();
+ double[] point = center.getPoint();
+ if (point[0] < 0) {
+ cluster1Found = true;
+ Assert.assertEquals(8, cluster.getPoints().size());
+ Assert.assertEquals(-14, point[0], epsilon);
+ Assert.assertEquals( 4, point[1], epsilon);
+ } else if (point[1] < 0) {
+ cluster2Found = true;
+ Assert.assertEquals(5, cluster.getPoints().size());
+ Assert.assertEquals( 0, point[0], epsilon);
+ Assert.assertEquals(-1, point[1], epsilon);
+ } else {
+ cluster3Found = true;
+ Assert.assertEquals(8, cluster.getPoints().size());
+ Assert.assertEquals(15, point[0], epsilon);
+ Assert.assertEquals(5, point[1], epsilon);
+ }
+ }
+ Assert.assertTrue(cluster1Found);
+ Assert.assertTrue(cluster2Found);
+ Assert.assertTrue(cluster3Found);
+
+ }
+
+}