LUCENE-7156 - fixed precision and accuracy calculations

This commit is contained in:
Tommaso Teofili 2016-03-31 14:45:11 +02:00
parent e1b45568b4
commit d08f327a7f
2 changed files with 68 additions and 48 deletions

View File

@ -106,11 +106,11 @@ public class ConfusionMatrixGenerator {
if (aLong != null) { if (aLong != null) {
stringLongMap.put(classified, aLong + 1); stringLongMap.put(classified, aLong + 1);
} else { } else {
stringLongMap.put(classified, 1l); stringLongMap.put(classified, 1L);
} }
} else { } else {
stringLongMap = new HashMap<>(); stringLongMap = new HashMap<>();
stringLongMap.put(classified, 1l); stringLongMap.put(classified, 1L);
counts.put(correctAnswer, stringLongMap); counts.put(correctAnswer, stringLongMap);
} }
@ -225,23 +225,29 @@ public class ConfusionMatrixGenerator {
*/ */
public double getAccuracy() { public double getAccuracy() {
if (this.accuracy == -1) { if (this.accuracy == -1) {
double cc = 0d; double tp = 0d;
double wc = 0d; double tn = 0d;
for (Map.Entry<String, Map<String, Long>> entry : linearizedMatrix.entrySet()) { double fp = 0d;
String correctAnswer = entry.getKey(); double fn = 0d;
for (Map.Entry<String, Long> classifiedAnswers : entry.getValue().entrySet()) { for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
Long value = classifiedAnswers.getValue(); String klass = classification.getKey();
if (value != null) { for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
if (correctAnswer.equals(classifiedAnswers.getKey())) { if (klass.equals(entry.getKey())) {
cc += value; tp += entry.getValue();
} else { } else {
wc += value; fn += entry.getValue();
} }
} }
for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) {
fp += values.get(klass);
} else {
tn++;
}
} }
} }
this.accuracy = cc / (cc + wc); this.accuracy = (tp + tn) / (fp + fn + tp + tn);
} }
return this.accuracy; return this.accuracy;
} }
@ -253,7 +259,7 @@ public class ConfusionMatrixGenerator {
*/ */
public double getPrecision() { public double getPrecision() {
double tp = 0; double tp = 0;
double fp = -linearizedMatrix.size(); double fp = 0;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) { for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey(); String klass = classification.getKey();
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) { for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
@ -268,8 +274,7 @@ public class ConfusionMatrixGenerator {
} }
} }
return tp + fp > 0 ? tp / (tp + fp) : 0; return tp > 0 ? tp / (tp + fp) : 0;
} }
/** /**

View File

@ -65,12 +65,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
double avgClassificationTime = confusionMatrix.getAvgClassificationTime(); double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(avgClassificationTime >= 0d ); assertTrue(avgClassificationTime >= 0d );
assertTrue(confusionMatrix.getAccuracy() >= 0d); double accuracy = confusionMatrix.getAccuracy();
assertTrue(confusionMatrix.getAccuracy() <= 1d); assertTrue(accuracy >= 0d);
assertTrue(confusionMatrix.getPrecision() >= 0d); assertTrue(accuracy <= 1d);
assertTrue(confusionMatrix.getPrecision() <= 1d); double precision = confusionMatrix.getPrecision();
assertTrue(confusionMatrix.getRecall() >= 0d); assertTrue(precision >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d); assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
} finally { } finally {
if (reader != null) { if (reader != null) {
reader.close(); reader.close();
@ -90,12 +93,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix()); assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d); assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d); double accuracy = confusionMatrix.getAccuracy();
assertTrue(confusionMatrix.getAccuracy() <= 1d); assertTrue(accuracy >= 0d);
assertTrue(confusionMatrix.getPrecision() >= 0d); assertTrue(accuracy <= 1d);
assertTrue(confusionMatrix.getPrecision() <= 1d); double precision = confusionMatrix.getPrecision();
assertTrue(confusionMatrix.getRecall() >= 0d); assertTrue(precision >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d); assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
} finally { } finally {
if (reader != null) { if (reader != null) {
reader.close(); reader.close();
@ -115,12 +121,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix()); assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d); assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d); double accuracy = confusionMatrix.getAccuracy();
assertTrue(confusionMatrix.getAccuracy() <= 1d); assertTrue(accuracy >= 0d);
assertTrue(confusionMatrix.getPrecision() >= 0d); assertTrue(accuracy <= 1d);
assertTrue(confusionMatrix.getPrecision() <= 1d); double precision = confusionMatrix.getPrecision();
assertTrue(confusionMatrix.getRecall() >= 0d); assertTrue(precision >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d); assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
} finally { } finally {
if (reader != null) { if (reader != null) {
reader.close(); reader.close();
@ -140,12 +149,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix()); assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d); assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d); double accuracy = confusionMatrix.getAccuracy();
assertTrue(confusionMatrix.getAccuracy() <= 1d); assertTrue(accuracy >= 0d);
assertTrue(confusionMatrix.getPrecision() >= 0d); assertTrue(accuracy <= 1d);
assertTrue(confusionMatrix.getPrecision() <= 1d); double precision = confusionMatrix.getPrecision();
assertTrue(confusionMatrix.getRecall() >= 0d); assertTrue(precision >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d); assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
} finally { } finally {
if (reader != null) { if (reader != null) {
reader.close(); reader.close();
@ -165,12 +177,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix()); assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d); assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d); double accuracy = confusionMatrix.getAccuracy();
assertTrue(confusionMatrix.getAccuracy() <= 1d); assertTrue(accuracy >= 0d);
assertTrue(confusionMatrix.getPrecision() >= 0d); assertTrue(accuracy <= 1d);
assertTrue(confusionMatrix.getPrecision() <= 1d); double precision = confusionMatrix.getPrecision();
assertTrue(confusionMatrix.getRecall() >= 0d); assertTrue(precision >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d); assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
assertTrue(confusionMatrix.getPrecision("true") >= 0d); assertTrue(confusionMatrix.getPrecision("true") >= 0d);
assertTrue(confusionMatrix.getPrecision("true") <= 1d); assertTrue(confusionMatrix.getPrecision("true") <= 1d);
assertTrue(confusionMatrix.getPrecision("false") >= 0d); assertTrue(confusionMatrix.getPrecision("false") >= 0d);