diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java index 17f5b21afc6..c9ecc4b3e82 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java @@ -106,11 +106,11 @@ public class ConfusionMatrixGenerator { if (aLong != null) { stringLongMap.put(classified, aLong + 1); } else { - stringLongMap.put(classified, 1l); + stringLongMap.put(classified, 1L); } } else { stringLongMap = new HashMap<>(); - stringLongMap.put(classified, 1l); + stringLongMap.put(classified, 1L); counts.put(correctAnswer, stringLongMap); } @@ -225,23 +225,29 @@ public class ConfusionMatrixGenerator { */ public double getAccuracy() { if (this.accuracy == -1) { - double cc = 0d; - double wc = 0d; - for (Map.Entry> entry : linearizedMatrix.entrySet()) { - String correctAnswer = entry.getKey(); - for (Map.Entry classifiedAnswers : entry.getValue().entrySet()) { - Long value = classifiedAnswers.getValue(); - if (value != null) { - if (correctAnswer.equals(classifiedAnswers.getKey())) { - cc += value; - } else { - wc += value; - } + double tp = 0d; + double tn = 0d; + double fp = 0d; + double fn = 0d; + for (Map.Entry> classification : linearizedMatrix.entrySet()) { + String klass = classification.getKey(); + for (Map.Entry entry : classification.getValue().entrySet()) { + if (klass.equals(entry.getKey())) { + tp += entry.getValue(); + } else { + fn += entry.getValue(); + } + } + for (Map values : linearizedMatrix.values()) { + if (values.containsKey(klass)) { + fp += values.get(klass); + } else { + tn++; } } } - this.accuracy = cc / (cc + wc); + this.accuracy = (tp + tn) / (fp + fn + tp + tn); } return this.accuracy; } @@ -253,7 +259,7 @@ public class ConfusionMatrixGenerator { */ public double getPrecision() { double tp = 0; - double fp = -linearizedMatrix.size(); + double fp = 0; for (Map.Entry> classification : linearizedMatrix.entrySet()) { String klass = classification.getKey(); for (Map.Entry entry : classification.getValue().entrySet()) { @@ -268,8 +274,7 @@ public class ConfusionMatrixGenerator { } } - return tp + fp > 0 ? tp / (tp + fp) : 0; - + return tp > 0 ? tp / (tp + fp) : 0; } /** diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java index e582b79d615..d1966ec7c81 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java @@ -65,12 +65,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); double avgClassificationTime = confusionMatrix.getAvgClassificationTime(); assertTrue(avgClassificationTime >= 0d ); - assertTrue(confusionMatrix.getAccuracy() >= 0d); - assertTrue(confusionMatrix.getAccuracy() <= 1d); - assertTrue(confusionMatrix.getPrecision() >= 0d); - assertTrue(confusionMatrix.getPrecision() <= 1d); - assertTrue(confusionMatrix.getRecall() >= 0d); - assertTrue(confusionMatrix.getRecall() <= 1d); + double accuracy = confusionMatrix.getAccuracy(); + assertTrue(accuracy >= 0d); + assertTrue(accuracy <= 1d); + double precision = confusionMatrix.getPrecision(); + assertTrue(precision >= 0d); + assertTrue(precision <= 1d); + double recall = confusionMatrix.getRecall(); + assertTrue(recall >= 0d); + assertTrue(recall <= 1d); } finally { if (reader != null) { reader.close(); @@ -90,12 +93,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d); - assertTrue(confusionMatrix.getAccuracy() >= 0d); - assertTrue(confusionMatrix.getAccuracy() <= 1d); - assertTrue(confusionMatrix.getPrecision() >= 0d); - assertTrue(confusionMatrix.getPrecision() <= 1d); - assertTrue(confusionMatrix.getRecall() >= 0d); - assertTrue(confusionMatrix.getRecall() <= 1d); + double accuracy = confusionMatrix.getAccuracy(); + assertTrue(accuracy >= 0d); + assertTrue(accuracy <= 1d); + double precision = confusionMatrix.getPrecision(); + assertTrue(precision >= 0d); + assertTrue(precision <= 1d); + double recall = confusionMatrix.getRecall(); + assertTrue(recall >= 0d); + assertTrue(recall <= 1d); } finally { if (reader != null) { reader.close(); @@ -115,12 +121,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d); - assertTrue(confusionMatrix.getAccuracy() >= 0d); - assertTrue(confusionMatrix.getAccuracy() <= 1d); - assertTrue(confusionMatrix.getPrecision() >= 0d); - assertTrue(confusionMatrix.getPrecision() <= 1d); - assertTrue(confusionMatrix.getRecall() >= 0d); - assertTrue(confusionMatrix.getRecall() <= 1d); + double accuracy = confusionMatrix.getAccuracy(); + assertTrue(accuracy >= 0d); + assertTrue(accuracy <= 1d); + double precision = confusionMatrix.getPrecision(); + assertTrue(precision >= 0d); + assertTrue(precision <= 1d); + double recall = confusionMatrix.getRecall(); + assertTrue(recall >= 0d); + assertTrue(recall <= 1d); } finally { if (reader != null) { reader.close(); @@ -140,12 +149,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d); - assertTrue(confusionMatrix.getAccuracy() >= 0d); - assertTrue(confusionMatrix.getAccuracy() <= 1d); - assertTrue(confusionMatrix.getPrecision() >= 0d); - assertTrue(confusionMatrix.getPrecision() <= 1d); - assertTrue(confusionMatrix.getRecall() >= 0d); - assertTrue(confusionMatrix.getRecall() <= 1d); + double accuracy = confusionMatrix.getAccuracy(); + assertTrue(accuracy >= 0d); + assertTrue(accuracy <= 1d); + double precision = confusionMatrix.getPrecision(); + assertTrue(precision >= 0d); + assertTrue(precision <= 1d); + double recall = confusionMatrix.getRecall(); + assertTrue(recall >= 0d); + assertTrue(recall <= 1d); } finally { if (reader != null) { reader.close(); @@ -165,12 +177,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d); - assertTrue(confusionMatrix.getAccuracy() >= 0d); - assertTrue(confusionMatrix.getAccuracy() <= 1d); - assertTrue(confusionMatrix.getPrecision() >= 0d); - assertTrue(confusionMatrix.getPrecision() <= 1d); - assertTrue(confusionMatrix.getRecall() >= 0d); - assertTrue(confusionMatrix.getRecall() <= 1d); + double accuracy = confusionMatrix.getAccuracy(); + assertTrue(accuracy >= 0d); + assertTrue(accuracy <= 1d); + double precision = confusionMatrix.getPrecision(); + assertTrue(precision >= 0d); + assertTrue(precision <= 1d); + double recall = confusionMatrix.getRecall(); + assertTrue(recall >= 0d); + assertTrue(recall <= 1d); assertTrue(confusionMatrix.getPrecision("true") >= 0d); assertTrue(confusionMatrix.getPrecision("true") <= 1d); assertTrue(confusionMatrix.getPrecision("false") >= 0d);