diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 58ccc0863..0ff43e7a4 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,11 @@ If the output is not quite correct, check for invisible trailing spaces! + + Added new class "ClusterEvaluator" to evaluate the result of a clustering algorithm + and refactored existing evaluation code in "MultiKMeansPlusPlusClusterer" + into separate class "SumOfClusterVariances". + Added InsufficientDataException. @@ -96,7 +101,7 @@ If the output is not quite correct, check for invisible trailing spaces! Added logDensity methods to AbstractReal/IntegerDistribution with naive default implementations and improved implementations for some current distributions. - + Added ConfidenceInterval class and BinomialConfidenceInterval providing several estimators for confidence intervals for binomial probabilities. @@ -127,7 +132,7 @@ If the output is not quite correct, check for invisible trailing spaces! Fix a typo in the test class of "GeometricDistribution" and ensure that a meaningful tolerance value is used when comparing test results with expected values. - + Added exact binomial test implementation. diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java index fa970ac5f..654cb04bc 100644 --- a/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java +++ b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java @@ -22,7 +22,8 @@ import java.util.List; import org.apache.commons.math3.exception.ConvergenceException; import org.apache.commons.math3.exception.MathIllegalArgumentException; -import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.apache.commons.math3.ml.clustering.evaluation.ClusterEvaluator; +import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances; /** * A wrapper around a k-means++ clustering algorithm which performs multiple trials @@ -39,15 +40,31 @@ public class MultiKMeansPlusPlusClusterer extends Cluster /** The number of trial runs. */ private final int numTrials; + /** The cluster evaluator to use. */ + private final ClusterEvaluator evaluator; + /** Build a clusterer. * @param clusterer the k-means clusterer to use * @param numTrials number of trial runs */ public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer clusterer, final int numTrials) { + this(clusterer, numTrials, new SumOfClusterVariances(clusterer.getDistanceMeasure())); + } + + /** Build a clusterer. + * @param clusterer the k-means clusterer to use + * @param numTrials number of trial runs + * @param evaluator the cluster evaluator to use + * @since 3.3 + */ + public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer clusterer, + final int numTrials, + final ClusterEvaluator evaluator) { super(clusterer.getDistanceMeasure()); this.clusterer = clusterer; this.numTrials = numTrials; + this.evaluator = evaluator; } /** @@ -66,6 +83,15 @@ public class MultiKMeansPlusPlusClusterer extends Cluster return numTrials; } + /** + * Returns the {@link ClusterEvaluator} used to determine the "best" clustering. + * @return the used {@link ClusterEvaluator} + * @since 3.3 + */ + public ClusterEvaluator getClusterEvaluator() { + return evaluator; + } + /** * Runs the K-means++ clustering algorithm. * @@ -92,22 +118,9 @@ public class MultiKMeansPlusPlusClusterer extends Cluster List> clusters = clusterer.cluster(points); // compute the variance of the current list - double varianceSum = 0.0; - for (final CentroidCluster cluster : clusters) { - if (!cluster.getPoints().isEmpty()) { + final double varianceSum = evaluator.score(clusters); - // compute the distance variance of the current cluster - final Clusterable center = cluster.getCenter(); - final Variance stat = new Variance(); - for (final T point : cluster.getPoints()) { - stat.increment(distance(point, center)); - } - varianceSum += stat.getResult(); - - } - } - - if (varianceSum <= bestVarianceSum) { + if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) { // this one is the best we have found so far, remember it best = clusters; bestVarianceSum = varianceSum; diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java new file mode 100644 index 000000000..eb86286f9 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering.evaluation; + +import java.util.List; + +import org.apache.commons.math3.ml.clustering.CentroidCluster; +import org.apache.commons.math3.ml.clustering.Cluster; +import org.apache.commons.math3.ml.clustering.Clusterable; +import org.apache.commons.math3.ml.clustering.DoublePoint; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; + +/** + * Base class for cluster evaluation methods. + * + * @param type of the clustered points + * @version $Id$ + * @since 3.3 + */ +public abstract class ClusterEvaluator { + + /** The distance measure to use when evaluating the cluster. */ + private final DistanceMeasure measure; + + /** + * Creates a new cluster evaluator with an {@link EuclideanDistance} + * as distance measure. + */ + public ClusterEvaluator() { + this(new EuclideanDistance()); + } + + /** + * Creates a new cluster evaluator with the given distance measure. + * @param measure the distance measure to use + */ + public ClusterEvaluator(final DistanceMeasure measure) { + this.measure = measure; + } + + /** + * Computes the evaluation score for the given list of clusters. + * @param clusters the clusters to evaluate + * @return the computed score + */ + public abstract double score(List> clusters); + + /** + * Returns whether the first evaluation score is considered to be better + * than the second one by this evaluator. + *

+ * Specific implementations shall override this method if the returned scores + * do not follow the same ordering, i.e. smaller score is better. + * + * @param score1 the first score + * @param score2 the second score + * @return {@code true} if the first score is considered to be better, {@code false} otherwise + */ + public boolean isBetterScore(double score1, double score2) { + return score1 < score2; + } + + /** + * Calculates the distance between two {@link Clusterable} instances + * with the configured {@link DistanceMeasure}. + * + * @param p1 the first clusterable + * @param p2 the second clusterable + * @return the distance between the two clusterables + */ + protected double distance(final Clusterable p1, final Clusterable p2) { + return measure.compute(p1.getPoint(), p2.getPoint()); + } + + /** + * Computes the centroid for a cluster. + * + * @param cluster the cluster + * @return the computed centroid for the cluster, + * or {@code null} if the cluster does not contain any points + */ + protected Clusterable centroidOf(final Cluster cluster) { + final List points = cluster.getPoints(); + if (points.isEmpty()) { + return null; + } + + // in case the cluster is of type CentroidCluster, no need to compute the centroid + if (cluster instanceof CentroidCluster) { + return ((CentroidCluster) cluster).getCenter(); + } + + final int dimension = points.get(0).getPoint().length; + final double[] centroid = new double[dimension]; + for (final T p : points) { + final double[] point = p.getPoint(); + for (int i = 0; i < centroid.length; i++) { + centroid[i] += point[i]; + } + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= points.size(); + } + return new DoublePoint(centroid); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java new file mode 100644 index 000000000..4dc648e5f --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering.evaluation; + +import java.util.List; + +import org.apache.commons.math3.ml.clustering.Cluster; +import org.apache.commons.math3.ml.clustering.Clusterable; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.stat.descriptive.moment.Variance; + +/** + * Computes the sum of intra-cluster distance variances according to the formula: + *

+ * \( score = \sum\limits_{i=1}^n \sigma_i^2 \)
+ * 
+ * where n is the number of clusters and \( \sigma_i^2 \) is the variance of + * intra-cluster distances of cluster \( c_i \). + * + * @param the type of the clustered points + * @version $Id$ + * @since 3.3 + */ +public class SumOfClusterVariances extends ClusterEvaluator { + + /** + * + * @param measure the distance measure to use + */ + public SumOfClusterVariances(final DistanceMeasure measure) { + super(measure); + } + + @Override + public double score(final List> clusters) { + double varianceSum = 0.0; + for (final Cluster cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + final Clusterable center = centroidOf(cluster); + + // compute the distance variance of the current cluster + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(distance(point, center)); + } + varianceSum += stat.getResult(); + + } + } + return varianceSum; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java new file mode 100644 index 000000000..700f56602 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Cluster evaluation methods. + */ +package org.apache.commons.math3.ml.clustering.evaluation; diff --git a/src/test/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariancesTest.java b/src/test/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariancesTest.java new file mode 100644 index 000000000..a92256d9a --- /dev/null +++ b/src/test/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariancesTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering.evaluation; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.math3.ml.clustering.Cluster; +import org.apache.commons.math3.ml.clustering.DoublePoint; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.junit.Before; +import org.junit.Test; + +public class SumOfClusterVariancesTest { + + private ClusterEvaluator evaluator; + + @Before + public void setUp() { + evaluator = new SumOfClusterVariances(new EuclideanDistance()); + } + + @Test + public void testScore() { + final DoublePoint[] points1 = new DoublePoint[] { + new DoublePoint(new double[] { 1 }), + new DoublePoint(new double[] { 2 }), + new DoublePoint(new double[] { 3 }) + }; + + final DoublePoint[] points2 = new DoublePoint[] { + new DoublePoint(new double[] { 1 }), + new DoublePoint(new double[] { 5 }), + new DoublePoint(new double[] { 10 }) + }; + + final List> clusters = new ArrayList>(); + + final Cluster cluster1 = new Cluster(); + for (DoublePoint p : points1) { + cluster1.addPoint(p); + } + clusters.add(cluster1); + + assertEquals(1.0/3.0, evaluator.score(clusters), 1e-6); + + final Cluster cluster2 = new Cluster(); + for (DoublePoint p : points2) { + cluster2.addPoint(p); + } + clusters.add(cluster2); + + assertEquals(6.148148148, evaluator.score(clusters), 1e-6); + } + + @Test + public void testOrdering() { + assertTrue(evaluator.isBetterScore(10, 20)); + assertFalse(evaluator.isBetterScore(20, 1)); + } +}