MATH-1522 Remove generic parameter in ClusterEvaluator and ClusterRanking

This commit is contained in:
CT 2020-03-10 12:28:41 +08:00 committed by Gilles Sadowski
parent af4962c3c6
commit af5ad16a06
7 changed files with 37 additions and 40 deletions

View File

@ -19,12 +19,13 @@ package org.apache.commons.math4.ml.clustering;
import java.util.List;
public interface ClusterEvaluator<T extends Clusterable> {
public interface ClusterEvaluator {
/**
* @param cList List of clusters.
* @return the score attributed by the evaluator.
*/
double score(List<? extends Cluster<T>> cList);
double score(List<? extends Cluster<? extends Clusterable>> cList);
/**
* @param a Score computed by this evaluator.
* @param b Score computed by this evaluator.
@ -40,11 +41,10 @@ public interface ClusterEvaluator<T extends Clusterable> {
* @param eval Evaluator function.
* @return a ranking function.
*/
static <T extends Clusterable> ClusterRanking ranking(ClusterEvaluator<T> eval) {
if (eval.isBetterScore(1, 2)) {
return cList -> 1 / eval.score(cList);
} else {
return cList -> eval.score(cList);
}
static <T extends Clusterable> ClusterRanking ranking(ClusterEvaluator eval) {
return clusters -> {
double score = eval.score(clusters);
return eval.isBetterScore(1, 2) ? score : 1 / score;
};
}
}

View File

@ -28,12 +28,12 @@ import java.util.List;
* </ul>
*/
@FunctionalInterface
public interface ClusterRanking<T extends Clusterable> {
public interface ClusterRanking {
/**
* Computes the rank (higher is better).
*
* @param clusters Clusters to be evaluated.
* @return the rank of the provided {@code clusters}.
*/
double compute(List<? extends Cluster<T>> clusters);
double compute(List<? extends Cluster<? extends Clusterable>> clusters);
}

View File

@ -39,7 +39,7 @@ public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Cluster
private final int numTrials;
/** The cluster evaluator to use. */
private final ClusterRanking<T> evaluator;
private final ClusterRanking evaluator;
/** Build a clusterer.
* @param clusterer the k-means clusterer to use
@ -49,7 +49,7 @@ public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Cluster
final int numTrials) {
this(clusterer,
numTrials,
ClusterEvaluator.ranking(new SumOfClusterVariances<T>(clusterer.getDistanceMeasure())));
ClusterEvaluator.ranking(new SumOfClusterVariances(clusterer.getDistanceMeasure())));
}
/** Build a clusterer.
@ -60,7 +60,7 @@ public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Cluster
*/
public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
final int numTrials,
final ClusterRanking<T> evaluator) {
final ClusterRanking evaluator) {
super(clusterer.getDistanceMeasure());
this.clusterer = clusterer;
this.numTrials = numTrials;

View File

@ -34,23 +34,22 @@ import java.util.List;
* The score is defined as ratio between the within-cluster dispersion and
* the between-cluster dispersion.
*
* @param <T> the type of the clustered points
* @see <a href="https://www.tandfonline.com/doi/abs/10.1080/03610927408827101">A dendrite method for cluster
* analysis</a>
*/
public class CalinskiHarabasz<T extends Clusterable> implements ClusterEvaluator<T> {
public class CalinskiHarabasz implements ClusterEvaluator {
/** {@inheritDoc} */
@Override
public double score(List<? extends Cluster<T>> clusters) {
public double score(List<? extends Cluster<? extends Clusterable>> clusters) {
final int dimension = dimensionOfClusters(clusters);
final double[] centroid = meanOfClusters(clusters, dimension);
double intraDistanceProduct = 0.0;
double extraDistanceProduct = 0.0;
for (Cluster<T> cluster : clusters) {
for (Cluster<? extends Clusterable> cluster : clusters) {
// Calculate the center of the cluster.
double[] clusterCentroid = mean(cluster.getPoints(), dimension);
for (T p : cluster.getPoints()) {
for (Clusterable p : cluster.getPoints()) {
// Increase the intra distance sum
intraDistanceProduct += covariance(clusterCentroid, p.getPoint());
}
@ -100,9 +99,9 @@ public class CalinskiHarabasz<T extends Clusterable> implements ClusterEvaluator
* @param dimension The dimension of each point
* @return The mean value.
*/
private double[] mean(final Collection<T> points, final int dimension) {
private double[] mean(final Collection<? extends Clusterable> points, final int dimension) {
final double[] centroid = new double[dimension];
for (final T p : points) {
for (final Clusterable p : points) {
final double[] point = p.getPoint();
for (int i = 0; i < centroid.length; i++) {
centroid[i] += point[i];
@ -121,11 +120,11 @@ public class CalinskiHarabasz<T extends Clusterable> implements ClusterEvaluator
* @param dimension The dimension of each point.
* @return The mean value.
*/
private double[] meanOfClusters(final Collection<? extends Cluster<T>> clusters, final int dimension) {
private double[] meanOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters, final int dimension) {
final double[] centroid = new double[dimension];
int allPointsCount = 0;
for (Cluster<T> cluster : clusters) {
for (T p : cluster.getPoints()) {
for (Cluster<? extends Clusterable> cluster : clusters) {
for (Clusterable p : cluster.getPoints()) {
double[] point = p.getPoint();
for (int i = 0; i < centroid.length; i++) {
centroid[i] += point[i];
@ -145,9 +144,9 @@ public class CalinskiHarabasz<T extends Clusterable> implements ClusterEvaluator
* @param clusters collection of cluster
* @return points count
*/
private int countAllPoints(final Collection<? extends Cluster<T>> clusters) {
private int countAllPoints(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
int pointCount = 0;
for (Cluster<T> cluster : clusters) {
for (Cluster<? extends Clusterable> cluster : clusters) {
pointCount += cluster.getPoints().size();
}
return pointCount;
@ -159,10 +158,10 @@ public class CalinskiHarabasz<T extends Clusterable> implements ClusterEvaluator
* @param clusters collection of cluster
* @return The dimension of the first point in clusters
*/
private int dimensionOfClusters(final Collection<? extends Cluster<T>> clusters) {
private int dimensionOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
// Iteration and find out the first point.
for (Cluster<T> cluster : clusters) {
for (T p : cluster.getPoints()) {
for (Cluster<? extends Clusterable> cluster : clusters) {
for (Clusterable p : cluster.getPoints()) {
return p.getPoint().length;
}
}

View File

@ -33,11 +33,9 @@ import org.apache.commons.math4.stat.descriptive.moment.Variance;
* where n is the number of clusters and \( \sigma_i^2 \) is the variance of
* intra-cluster distances of cluster \( c_i \).
*
* @param <T> the type of the clustered points
* @since 3.3
*/
public class SumOfClusterVariances<T extends Clusterable>
implements ClusterEvaluator<T> {
public class SumOfClusterVariances implements ClusterEvaluator {
/** The distance measure to use when evaluating the cluster. */
private final DistanceMeasure measure;
@ -48,18 +46,18 @@ public class SumOfClusterVariances<T extends Clusterable>
this.measure = measure;
}
/** {@inheritDoc} */
@Override
public double score(final List<? extends Cluster<T>> clusters) {
/** {@inheritDoc}
* @param clusters*/
public double score(List<? extends Cluster<? extends Clusterable>> clusters) {
double varianceSum = 0.0;
for (final Cluster<T> cluster : clusters) {
for (final Cluster<? extends Clusterable> cluster : clusters) {
if (!cluster.getPoints().isEmpty()) {
final Clusterable center = cluster.centroid();
// compute the distance variance of the current cluster
final Variance stat = new Variance();
for (final T point : cluster.getPoints()) {
for (final Clusterable point : cluster.getPoints()) {
stat.increment(distance(point, center));
}

View File

@ -33,12 +33,12 @@ import java.util.ArrayList;
import java.util.List;
public class CalinskiHarabaszTest {
private ClusterEvaluator<DoublePoint> evaluator;
private ClusterEvaluator evaluator;
private DistanceMeasure distanceMeasure;
@Before
public void setUp() {
evaluator = new CalinskiHarabasz<>();
evaluator = new CalinskiHarabasz();
distanceMeasure = new EuclideanDistance();
}

View File

@ -33,11 +33,11 @@ import org.junit.Test;
public class SumOfClusterVariancesTest {
private ClusterEvaluator<DoublePoint> evaluator;
private ClusterEvaluator evaluator;
@Before
public void setUp() {
evaluator = new SumOfClusterVariances<>(new EuclideanDistance());
evaluator = new SumOfClusterVariances(new EuclideanDistance());
}
@Test