LUCENE-7156 - fixed precision and accuracy calculations

This commit is contained in:
Tommaso Teofili 2016-03-31 14:45:11 +02:00
parent e1b45568b4
commit d08f327a7f
2 changed files with 68 additions and 48 deletions

View File

@ -106,11 +106,11 @@ public class ConfusionMatrixGenerator {
if (aLong != null) {
stringLongMap.put(classified, aLong + 1);
} else {
stringLongMap.put(classified, 1l);
stringLongMap.put(classified, 1L);
}
} else {
stringLongMap = new HashMap<>();
stringLongMap.put(classified, 1l);
stringLongMap.put(classified, 1L);
counts.put(correctAnswer, stringLongMap);
}
@ -225,23 +225,29 @@ public class ConfusionMatrixGenerator {
*/
public double getAccuracy() {
if (this.accuracy == -1) {
double cc = 0d;
double wc = 0d;
for (Map.Entry<String, Map<String, Long>> entry : linearizedMatrix.entrySet()) {
String correctAnswer = entry.getKey();
for (Map.Entry<String, Long> classifiedAnswers : entry.getValue().entrySet()) {
Long value = classifiedAnswers.getValue();
if (value != null) {
if (correctAnswer.equals(classifiedAnswers.getKey())) {
cc += value;
double tp = 0d;
double tn = 0d;
double fp = 0d;
double fn = 0d;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey();
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
} else {
wc += value;
fn += entry.getValue();
}
}
for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) {
fp += values.get(klass);
} else {
tn++;
}
}
}
this.accuracy = cc / (cc + wc);
this.accuracy = (tp + tn) / (fp + fn + tp + tn);
}
return this.accuracy;
}
@ -253,7 +259,7 @@ public class ConfusionMatrixGenerator {
*/
public double getPrecision() {
double tp = 0;
double fp = -linearizedMatrix.size();
double fp = 0;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey();
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
@ -268,8 +274,7 @@ public class ConfusionMatrixGenerator {
}
}
return tp + fp > 0 ? tp / (tp + fp) : 0;
return tp > 0 ? tp / (tp + fp) : 0;
}
/**

View File

@ -65,12 +65,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(avgClassificationTime >= 0d );
assertTrue(confusionMatrix.getAccuracy() >= 0d);
assertTrue(confusionMatrix.getAccuracy() <= 1d);
assertTrue(confusionMatrix.getPrecision() >= 0d);
assertTrue(confusionMatrix.getPrecision() <= 1d);
assertTrue(confusionMatrix.getRecall() >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -90,12 +93,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d);
assertTrue(confusionMatrix.getAccuracy() <= 1d);
assertTrue(confusionMatrix.getPrecision() >= 0d);
assertTrue(confusionMatrix.getPrecision() <= 1d);
assertTrue(confusionMatrix.getRecall() >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -115,12 +121,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d);
assertTrue(confusionMatrix.getAccuracy() <= 1d);
assertTrue(confusionMatrix.getPrecision() >= 0d);
assertTrue(confusionMatrix.getPrecision() <= 1d);
assertTrue(confusionMatrix.getRecall() >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -140,12 +149,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d);
assertTrue(confusionMatrix.getAccuracy() <= 1d);
assertTrue(confusionMatrix.getPrecision() >= 0d);
assertTrue(confusionMatrix.getPrecision() <= 1d);
assertTrue(confusionMatrix.getRecall() >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -165,12 +177,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d);
assertTrue(confusionMatrix.getAccuracy() <= 1d);
assertTrue(confusionMatrix.getPrecision() >= 0d);
assertTrue(confusionMatrix.getPrecision() <= 1d);
assertTrue(confusionMatrix.getRecall() >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
assertTrue(confusionMatrix.getPrecision("true") >= 0d);
assertTrue(confusionMatrix.getPrecision("true") <= 1d);
assertTrue(confusionMatrix.getPrecision("false") >= 0d);