mirror of https://github.com/apache/lucene.git
LUCENE-7305 - use macro average in confusion matrix metrics, removed unused import in datasplitter
This commit is contained in:
parent
8808cf5373
commit
dc50b79a14
|
@ -175,7 +175,7 @@ public class ConfusionMatrixGenerator {
|
||||||
public double getPrecision(String klass) {
|
public double getPrecision(String klass) {
|
||||||
Map<String, Long> classifications = linearizedMatrix.get(klass);
|
Map<String, Long> classifications = linearizedMatrix.get(klass);
|
||||||
double tp = 0;
|
double tp = 0;
|
||||||
double fp = 0;
|
double den = 0; // tp + fp
|
||||||
if (classifications != null) {
|
if (classifications != null) {
|
||||||
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
||||||
if (klass.equals(entry.getKey())) {
|
if (klass.equals(entry.getKey())) {
|
||||||
|
@ -184,11 +184,11 @@ public class ConfusionMatrixGenerator {
|
||||||
}
|
}
|
||||||
for (Map<String, Long> values : linearizedMatrix.values()) {
|
for (Map<String, Long> values : linearizedMatrix.values()) {
|
||||||
if (values.containsKey(klass)) {
|
if (values.containsKey(klass)) {
|
||||||
fp += values.get(klass);
|
den += values.get(klass);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return tp > 0 ? tp / (tp + fp) : 0;
|
return tp > 0 ? tp / den : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -246,7 +246,7 @@ public class ConfusionMatrixGenerator {
|
||||||
if (this.accuracy == -1) {
|
if (this.accuracy == -1) {
|
||||||
double tp = 0d;
|
double tp = 0d;
|
||||||
double tn = 0d;
|
double tn = 0d;
|
||||||
double fp = 0d;
|
double tfp = 0d; // tp + fp
|
||||||
double fn = 0d;
|
double fn = 0d;
|
||||||
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
|
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
|
||||||
String klass = classification.getKey();
|
String klass = classification.getKey();
|
||||||
|
@ -259,63 +259,46 @@ public class ConfusionMatrixGenerator {
|
||||||
}
|
}
|
||||||
for (Map<String, Long> values : linearizedMatrix.values()) {
|
for (Map<String, Long> values : linearizedMatrix.values()) {
|
||||||
if (values.containsKey(klass)) {
|
if (values.containsKey(klass)) {
|
||||||
fp += values.get(klass);
|
tfp += values.get(klass);
|
||||||
} else {
|
} else {
|
||||||
tn++;
|
tn++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
this.accuracy = (tp + tn) / (fp + fn + tp + tn);
|
this.accuracy = (tp + tn) / (tfp + fn + tn);
|
||||||
}
|
}
|
||||||
return this.accuracy;
|
return this.accuracy;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get the precision (see {@link #getPrecision(String)}) over all the classes.
|
* get the macro averaged precision (see {@link #getPrecision(String)}) over all the classes.
|
||||||
*
|
*
|
||||||
* @return the precision as computed from the whole confusion matrix
|
* @return the macro averaged precision as computed from the confusion matrix
|
||||||
*/
|
*/
|
||||||
public double getPrecision() {
|
public double getPrecision() {
|
||||||
double tp = 0;
|
double p = 0;
|
||||||
double fp = 0;
|
|
||||||
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
|
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
|
||||||
String klass = classification.getKey();
|
String klass = classification.getKey();
|
||||||
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
|
p += getPrecision(klass);
|
||||||
if (klass.equals(entry.getKey())) {
|
|
||||||
tp += entry.getValue();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (Map<String, Long> values : linearizedMatrix.values()) {
|
|
||||||
if (values.containsKey(klass)) {
|
|
||||||
fp += values.get(klass);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return tp > 0 ? tp / (tp + fp) : 0;
|
return p / linearizedMatrix.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get the recall (see {@link #getRecall(String)}) over all the classes
|
* get the macro averaged recall (see {@link #getRecall(String)}) over all the classes
|
||||||
*
|
*
|
||||||
* @return the recall as computed from the whole confusion matrix
|
* @return the recall as computed from the confusion matrix
|
||||||
*/
|
*/
|
||||||
public double getRecall() {
|
public double getRecall() {
|
||||||
double tp = 0;
|
double r = 0;
|
||||||
double fn = 0;
|
|
||||||
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
|
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
|
||||||
String klass = classification.getKey();
|
String klass = classification.getKey();
|
||||||
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
|
r += getRecall(klass);
|
||||||
if (klass.equals(entry.getKey())) {
|
|
||||||
tp += entry.getValue();
|
|
||||||
} else {
|
|
||||||
fn += entry.getValue();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return tp + fn > 0 ? tp / (tp + fn) : 0;
|
return r / linearizedMatrix.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -30,7 +30,6 @@ import org.apache.lucene.index.IndexWriterConfig;
|
||||||
import org.apache.lucene.index.IndexableField;
|
import org.apache.lucene.index.IndexableField;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.SortedDocValues;
|
import org.apache.lucene.index.SortedDocValues;
|
||||||
import org.apache.lucene.index.Terms;
|
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
|
Loading…
Reference in New Issue