MATH-1523: Abstract class replaced by an interface.

This commit is contained in:
Gilles 2020-03-10 02:41:06 +01:00
parent aafc49afd7
commit c770e66963
6 changed files with 80 additions and 101 deletions

View File

@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
</release> </release>
<release version="4.0" date="XXXX-XX-XX" description=""> <release version="4.0" date="XXXX-XX-XX" description="">
<action dev="erans" type="update" issue="MATH-1523">
Abstract class "ClusterEvaluator" replaced by an interface.
</action>
<action dev="erans" type="fix" issue="MATH-1518"> <action dev="erans" type="fix" issue="MATH-1518">
Remove code duplication by moving method to class "Cluster". Remove code duplication by moving method to class "Cluster".
</action> </action>

View File

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.ml.clustering;
import java.util.List;
public interface ClusterEvaluator<T extends Clusterable> {
/**
* @param cList List of clusters.
* @return the score attributed by the evaluator.
*/
double score(List<? extends Cluster<T>> cList);
/**
* @param a Score computed by this evaluator.
* @param b Score computed by this evaluator.
* @return true if the evaluator considers score {@code a} is
* considered better than score {@code b}.
*/
boolean isBetterScore(double a, double b);
/**
* Converts to a {@link ClusterRanking ranking function}
* (as required by clustering implementations).
*
* @param eval Evaluator function.
* @return a ranking function.
*/
static <T extends Clusterable> ClusterRanking ranking(ClusterEvaluator<T> eval) {
if (eval.isBetterScore(1, 2)) {
return cList -> 1 / eval.score(cList);
} else {
return cList -> eval.score(cList);
}
}
}

View File

@ -22,7 +22,6 @@ import java.util.List;
import org.apache.commons.math4.exception.ConvergenceException; import org.apache.commons.math4.exception.ConvergenceException;
import org.apache.commons.math4.exception.MathIllegalArgumentException; import org.apache.commons.math4.exception.MathIllegalArgumentException;
import org.apache.commons.math4.ml.clustering.evaluation.ClusterEvaluator;
import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances; import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances;
/** /**
@ -48,7 +47,9 @@ public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Cluster
*/ */
public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer, public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
final int numTrials) { final int numTrials) {
this(clusterer, numTrials, new SumOfClusterVariances<T>(clusterer.getDistanceMeasure())); this(clusterer,
numTrials,
ClusterEvaluator.ranking(new SumOfClusterVariances<T>(clusterer.getDistanceMeasure())));
} }
/** Build a clusterer. /** Build a clusterer.

View File

@ -1,89 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.ml.clustering.evaluation;
import java.util.List;
import org.apache.commons.math4.ml.clustering.CentroidCluster;
import org.apache.commons.math4.ml.clustering.Cluster;
import org.apache.commons.math4.ml.clustering.Clusterable;
import org.apache.commons.math4.ml.clustering.DoublePoint;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.ml.distance.EuclideanDistance;
/**
* Base class for cluster evaluation methods.
*
* @param <T> type of the clustered points
* @since 3.3
*/
public abstract class ClusterEvaluator<T extends Clusterable> {
/** The distance measure to use when evaluating the cluster. */
private final DistanceMeasure measure;
/**
* Creates a new cluster evaluator with an {@link EuclideanDistance}
* as distance measure.
*/
public ClusterEvaluator() {
this(new EuclideanDistance());
}
/**
* Creates a new cluster evaluator with the given distance measure.
* @param measure the distance measure to use
*/
public ClusterEvaluator(final DistanceMeasure measure) {
this.measure = measure;
}
/**
* Computes the evaluation score for the given list of clusters.
* @param clusters the clusters to evaluate
* @return the computed score
*/
public abstract double score(List<? extends Cluster<T>> clusters);
/**
* Returns whether the first evaluation score is considered to be better
* than the second one by this evaluator.
* <p>
* Specific implementations shall override this method if the returned scores
* do not follow the same ordering, i.e. smaller score is better.
*
* @param score1 the first score
* @param score2 the second score
* @return {@code true} if the first score is considered to be better, {@code false} otherwise
*/
public boolean isBetterScore(double score1, double score2) {
return score1 < score2;
}
/**
* Calculates the distance between two {@link Clusterable} instances
* with the configured {@link DistanceMeasure}.
*
* @param p1 the first clusterable
* @param p2 the second clusterable
* @return the distance between the two clusterables
*/
protected double distance(final Clusterable p1, final Clusterable p2) {
return measure.compute(p1.getPoint(), p2.getPoint());
}
}

View File

@ -21,7 +21,7 @@ import java.util.List;
import org.apache.commons.math4.ml.clustering.Cluster; import org.apache.commons.math4.ml.clustering.Cluster;
import org.apache.commons.math4.ml.clustering.Clusterable; import org.apache.commons.math4.ml.clustering.Clusterable;
import org.apache.commons.math4.ml.clustering.ClusterRanking; import org.apache.commons.math4.ml.clustering.ClusterEvaluator;
import org.apache.commons.math4.ml.distance.DistanceMeasure; import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.stat.descriptive.moment.Variance; import org.apache.commons.math4.stat.descriptive.moment.Variance;
@ -36,14 +36,16 @@ import org.apache.commons.math4.stat.descriptive.moment.Variance;
* @param <T> the type of the clustered points * @param <T> the type of the clustered points
* @since 3.3 * @since 3.3
*/ */
public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluator<T> public class SumOfClusterVariances<T extends Clusterable>
implements ClusterRanking<T> { implements ClusterEvaluator<T> {
/** The distance measure to use when evaluating the cluster. */
private final DistanceMeasure measure;
/** /**
* @param measure the distance measure to use * @param measure Distance measure.
*/ */
public SumOfClusterVariances(final DistanceMeasure measure) { public SumOfClusterVariances(final DistanceMeasure measure) {
super(measure); this.measure = measure;
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
@ -60,8 +62,8 @@ public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluat
for (final T point : cluster.getPoints()) { for (final T point : cluster.getPoints()) {
stat.increment(distance(point, center)); stat.increment(distance(point, center));
} }
varianceSum += stat.getResult();
varianceSum += stat.getResult();
} }
} }
return varianceSum; return varianceSum;
@ -69,7 +71,20 @@ public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluat
/** {@inheritDoc} */ /** {@inheritDoc} */
@Override @Override
public double compute(List<? extends Cluster<T>> clusters) { public boolean isBetterScore(double a,
return 1d / score(clusters); double b) {
return a < b;
}
/**
* Calculates the distance between two {@link Clusterable} instances
* with the configured {@link DistanceMeasure}.
*
* @param p1 the first clusterable
* @param p2 the second clusterable
* @return the distance between the two clusterables
*/
private double distance(final Clusterable p1, final Clusterable p2) {
return measure.compute(p1.getPoint(), p2.getPoint());
} }
} }

View File

@ -26,8 +26,7 @@ import java.util.List;
import org.apache.commons.math4.ml.clustering.Cluster; import org.apache.commons.math4.ml.clustering.Cluster;
import org.apache.commons.math4.ml.clustering.DoublePoint; import org.apache.commons.math4.ml.clustering.DoublePoint;
import org.apache.commons.math4.ml.clustering.evaluation.ClusterEvaluator; import org.apache.commons.math4.ml.clustering.ClusterEvaluator;
import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances;
import org.apache.commons.math4.ml.distance.EuclideanDistance; import org.apache.commons.math4.ml.distance.EuclideanDistance;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;