LUCENE-6854 - adjusted precision, recall, f1 measure metrics

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1710605 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-10-26 14:21:18 +00:00
parent 8e929f3468
commit 0deadf0671
3 changed files with 22 additions and 18 deletions

View File

@ -238,10 +238,11 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
@Override @Override
public String toString() { public String toString() {
return "KNearestNeighborClassifier{" + return "KNearestNeighborClassifier{" +
", textFieldNames=" + Arrays.toString(textFieldNames) + "textFieldNames=" + Arrays.toString(textFieldNames) +
", classFieldName='" + classFieldName + '\'' + ", classFieldName='" + classFieldName + '\'' +
", k=" + k + ", k=" + k +
", query=" + query + ", query=" + query +
", similarity=" + indexSearcher.getSimilarity(true) +
'}'; '}';
} }
} }

View File

@ -151,7 +151,6 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
int docsWithClassSize = countDocsWithClass(); int docsWithClassSize = countDocsWithClass();
while ((next = classesEnum.next()) != null) { while ((next = classesEnum.next()) != null) {
if (next.length > 0) { if (next.length > 0) {
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
Term term = new Term(this.classFieldName, next); Term term = new Term(this.classFieldName, next);
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize); double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal)); assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));

View File

@ -168,17 +168,19 @@ public class ConfusionMatrixGenerator {
Map<String, Long> classifications = linearizedMatrix.get(klass); Map<String, Long> classifications = linearizedMatrix.get(klass);
double tp = 0; double tp = 0;
double fp = 0; double fp = 0;
for (Map.Entry<String, Long> entry : classifications.entrySet()) { if (classifications != null) {
if (klass.equals(entry.getKey())) { for (Map.Entry<String, Long> entry : classifications.entrySet()) {
tp += entry.getValue(); if (klass.equals(entry.getKey())) {
tp += entry.getValue();
}
}
for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) {
fp += values.get(klass);
}
} }
} }
for (Map<String, Long> values : linearizedMatrix.values()) { return tp + fp > 0 ? tp / (tp + fp) : 0;
if (values.containsKey(klass)) {
fp += values.get(klass);
}
}
return tp / (tp + fp);
} }
/** /**
@ -191,14 +193,16 @@ public class ConfusionMatrixGenerator {
Map<String, Long> classifications = linearizedMatrix.get(klass); Map<String, Long> classifications = linearizedMatrix.get(klass);
double tp = 0; double tp = 0;
double fn = 0; double fn = 0;
for (Map.Entry<String, Long> entry : classifications.entrySet()) { if (classifications != null) {
if (klass.equals(entry.getKey())) { for (Map.Entry<String, Long> entry : classifications.entrySet()) {
tp += entry.getValue(); if (klass.equals(entry.getKey())) {
} else { tp += entry.getValue();
fn += entry.getValue(); } else {
fn += entry.getValue();
}
} }
} }
return tp / (tp + fn); return tp + fn > 0 ? tp / (tp + fn) : 0;
} }
/** /**
@ -210,7 +214,7 @@ public class ConfusionMatrixGenerator {
public double getF1Measure(String klass) { public double getF1Measure(String klass) {
double recall = getRecall(klass); double recall = getRecall(klass);
double precision = getPrecision(klass); double precision = getPrecision(klass);
return 2 * precision * recall / (precision + recall); return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
} }
/** /**