From 0deadf0671f826a7eef40a9dab97679b9ee3f33e Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Mon, 26 Oct 2015 14:21:18 +0000 Subject: [PATCH] LUCENE-6854 - adjusted precision, recall, f1 measure metrics git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1710605 13f79535-47bb-0310-9956-ffa450edef68 --- .../KNearestNeighborClassifier.java | 3 +- .../SimpleNaiveBayesClassifier.java | 1 - .../utils/ConfusionMatrixGenerator.java | 36 ++++++++++--------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index 61707228b88..bb95726b45c 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -238,10 +238,11 @@ public class KNearestNeighborClassifier implements Classifier { @Override public String toString() { return "KNearestNeighborClassifier{" + - ", textFieldNames=" + Arrays.toString(textFieldNames) + + "textFieldNames=" + Arrays.toString(textFieldNames) + ", classFieldName='" + classFieldName + '\'' + ", k=" + k + ", query=" + query + + ", similarity=" + indexSearcher.getSimilarity(true) + '}'; } } \ No newline at end of file diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index 6f7da1afb9f..73c90de5b1f 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -151,7 +151,6 @@ public class SimpleNaiveBayesClassifier implements Classifier { int docsWithClassSize = countDocsWithClass(); while ((next = classesEnum.next()) != null) { if (next.length > 0) { - // We are passing the term to IndexSearcher so we need to make sure it will not change over time Term term = new Term(this.classFieldName, next); double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize); assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal)); diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java index 78ad691ac1f..a6ccba9da27 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java @@ -168,17 +168,19 @@ public class ConfusionMatrixGenerator { Map classifications = linearizedMatrix.get(klass); double tp = 0; double fp = 0; - for (Map.Entry entry : classifications.entrySet()) { - if (klass.equals(entry.getKey())) { - tp += entry.getValue(); + if (classifications != null) { + for (Map.Entry entry : classifications.entrySet()) { + if (klass.equals(entry.getKey())) { + tp += entry.getValue(); + } + } + for (Map values : linearizedMatrix.values()) { + if (values.containsKey(klass)) { + fp += values.get(klass); + } } } - for (Map values : linearizedMatrix.values()) { - if (values.containsKey(klass)) { - fp += values.get(klass); - } - } - return tp / (tp + fp); + return tp + fp > 0 ? tp / (tp + fp) : 0; } /** @@ -191,14 +193,16 @@ public class ConfusionMatrixGenerator { Map classifications = linearizedMatrix.get(klass); double tp = 0; double fn = 0; - for (Map.Entry entry : classifications.entrySet()) { - if (klass.equals(entry.getKey())) { - tp += entry.getValue(); - } else { - fn += entry.getValue(); + if (classifications != null) { + for (Map.Entry entry : classifications.entrySet()) { + if (klass.equals(entry.getKey())) { + tp += entry.getValue(); + } else { + fn += entry.getValue(); + } } } - return tp / (tp + fn); + return tp + fn > 0 ? tp / (tp + fn) : 0; } /** @@ -210,7 +214,7 @@ public class ConfusionMatrixGenerator { public double getF1Measure(String klass) { double recall = getRecall(klass); double precision = getPrecision(klass); - return 2 * precision * recall / (precision + recall); + return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0; } /**