LUCENE-6854 - adjusted precision, recall, f1 measure metrics

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1710605 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-10-26 14:21:18 +00:00
parent 8e929f3468
commit 0deadf0671
3 changed files with 22 additions and 18 deletions

View File

@ -238,10 +238,11 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
@Override
public String toString() {
return "KNearestNeighborClassifier{" +
", textFieldNames=" + Arrays.toString(textFieldNames) +
"textFieldNames=" + Arrays.toString(textFieldNames) +
", classFieldName='" + classFieldName + '\'' +
", k=" + k +
", query=" + query +
", similarity=" + indexSearcher.getSimilarity(true) +
'}';
}
}

View File

@ -151,7 +151,6 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
int docsWithClassSize = countDocsWithClass();
while ((next = classesEnum.next()) != null) {
if (next.length > 0) {
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
Term term = new Term(this.classFieldName, next);
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));

View File

@ -168,17 +168,19 @@ public class ConfusionMatrixGenerator {
Map<String, Long> classifications = linearizedMatrix.get(klass);
double tp = 0;
double fp = 0;
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
if (classifications != null) {
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
}
}
for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) {
fp += values.get(klass);
}
}
}
for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) {
fp += values.get(klass);
}
}
return tp / (tp + fp);
return tp + fp > 0 ? tp / (tp + fp) : 0;
}
/**
@ -191,14 +193,16 @@ public class ConfusionMatrixGenerator {
Map<String, Long> classifications = linearizedMatrix.get(klass);
double tp = 0;
double fn = 0;
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
} else {
fn += entry.getValue();
if (classifications != null) {
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
} else {
fn += entry.getValue();
}
}
}
return tp / (tp + fn);
return tp + fn > 0 ? tp / (tp + fn) : 0;
}
/**
@ -210,7 +214,7 @@ public class ConfusionMatrixGenerator {
public double getF1Measure(String klass) {
double recall = getRecall(klass);
double precision = getPrecision(klass);
return 2 * precision * recall / (precision + recall);
return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
}
/**