MATH-1523: Abstract class replaced by an interface.
This commit is contained in:
parent
aafc49afd7
commit
c770e66963
|
@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
|
||||||
</release>
|
</release>
|
||||||
|
|
||||||
<release version="4.0" date="XXXX-XX-XX" description="">
|
<release version="4.0" date="XXXX-XX-XX" description="">
|
||||||
|
<action dev="erans" type="update" issue="MATH-1523">
|
||||||
|
Abstract class "ClusterEvaluator" replaced by an interface.
|
||||||
|
</action>
|
||||||
<action dev="erans" type="fix" issue="MATH-1518">
|
<action dev="erans" type="fix" issue="MATH-1518">
|
||||||
Remove code duplication by moving method to class "Cluster".
|
Remove code duplication by moving method to class "Cluster".
|
||||||
</action>
|
</action>
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.commons.math4.ml.clustering;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface ClusterEvaluator<T extends Clusterable> {
|
||||||
|
/**
|
||||||
|
* @param cList List of clusters.
|
||||||
|
* @return the score attributed by the evaluator.
|
||||||
|
*/
|
||||||
|
double score(List<? extends Cluster<T>> cList);
|
||||||
|
/**
|
||||||
|
* @param a Score computed by this evaluator.
|
||||||
|
* @param b Score computed by this evaluator.
|
||||||
|
* @return true if the evaluator considers score {@code a} is
|
||||||
|
* considered better than score {@code b}.
|
||||||
|
*/
|
||||||
|
boolean isBetterScore(double a, double b);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts to a {@link ClusterRanking ranking function}
|
||||||
|
* (as required by clustering implementations).
|
||||||
|
*
|
||||||
|
* @param eval Evaluator function.
|
||||||
|
* @return a ranking function.
|
||||||
|
*/
|
||||||
|
static <T extends Clusterable> ClusterRanking ranking(ClusterEvaluator<T> eval) {
|
||||||
|
if (eval.isBetterScore(1, 2)) {
|
||||||
|
return cList -> 1 / eval.score(cList);
|
||||||
|
} else {
|
||||||
|
return cList -> eval.score(cList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -22,7 +22,6 @@ import java.util.List;
|
||||||
|
|
||||||
import org.apache.commons.math4.exception.ConvergenceException;
|
import org.apache.commons.math4.exception.ConvergenceException;
|
||||||
import org.apache.commons.math4.exception.MathIllegalArgumentException;
|
import org.apache.commons.math4.exception.MathIllegalArgumentException;
|
||||||
import org.apache.commons.math4.ml.clustering.evaluation.ClusterEvaluator;
|
|
||||||
import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances;
|
import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -48,7 +47,9 @@ public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Cluster
|
||||||
*/
|
*/
|
||||||
public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
|
public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
|
||||||
final int numTrials) {
|
final int numTrials) {
|
||||||
this(clusterer, numTrials, new SumOfClusterVariances<T>(clusterer.getDistanceMeasure()));
|
this(clusterer,
|
||||||
|
numTrials,
|
||||||
|
ClusterEvaluator.ranking(new SumOfClusterVariances<T>(clusterer.getDistanceMeasure())));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Build a clusterer.
|
/** Build a clusterer.
|
||||||
|
|
|
@ -1,89 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.commons.math4.ml.clustering.evaluation;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import org.apache.commons.math4.ml.clustering.CentroidCluster;
|
|
||||||
import org.apache.commons.math4.ml.clustering.Cluster;
|
|
||||||
import org.apache.commons.math4.ml.clustering.Clusterable;
|
|
||||||
import org.apache.commons.math4.ml.clustering.DoublePoint;
|
|
||||||
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
|
||||||
import org.apache.commons.math4.ml.distance.EuclideanDistance;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Base class for cluster evaluation methods.
|
|
||||||
*
|
|
||||||
* @param <T> type of the clustered points
|
|
||||||
* @since 3.3
|
|
||||||
*/
|
|
||||||
public abstract class ClusterEvaluator<T extends Clusterable> {
|
|
||||||
|
|
||||||
/** The distance measure to use when evaluating the cluster. */
|
|
||||||
private final DistanceMeasure measure;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a new cluster evaluator with an {@link EuclideanDistance}
|
|
||||||
* as distance measure.
|
|
||||||
*/
|
|
||||||
public ClusterEvaluator() {
|
|
||||||
this(new EuclideanDistance());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a new cluster evaluator with the given distance measure.
|
|
||||||
* @param measure the distance measure to use
|
|
||||||
*/
|
|
||||||
public ClusterEvaluator(final DistanceMeasure measure) {
|
|
||||||
this.measure = measure;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the evaluation score for the given list of clusters.
|
|
||||||
* @param clusters the clusters to evaluate
|
|
||||||
* @return the computed score
|
|
||||||
*/
|
|
||||||
public abstract double score(List<? extends Cluster<T>> clusters);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns whether the first evaluation score is considered to be better
|
|
||||||
* than the second one by this evaluator.
|
|
||||||
* <p>
|
|
||||||
* Specific implementations shall override this method if the returned scores
|
|
||||||
* do not follow the same ordering, i.e. smaller score is better.
|
|
||||||
*
|
|
||||||
* @param score1 the first score
|
|
||||||
* @param score2 the second score
|
|
||||||
* @return {@code true} if the first score is considered to be better, {@code false} otherwise
|
|
||||||
*/
|
|
||||||
public boolean isBetterScore(double score1, double score2) {
|
|
||||||
return score1 < score2;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculates the distance between two {@link Clusterable} instances
|
|
||||||
* with the configured {@link DistanceMeasure}.
|
|
||||||
*
|
|
||||||
* @param p1 the first clusterable
|
|
||||||
* @param p2 the second clusterable
|
|
||||||
* @return the distance between the two clusterables
|
|
||||||
*/
|
|
||||||
protected double distance(final Clusterable p1, final Clusterable p2) {
|
|
||||||
return measure.compute(p1.getPoint(), p2.getPoint());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -21,7 +21,7 @@ import java.util.List;
|
||||||
|
|
||||||
import org.apache.commons.math4.ml.clustering.Cluster;
|
import org.apache.commons.math4.ml.clustering.Cluster;
|
||||||
import org.apache.commons.math4.ml.clustering.Clusterable;
|
import org.apache.commons.math4.ml.clustering.Clusterable;
|
||||||
import org.apache.commons.math4.ml.clustering.ClusterRanking;
|
import org.apache.commons.math4.ml.clustering.ClusterEvaluator;
|
||||||
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
||||||
import org.apache.commons.math4.stat.descriptive.moment.Variance;
|
import org.apache.commons.math4.stat.descriptive.moment.Variance;
|
||||||
|
|
||||||
|
@ -36,14 +36,16 @@ import org.apache.commons.math4.stat.descriptive.moment.Variance;
|
||||||
* @param <T> the type of the clustered points
|
* @param <T> the type of the clustered points
|
||||||
* @since 3.3
|
* @since 3.3
|
||||||
*/
|
*/
|
||||||
public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluator<T>
|
public class SumOfClusterVariances<T extends Clusterable>
|
||||||
implements ClusterRanking<T> {
|
implements ClusterEvaluator<T> {
|
||||||
|
/** The distance measure to use when evaluating the cluster. */
|
||||||
|
private final DistanceMeasure measure;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param measure the distance measure to use
|
* @param measure Distance measure.
|
||||||
*/
|
*/
|
||||||
public SumOfClusterVariances(final DistanceMeasure measure) {
|
public SumOfClusterVariances(final DistanceMeasure measure) {
|
||||||
super(measure);
|
this.measure = measure;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
|
@ -60,8 +62,8 @@ public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluat
|
||||||
for (final T point : cluster.getPoints()) {
|
for (final T point : cluster.getPoints()) {
|
||||||
stat.increment(distance(point, center));
|
stat.increment(distance(point, center));
|
||||||
}
|
}
|
||||||
varianceSum += stat.getResult();
|
|
||||||
|
|
||||||
|
varianceSum += stat.getResult();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return varianceSum;
|
return varianceSum;
|
||||||
|
@ -69,7 +71,20 @@ public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluat
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
/** {@inheritDoc} */
|
||||||
@Override
|
@Override
|
||||||
public double compute(List<? extends Cluster<T>> clusters) {
|
public boolean isBetterScore(double a,
|
||||||
return 1d / score(clusters);
|
double b) {
|
||||||
|
return a < b;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates the distance between two {@link Clusterable} instances
|
||||||
|
* with the configured {@link DistanceMeasure}.
|
||||||
|
*
|
||||||
|
* @param p1 the first clusterable
|
||||||
|
* @param p2 the second clusterable
|
||||||
|
* @return the distance between the two clusterables
|
||||||
|
*/
|
||||||
|
private double distance(final Clusterable p1, final Clusterable p2) {
|
||||||
|
return measure.compute(p1.getPoint(), p2.getPoint());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,8 +26,7 @@ import java.util.List;
|
||||||
|
|
||||||
import org.apache.commons.math4.ml.clustering.Cluster;
|
import org.apache.commons.math4.ml.clustering.Cluster;
|
||||||
import org.apache.commons.math4.ml.clustering.DoublePoint;
|
import org.apache.commons.math4.ml.clustering.DoublePoint;
|
||||||
import org.apache.commons.math4.ml.clustering.evaluation.ClusterEvaluator;
|
import org.apache.commons.math4.ml.clustering.ClusterEvaluator;
|
||||||
import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances;
|
|
||||||
import org.apache.commons.math4.ml.distance.EuclideanDistance;
|
import org.apache.commons.math4.ml.distance.EuclideanDistance;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
Loading…
Reference in New Issue