diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 21d622770..a3e1ceb66 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces! + + Abstract class "ClusterEvaluator" replaced by an interface. + Remove code duplication by moving method to class "Cluster". diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/ClusterEvaluator.java b/src/main/java/org/apache/commons/math4/ml/clustering/ClusterEvaluator.java new file mode 100644 index 000000000..242941483 --- /dev/null +++ b/src/main/java/org/apache/commons/math4/ml/clustering/ClusterEvaluator.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math4.ml.clustering; + +import java.util.List; + +public interface ClusterEvaluator { + /** + * @param cList List of clusters. + * @return the score attributed by the evaluator. + */ + double score(List> cList); + /** + * @param a Score computed by this evaluator. + * @param b Score computed by this evaluator. + * @return true if the evaluator considers score {@code a} is + * considered better than score {@code b}. + */ + boolean isBetterScore(double a, double b); + + /** + * Converts to a {@link ClusterRanking ranking function} + * (as required by clustering implementations). + * + * @param eval Evaluator function. + * @return a ranking function. + */ + static ClusterRanking ranking(ClusterEvaluator eval) { + if (eval.isBetterScore(1, 2)) { + return cList -> 1 / eval.score(cList); + } else { + return cList -> eval.score(cList); + } + } +} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/MultiKMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/MultiKMeansPlusPlusClusterer.java index 32b358c19..282050afc 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/MultiKMeansPlusPlusClusterer.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/MultiKMeansPlusPlusClusterer.java @@ -22,7 +22,6 @@ import java.util.List; import org.apache.commons.math4.exception.ConvergenceException; import org.apache.commons.math4.exception.MathIllegalArgumentException; -import org.apache.commons.math4.ml.clustering.evaluation.ClusterEvaluator; import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances; /** @@ -48,7 +47,9 @@ public class MultiKMeansPlusPlusClusterer extends Cluster */ public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer clusterer, final int numTrials) { - this(clusterer, numTrials, new SumOfClusterVariances(clusterer.getDistanceMeasure())); + this(clusterer, + numTrials, + ClusterEvaluator.ranking(new SumOfClusterVariances(clusterer.getDistanceMeasure()))); } /** Build a clusterer. diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java b/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java deleted file mode 100644 index ed4205795..000000000 --- a/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.commons.math4.ml.clustering.evaluation; - -import java.util.List; - -import org.apache.commons.math4.ml.clustering.CentroidCluster; -import org.apache.commons.math4.ml.clustering.Cluster; -import org.apache.commons.math4.ml.clustering.Clusterable; -import org.apache.commons.math4.ml.clustering.DoublePoint; -import org.apache.commons.math4.ml.distance.DistanceMeasure; -import org.apache.commons.math4.ml.distance.EuclideanDistance; - -/** - * Base class for cluster evaluation methods. - * - * @param type of the clustered points - * @since 3.3 - */ -public abstract class ClusterEvaluator { - - /** The distance measure to use when evaluating the cluster. */ - private final DistanceMeasure measure; - - /** - * Creates a new cluster evaluator with an {@link EuclideanDistance} - * as distance measure. - */ - public ClusterEvaluator() { - this(new EuclideanDistance()); - } - - /** - * Creates a new cluster evaluator with the given distance measure. - * @param measure the distance measure to use - */ - public ClusterEvaluator(final DistanceMeasure measure) { - this.measure = measure; - } - - /** - * Computes the evaluation score for the given list of clusters. - * @param clusters the clusters to evaluate - * @return the computed score - */ - public abstract double score(List> clusters); - - /** - * Returns whether the first evaluation score is considered to be better - * than the second one by this evaluator. - *

- * Specific implementations shall override this method if the returned scores - * do not follow the same ordering, i.e. smaller score is better. - * - * @param score1 the first score - * @param score2 the second score - * @return {@code true} if the first score is considered to be better, {@code false} otherwise - */ - public boolean isBetterScore(double score1, double score2) { - return score1 < score2; - } - - /** - * Calculates the distance between two {@link Clusterable} instances - * with the configured {@link DistanceMeasure}. - * - * @param p1 the first clusterable - * @param p2 the second clusterable - * @return the distance between the two clusterables - */ - protected double distance(final Clusterable p1, final Clusterable p2) { - return measure.compute(p1.getPoint(), p2.getPoint()); - } -} diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariances.java b/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariances.java index 8eebe2fe8..2ff48b97b 100644 --- a/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariances.java +++ b/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariances.java @@ -21,7 +21,7 @@ import java.util.List; import org.apache.commons.math4.ml.clustering.Cluster; import org.apache.commons.math4.ml.clustering.Clusterable; -import org.apache.commons.math4.ml.clustering.ClusterRanking; +import org.apache.commons.math4.ml.clustering.ClusterEvaluator; import org.apache.commons.math4.ml.distance.DistanceMeasure; import org.apache.commons.math4.stat.descriptive.moment.Variance; @@ -36,14 +36,16 @@ import org.apache.commons.math4.stat.descriptive.moment.Variance; * @param the type of the clustered points * @since 3.3 */ -public class SumOfClusterVariances extends ClusterEvaluator - implements ClusterRanking { +public class SumOfClusterVariances + implements ClusterEvaluator { + /** The distance measure to use when evaluating the cluster. */ + private final DistanceMeasure measure; /** - * @param measure the distance measure to use + * @param measure Distance measure. */ public SumOfClusterVariances(final DistanceMeasure measure) { - super(measure); + this.measure = measure; } /** {@inheritDoc} */ @@ -60,8 +62,8 @@ public class SumOfClusterVariances extends ClusterEvaluat for (final T point : cluster.getPoints()) { stat.increment(distance(point, center)); } - varianceSum += stat.getResult(); + varianceSum += stat.getResult(); } } return varianceSum; @@ -69,7 +71,20 @@ public class SumOfClusterVariances extends ClusterEvaluat /** {@inheritDoc} */ @Override - public double compute(List> clusters) { - return 1d / score(clusters); + public boolean isBetterScore(double a, + double b) { + return a < b; + } + + /** + * Calculates the distance between two {@link Clusterable} instances + * with the configured {@link DistanceMeasure}. + * + * @param p1 the first clusterable + * @param p2 the second clusterable + * @return the distance between the two clusterables + */ + private double distance(final Clusterable p1, final Clusterable p2) { + return measure.compute(p1.getPoint(), p2.getPoint()); } } diff --git a/src/test/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariancesTest.java b/src/test/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariancesTest.java index 37c59a605..19095c6ee 100644 --- a/src/test/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariancesTest.java +++ b/src/test/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariancesTest.java @@ -26,8 +26,7 @@ import java.util.List; import org.apache.commons.math4.ml.clustering.Cluster; import org.apache.commons.math4.ml.clustering.DoublePoint; -import org.apache.commons.math4.ml.clustering.evaluation.ClusterEvaluator; -import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances; +import org.apache.commons.math4.ml.clustering.ClusterEvaluator; import org.apache.commons.math4.ml.distance.EuclideanDistance; import org.junit.Before; import org.junit.Test;