diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java index 3dd8ba83d04..65de8015f20 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java @@ -175,7 +175,7 @@ public class ConfusionMatrixGenerator { public double getPrecision(String klass) { Map classifications = linearizedMatrix.get(klass); double tp = 0; - double fp = 0; + double den = 0; // tp + fp if (classifications != null) { for (Map.Entry entry : classifications.entrySet()) { if (klass.equals(entry.getKey())) { @@ -184,11 +184,11 @@ public class ConfusionMatrixGenerator { } for (Map values : linearizedMatrix.values()) { if (values.containsKey(klass)) { - fp += values.get(klass); + den += values.get(klass); } } } - return tp > 0 ? tp / (tp + fp) : 0; + return tp > 0 ? tp / den : 0; } /** @@ -246,7 +246,7 @@ public class ConfusionMatrixGenerator { if (this.accuracy == -1) { double tp = 0d; double tn = 0d; - double fp = 0d; + double tfp = 0d; // tp + fp double fn = 0d; for (Map.Entry> classification : linearizedMatrix.entrySet()) { String klass = classification.getKey(); @@ -259,63 +259,46 @@ public class ConfusionMatrixGenerator { } for (Map values : linearizedMatrix.values()) { if (values.containsKey(klass)) { - fp += values.get(klass); + tfp += values.get(klass); } else { tn++; } } } - this.accuracy = (tp + tn) / (fp + fn + tp + tn); + this.accuracy = (tp + tn) / (tfp + fn + tn); } return this.accuracy; } /** - * get the precision (see {@link #getPrecision(String)}) over all the classes. + * get the macro averaged precision (see {@link #getPrecision(String)}) over all the classes. * - * @return the precision as computed from the whole confusion matrix + * @return the macro averaged precision as computed from the confusion matrix */ public double getPrecision() { - double tp = 0; - double fp = 0; + double p = 0; for (Map.Entry> classification : linearizedMatrix.entrySet()) { String klass = classification.getKey(); - for (Map.Entry entry : classification.getValue().entrySet()) { - if (klass.equals(entry.getKey())) { - tp += entry.getValue(); - } - } - for (Map values : linearizedMatrix.values()) { - if (values.containsKey(klass)) { - fp += values.get(klass); - } - } + p += getPrecision(klass); } - return tp > 0 ? tp / (tp + fp) : 0; + return p / linearizedMatrix.size(); } /** - * get the recall (see {@link #getRecall(String)}) over all the classes + * get the macro averaged recall (see {@link #getRecall(String)}) over all the classes * - * @return the recall as computed from the whole confusion matrix + * @return the recall as computed from the confusion matrix */ public double getRecall() { - double tp = 0; - double fn = 0; + double r = 0; for (Map.Entry> classification : linearizedMatrix.entrySet()) { String klass = classification.getKey(); - for (Map.Entry entry : classification.getValue().entrySet()) { - if (klass.equals(entry.getKey())) { - tp += entry.getValue(); - } else { - fn += entry.getValue(); - } - } + r += getRecall(klass); } - return tp + fn > 0 ? tp / (tp + fn) : 0; + return r / linearizedMatrix.size(); } @Override diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java index c1c8ad19ee6..fbc57b93eb2 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java @@ -30,7 +30,6 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.SortedDocValues; -import org.apache.lucene.index.Terms; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.ScoreDoc;