mirror of
https://github.com/apache/lucene.git
synced 2025-02-09 11:35:14 +00:00
LUCENE-6854 - adjusted precision, recall, f1 measure metrics
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1710605 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
8e929f3468
commit
0deadf0671
@ -238,10 +238,11 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "KNearestNeighborClassifier{" +
|
return "KNearestNeighborClassifier{" +
|
||||||
", textFieldNames=" + Arrays.toString(textFieldNames) +
|
"textFieldNames=" + Arrays.toString(textFieldNames) +
|
||||||
", classFieldName='" + classFieldName + '\'' +
|
", classFieldName='" + classFieldName + '\'' +
|
||||||
", k=" + k +
|
", k=" + k +
|
||||||
", query=" + query +
|
", query=" + query +
|
||||||
|
", similarity=" + indexSearcher.getSimilarity(true) +
|
||||||
'}';
|
'}';
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -151,7 +151,6 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||||||
int docsWithClassSize = countDocsWithClass();
|
int docsWithClassSize = countDocsWithClass();
|
||||||
while ((next = classesEnum.next()) != null) {
|
while ((next = classesEnum.next()) != null) {
|
||||||
if (next.length > 0) {
|
if (next.length > 0) {
|
||||||
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
|
|
||||||
Term term = new Term(this.classFieldName, next);
|
Term term = new Term(this.classFieldName, next);
|
||||||
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
|
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
|
||||||
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
|
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
|
||||||
|
@ -168,17 +168,19 @@ public class ConfusionMatrixGenerator {
|
|||||||
Map<String, Long> classifications = linearizedMatrix.get(klass);
|
Map<String, Long> classifications = linearizedMatrix.get(klass);
|
||||||
double tp = 0;
|
double tp = 0;
|
||||||
double fp = 0;
|
double fp = 0;
|
||||||
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
if (classifications != null) {
|
||||||
if (klass.equals(entry.getKey())) {
|
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
||||||
tp += entry.getValue();
|
if (klass.equals(entry.getKey())) {
|
||||||
|
tp += entry.getValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (Map<String, Long> values : linearizedMatrix.values()) {
|
||||||
|
if (values.containsKey(klass)) {
|
||||||
|
fp += values.get(klass);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (Map<String, Long> values : linearizedMatrix.values()) {
|
return tp + fp > 0 ? tp / (tp + fp) : 0;
|
||||||
if (values.containsKey(klass)) {
|
|
||||||
fp += values.get(klass);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tp / (tp + fp);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -191,14 +193,16 @@ public class ConfusionMatrixGenerator {
|
|||||||
Map<String, Long> classifications = linearizedMatrix.get(klass);
|
Map<String, Long> classifications = linearizedMatrix.get(klass);
|
||||||
double tp = 0;
|
double tp = 0;
|
||||||
double fn = 0;
|
double fn = 0;
|
||||||
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
if (classifications != null) {
|
||||||
if (klass.equals(entry.getKey())) {
|
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
|
||||||
tp += entry.getValue();
|
if (klass.equals(entry.getKey())) {
|
||||||
} else {
|
tp += entry.getValue();
|
||||||
fn += entry.getValue();
|
} else {
|
||||||
|
fn += entry.getValue();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return tp / (tp + fn);
|
return tp + fn > 0 ? tp / (tp + fn) : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -210,7 +214,7 @@ public class ConfusionMatrixGenerator {
|
|||||||
public double getF1Measure(String klass) {
|
public double getF1Measure(String klass) {
|
||||||
double recall = getRecall(klass);
|
double recall = getRecall(klass);
|
||||||
double precision = getPrecision(klass);
|
double precision = getPrecision(klass);
|
||||||
return 2 * precision * recall / (precision + recall);
|
return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
Loading…
x
Reference in New Issue
Block a user