mirror of https://github.com/apache/lucene.git
LUCENE-7156 - fixed precision and accuracy calculations
This commit is contained in:
parent
e1b45568b4
commit
d08f327a7f
|
@ -106,11 +106,11 @@ public class ConfusionMatrixGenerator {
|
|||
if (aLong != null) {
|
||||
stringLongMap.put(classified, aLong + 1);
|
||||
} else {
|
||||
stringLongMap.put(classified, 1l);
|
||||
stringLongMap.put(classified, 1L);
|
||||
}
|
||||
} else {
|
||||
stringLongMap = new HashMap<>();
|
||||
stringLongMap.put(classified, 1l);
|
||||
stringLongMap.put(classified, 1L);
|
||||
counts.put(correctAnswer, stringLongMap);
|
||||
}
|
||||
|
||||
|
@ -225,23 +225,29 @@ public class ConfusionMatrixGenerator {
|
|||
*/
|
||||
public double getAccuracy() {
|
||||
if (this.accuracy == -1) {
|
||||
double cc = 0d;
|
||||
double wc = 0d;
|
||||
for (Map.Entry<String, Map<String, Long>> entry : linearizedMatrix.entrySet()) {
|
||||
String correctAnswer = entry.getKey();
|
||||
for (Map.Entry<String, Long> classifiedAnswers : entry.getValue().entrySet()) {
|
||||
Long value = classifiedAnswers.getValue();
|
||||
if (value != null) {
|
||||
if (correctAnswer.equals(classifiedAnswers.getKey())) {
|
||||
cc += value;
|
||||
double tp = 0d;
|
||||
double tn = 0d;
|
||||
double fp = 0d;
|
||||
double fn = 0d;
|
||||
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
|
||||
String klass = classification.getKey();
|
||||
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
|
||||
if (klass.equals(entry.getKey())) {
|
||||
tp += entry.getValue();
|
||||
} else {
|
||||
wc += value;
|
||||
fn += entry.getValue();
|
||||
}
|
||||
}
|
||||
for (Map<String, Long> values : linearizedMatrix.values()) {
|
||||
if (values.containsKey(klass)) {
|
||||
fp += values.get(klass);
|
||||
} else {
|
||||
tn++;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
this.accuracy = cc / (cc + wc);
|
||||
this.accuracy = (tp + tn) / (fp + fn + tp + tn);
|
||||
}
|
||||
return this.accuracy;
|
||||
}
|
||||
|
@ -253,7 +259,7 @@ public class ConfusionMatrixGenerator {
|
|||
*/
|
||||
public double getPrecision() {
|
||||
double tp = 0;
|
||||
double fp = -linearizedMatrix.size();
|
||||
double fp = 0;
|
||||
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
|
||||
String klass = classification.getKey();
|
||||
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
|
||||
|
@ -268,8 +274,7 @@ public class ConfusionMatrixGenerator {
|
|||
}
|
||||
}
|
||||
|
||||
return tp + fp > 0 ? tp / (tp + fp) : 0;
|
||||
|
||||
return tp > 0 ? tp / (tp + fp) : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -65,12 +65,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(avgClassificationTime >= 0d );
|
||||
assertTrue(confusionMatrix.getAccuracy() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision() >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision() <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall() >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall() <= 1d);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -90,12 +93,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision() >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision() <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall() >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall() <= 1d);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -115,12 +121,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision() >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision() <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall() >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall() <= 1d);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -140,12 +149,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision() >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision() <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall() >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall() <= 1d);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -165,12 +177,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision() >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision() <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall() >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall() <= 1d);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision("true") >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision("true") <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision("false") >= 0d);
|
||||
|
|
Loading…
Reference in New Issue