mirror of https://github.com/apache/lucene.git
LUCENE-6854 - adjusted precision, recall, f1 measure metrics
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1710605 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
8e929f3468
commit
0deadf0671
|
@ -238,10 +238,11 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
@Override
|
||||
public String toString() {
|
||||
return "KNearestNeighborClassifier{" +
|
||||
", textFieldNames=" + Arrays.toString(textFieldNames) +
|
||||
"textFieldNames=" + Arrays.toString(textFieldNames) +
|
||||
", classFieldName='" + classFieldName + '\'' +
|
||||
", k=" + k +
|
||||
", query=" + query +
|
||||
", similarity=" + indexSearcher.getSimilarity(true) +
|
||||
'}';
|
||||
}
|
||||
}
|
|
@ -151,7 +151,6 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
int docsWithClassSize = countDocsWithClass();
|
||||
while ((next = classesEnum.next()) != null) {
|
||||
if (next.length > 0) {
|
||||
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
|
||||
Term term = new Term(this.classFieldName, next);
|
||||
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
|
||||
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
|
||||
|
|
|
@ -168,17 +168,19 @@ public class ConfusionMatrixGenerator {
|
|||
Map<String, Long> classifications = linearizedMatrix.get(klass);
|
||||
double tp = 0;
|
||||
double fp = 0;
|
||||
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
||||
if (klass.equals(entry.getKey())) {
|
||||
tp += entry.getValue();
|
||||
if (classifications != null) {
|
||||
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
||||
if (klass.equals(entry.getKey())) {
|
||||
tp += entry.getValue();
|
||||
}
|
||||
}
|
||||
for (Map<String, Long> values : linearizedMatrix.values()) {
|
||||
if (values.containsKey(klass)) {
|
||||
fp += values.get(klass);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Map<String, Long> values : linearizedMatrix.values()) {
|
||||
if (values.containsKey(klass)) {
|
||||
fp += values.get(klass);
|
||||
}
|
||||
}
|
||||
return tp / (tp + fp);
|
||||
return tp + fp > 0 ? tp / (tp + fp) : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -191,14 +193,16 @@ public class ConfusionMatrixGenerator {
|
|||
Map<String, Long> classifications = linearizedMatrix.get(klass);
|
||||
double tp = 0;
|
||||
double fn = 0;
|
||||
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
||||
if (klass.equals(entry.getKey())) {
|
||||
tp += entry.getValue();
|
||||
} else {
|
||||
fn += entry.getValue();
|
||||
if (classifications != null) {
|
||||
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
||||
if (klass.equals(entry.getKey())) {
|
||||
tp += entry.getValue();
|
||||
} else {
|
||||
fn += entry.getValue();
|
||||
}
|
||||
}
|
||||
}
|
||||
return tp / (tp + fn);
|
||||
return tp + fn > 0 ? tp / (tp + fn) : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -210,7 +214,7 @@ public class ConfusionMatrixGenerator {
|
|||
public double getF1Measure(String klass) {
|
||||
double recall = getRecall(klass);
|
||||
double precision = getPrecision(klass);
|
||||
return 2 * precision * recall / (precision + recall);
|
||||
return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
Loading…
Reference in New Issue