From 2507015f4c9a024c65de12a0de263ebe51d3de2a Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Fri, 8 Apr 2016 11:02:48 +0200 Subject: [PATCH] LUCENE-7193 - add generic f1-measure metric to confusion matrix --- .../utils/ConfusionMatrixGenerator.java | 39 ++++++++++++++----- .../utils/ConfusionMatrixGeneratorTest.java | 30 +++++++++++--- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java index c9ecc4b3e82..3dd8ba83d04 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java @@ -52,16 +52,17 @@ public class ConfusionMatrixGenerator { * get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier}, * generated on the given {@link LeafReader}, class and text fields. * - * @param reader the {@link LeafReader} containing the index used for creating the {@link Classifier} - * @param classifier the {@link Classifier} whose confusion matrix has to be generated - * @param classFieldName the name of the Lucene field used as the classifier's output - * @param textFieldName the nome the Lucene field used as the classifier's input - * @param the return type of the {@link ClassificationResult} returned by the given {@link Classifier} + * @param reader the {@link LeafReader} containing the index used for creating the {@link Classifier} + * @param classifier the {@link Classifier} whose confusion matrix has to be generated + * @param classFieldName the name of the Lucene field used as the classifier's output + * @param textFieldName the nome the Lucene field used as the classifier's input + * @param timeoutMilliseconds timeout to wait before stopping creating the confusion matrix + * @param the return type of the {@link ClassificationResult} returned by the given {@link Classifier} * @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} * @throws IOException if problems occurr while reading the index or using the classifier */ public static ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier classifier, String classFieldName, - String textFieldName) throws IOException { + String textFieldName, long timeoutMilliseconds) throws IOException { ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-")); @@ -72,7 +73,13 @@ public class ConfusionMatrixGenerator { TopDocs topDocs = indexSearcher.search(new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE); double time = 0d; + int counter = 0; for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + + if (timeoutMilliseconds > 0 && time >= timeoutMilliseconds) { + break; + } + Document doc = reader.document(scoreDoc.doc); String[] correctAnswers = doc.getValues(classFieldName); @@ -91,6 +98,7 @@ public class ConfusionMatrixGenerator { if (result != null) { T assignedClass = result.getAssignedClass(); if (assignedClass != null) { + counter++; String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString(); String correctAnswer; @@ -117,7 +125,7 @@ public class ConfusionMatrixGenerator { } } } catch (TimeoutException timeoutException) { - // add timeout + // add classification timeout time += 5000; } catch (ExecutionException | InterruptedException executionException) { throw new RuntimeException(executionException); @@ -126,7 +134,7 @@ public class ConfusionMatrixGenerator { } } } - return new ConfusionMatrix(counts, time / topDocs.totalHits, topDocs.totalHits); + return new ConfusionMatrix(counts, time / counter, counter); } finally { executorService.shutdown(); } @@ -167,7 +175,7 @@ public class ConfusionMatrixGenerator { public double getPrecision(String klass) { Map classifications = linearizedMatrix.get(klass); double tp = 0; - double fp = -1; + double fp = 0; if (classifications != null) { for (Map.Entry entry : classifications.entrySet()) { if (klass.equals(entry.getKey())) { @@ -180,7 +188,7 @@ public class ConfusionMatrixGenerator { } } } - return tp + fp > 0 ? tp / (tp + fp) : 0; + return tp > 0 ? tp / (tp + fp) : 0; } /** @@ -217,6 +225,17 @@ public class ConfusionMatrixGenerator { return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0; } + /** + * get the F-1 measure on this confusion matrix + * + * @return the F-1 measure + */ + public double getF1Measure() { + double recall = getRecall(); + double precision = getPrecision(); + return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0; + } + /** * Calculate accuracy on this confusion matrix using the formula: * {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)} diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java index d1966ec7c81..63cce2a1c8e 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java @@ -59,7 +59,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase return null; } }; - ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName); + ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, + classifier, categoryFieldName, textFieldName, -1); assertNotNull(confusionMatrix); assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); @@ -74,6 +75,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase double recall = confusionMatrix.getRecall(); assertTrue(recall >= 0d); assertTrue(recall <= 1d); + double f1Measure = confusionMatrix.getF1Measure(); + assertTrue(f1Measure >= 0d); + assertTrue(f1Measure <= 1d); } finally { if (reader != null) { reader.close(); @@ -88,7 +92,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase MockAnalyzer analyzer = new MockAnalyzer(random()); reader = getSampleIndex(analyzer); Classifier classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName); - ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName); + ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, + classifier, categoryFieldName, textFieldName, -1); assertNotNull(confusionMatrix); assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); @@ -102,6 +107,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase double recall = confusionMatrix.getRecall(); assertTrue(recall >= 0d); assertTrue(recall <= 1d); + double f1Measure = confusionMatrix.getF1Measure(); + assertTrue(f1Measure >= 0d); + assertTrue(f1Measure <= 1d); } finally { if (reader != null) { reader.close(); @@ -116,7 +124,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase MockAnalyzer analyzer = new MockAnalyzer(random()); reader = getSampleIndex(analyzer); Classifier classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName); - ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName); + ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, + classifier, categoryFieldName, textFieldName, -1); assertNotNull(confusionMatrix); assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); @@ -130,6 +139,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase double recall = confusionMatrix.getRecall(); assertTrue(recall >= 0d); assertTrue(recall <= 1d); + double f1Measure = confusionMatrix.getF1Measure(); + assertTrue(f1Measure >= 0d); + assertTrue(f1Measure <= 1d); } finally { if (reader != null) { reader.close(); @@ -144,7 +156,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase MockAnalyzer analyzer = new MockAnalyzer(random()); reader = getSampleIndex(analyzer); Classifier classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName); - ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName); + ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, + classifier, categoryFieldName, textFieldName, -1); assertNotNull(confusionMatrix); assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); @@ -158,6 +171,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase double recall = confusionMatrix.getRecall(); assertTrue(recall >= 0d); assertTrue(recall <= 1d); + double f1Measure = confusionMatrix.getF1Measure(); + assertTrue(f1Measure >= 0d); + assertTrue(f1Measure <= 1d); } finally { if (reader != null) { reader.close(); @@ -172,7 +188,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase MockAnalyzer analyzer = new MockAnalyzer(random()); reader = getSampleIndex(analyzer); Classifier classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName); - ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, booleanFieldName, textFieldName); + ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, + classifier, booleanFieldName, textFieldName, -1); assertNotNull(confusionMatrix); assertNotNull(confusionMatrix.getLinearizedMatrix()); assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs()); @@ -186,6 +203,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase double recall = confusionMatrix.getRecall(); assertTrue(recall >= 0d); assertTrue(recall <= 1d); + double f1Measure = confusionMatrix.getF1Measure(); + assertTrue(f1Measure >= 0d); + assertTrue(f1Measure <= 1d); assertTrue(confusionMatrix.getPrecision("true") >= 0d); assertTrue(confusionMatrix.getPrecision("true") <= 1d); assertTrue(confusionMatrix.getPrecision("false") >= 0d);