LUCENE-7305 - use macro average in confusion matrix metrics, removed unused import in datasplitter

This commit is contained in:
Tommaso Teofili 2016-05-26 16:13:46 +02:00
parent 8808cf5373
commit dc50b79a14
2 changed files with 16 additions and 34 deletions

View File

@ -175,7 +175,7 @@ public class ConfusionMatrixGenerator {
public double getPrecision(String klass) { public double getPrecision(String klass) {
Map<String, Long> classifications = linearizedMatrix.get(klass); Map<String, Long> classifications = linearizedMatrix.get(klass);
double tp = 0; double tp = 0;
double fp = 0; double den = 0; // tp + fp
if (classifications != null) { if (classifications != null) {
for (Map.Entry<String, Long> entry : classifications.entrySet()) { for (Map.Entry<String, Long> entry : classifications.entrySet()) {
if (klass.equals(entry.getKey())) { if (klass.equals(entry.getKey())) {
@ -184,11 +184,11 @@ public class ConfusionMatrixGenerator {
} }
for (Map<String, Long> values : linearizedMatrix.values()) { for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) { if (values.containsKey(klass)) {
fp += values.get(klass); den += values.get(klass);
} }
} }
} }
return tp > 0 ? tp / (tp + fp) : 0; return tp > 0 ? tp / den : 0;
} }
/** /**
@ -246,7 +246,7 @@ public class ConfusionMatrixGenerator {
if (this.accuracy == -1) { if (this.accuracy == -1) {
double tp = 0d; double tp = 0d;
double tn = 0d; double tn = 0d;
double fp = 0d; double tfp = 0d; // tp + fp
double fn = 0d; double fn = 0d;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) { for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey(); String klass = classification.getKey();
@ -259,63 +259,46 @@ public class ConfusionMatrixGenerator {
} }
for (Map<String, Long> values : linearizedMatrix.values()) { for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) { if (values.containsKey(klass)) {
fp += values.get(klass); tfp += values.get(klass);
} else { } else {
tn++; tn++;
} }
} }
} }
this.accuracy = (tp + tn) / (fp + fn + tp + tn); this.accuracy = (tp + tn) / (tfp + fn + tn);
} }
return this.accuracy; return this.accuracy;
} }
/** /**
* get the precision (see {@link #getPrecision(String)}) over all the classes. * get the macro averaged precision (see {@link #getPrecision(String)}) over all the classes.
* *
* @return the precision as computed from the whole confusion matrix * @return the macro averaged precision as computed from the confusion matrix
*/ */
public double getPrecision() { public double getPrecision() {
double tp = 0; double p = 0;
double fp = 0;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) { for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey(); String klass = classification.getKey();
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) { p += getPrecision(klass);
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
}
}
for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) {
fp += values.get(klass);
}
}
} }
return tp > 0 ? tp / (tp + fp) : 0; return p / linearizedMatrix.size();
} }
/** /**
* get the recall (see {@link #getRecall(String)}) over all the classes * get the macro averaged recall (see {@link #getRecall(String)}) over all the classes
* *
* @return the recall as computed from the whole confusion matrix * @return the recall as computed from the confusion matrix
*/ */
public double getRecall() { public double getRecall() {
double tp = 0; double r = 0;
double fn = 0;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) { for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey(); String klass = classification.getKey();
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) { r += getRecall(klass);
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
} else {
fn += entry.getValue();
}
}
} }
return tp + fn > 0 ? tp / (tp + fn) : 0; return r / linearizedMatrix.size();
} }
@Override @Override

View File

@ -30,7 +30,6 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;