LUCENE-6854 - global confusion matrix precision and recall metrics

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1723729 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2016-01-08 14:23:12 +00:00
parent bcf2913b9f
commit d9f4eb84cf
5 changed files with 181 additions and 16 deletions

View File

@ -246,6 +246,54 @@ public class ConfusionMatrixGenerator {
return this.accuracy;
}
/**
* get the precision (see {@link #getPrecision(String)}) over all the classes.
*
* @return the precision as computed from the whole confusion matrix
*/
public double getPrecision() {
double tp = 0;
double fp = 0;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey();
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
}
}
for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) {
fp += values.get(klass);
}
}
}
return tp + fp > 0 ? tp / (tp + fp) : 0;
}
/**
* get the recall (see {@link #getRecall(String)}) over all the classes
*
* @return the recall as computed from the whole confusion matrix
*/
public double getRecall() {
double tp = 0;
double fn = 0;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey();
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
} else {
fn += entry.getValue();
}
}
}
return tp + fn > 0 ? tp / (tp + fn) : 0;
}
@Override
public String toString() {
return "ConfusionMatrix{" +

View File

@ -25,7 +25,10 @@ import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.junit.Test;
@ -109,7 +112,32 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(5000 > avgClassificationTime);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy > 0d);
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
TermsEnum iterator = terms.iterator();
BytesRef term;
while ((term = iterator.next()) != null) {
String s = term.utf8ToString();
recall = confusionMatrix.getRecall(s);
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
precision = confusionMatrix.getPrecision(s);
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double f1Measure = confusionMatrix.getF1Measure(s);
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
}
} finally {
leafReader.close();
}

View File

@ -23,7 +23,10 @@ import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.en.EnglishAnalyzer;
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.LMDirichletSimilarity;
@ -139,7 +142,32 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(5000 > avgClassificationTime);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy > 0d);
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
TermsEnum iterator = terms.iterator();
BytesRef term;
while ((term = iterator.next()) != null) {
String s = term.utf8ToString();
recall = confusionMatrix.getRecall(s);
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
precision = confusionMatrix.getPrecision(s);
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double f1Measure = confusionMatrix.getF1Measure(s);
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
}
} finally {
leafReader.close();
}

View File

@ -24,7 +24,10 @@ import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.junit.Ignore;
@ -110,7 +113,33 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy > 0d);
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
TermsEnum iterator = terms.iterator();
BytesRef term;
while ((term = iterator.next()) != null) {
String s = term.utf8ToString();
recall = confusionMatrix.getRecall(s);
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
precision = confusionMatrix.getPrecision(s);
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double f1Measure = confusionMatrix.getF1Measure(s);
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
}
} finally {
leafReader.close();
}

View File

@ -65,6 +65,12 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(avgClassificationTime >= 0d );
assertTrue(confusionMatrix.getAccuracy() >= 0d);
assertTrue(confusionMatrix.getAccuracy() <= 1d);
assertTrue(confusionMatrix.getPrecision() >= 0d);
assertTrue(confusionMatrix.getPrecision() <= 1d);
assertTrue(confusionMatrix.getRecall() >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -83,8 +89,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
assertTrue(confusionMatrix.getAccuracy() > 0d);
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d);
assertTrue(confusionMatrix.getAccuracy() <= 1d);
assertTrue(confusionMatrix.getPrecision() >= 0d);
assertTrue(confusionMatrix.getPrecision() <= 1d);
assertTrue(confusionMatrix.getRecall() >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -103,8 +114,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
assertTrue(confusionMatrix.getAccuracy() > 0d);
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d);
assertTrue(confusionMatrix.getAccuracy() <= 1d);
assertTrue(confusionMatrix.getPrecision() >= 0d);
assertTrue(confusionMatrix.getPrecision() <= 1d);
assertTrue(confusionMatrix.getRecall() >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -123,8 +139,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
assertTrue(confusionMatrix.getAccuracy() > 0d);
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d);
assertTrue(confusionMatrix.getAccuracy() <= 1d);
assertTrue(confusionMatrix.getPrecision() >= 0d);
assertTrue(confusionMatrix.getPrecision() <= 1d);
assertTrue(confusionMatrix.getRecall() >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d);
} finally {
if (reader != null) {
reader.close();
@ -144,13 +165,24 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() > 0d);
assertTrue(confusionMatrix.getPrecision("true") > 0d);
assertTrue(confusionMatrix.getPrecision("false") > 0d);
assertTrue(confusionMatrix.getRecall("true") > 0d);
assertTrue(confusionMatrix.getRecall("false") > 0d);
assertTrue(confusionMatrix.getF1Measure("true") > 0d);
assertTrue(confusionMatrix.getF1Measure("false") > 0d);
assertTrue(confusionMatrix.getAccuracy() >= 0d);
assertTrue(confusionMatrix.getAccuracy() <= 1d);
assertTrue(confusionMatrix.getPrecision() >= 0d);
assertTrue(confusionMatrix.getPrecision() <= 1d);
assertTrue(confusionMatrix.getRecall() >= 0d);
assertTrue(confusionMatrix.getRecall() <= 1d);
assertTrue(confusionMatrix.getPrecision("true") >= 0d);
assertTrue(confusionMatrix.getPrecision("true") <= 1d);
assertTrue(confusionMatrix.getPrecision("false") >= 0d);
assertTrue(confusionMatrix.getPrecision("false") <= 1d);
assertTrue(confusionMatrix.getRecall("true") >= 0d);
assertTrue(confusionMatrix.getRecall("true") <= 1d);
assertTrue(confusionMatrix.getRecall("false") >= 0d);
assertTrue(confusionMatrix.getRecall("false") <= 1d);
assertTrue(confusionMatrix.getF1Measure("true") >= 0d);
assertTrue(confusionMatrix.getF1Measure("true") <= 1d);
assertTrue(confusionMatrix.getF1Measure("false") >= 0d);
assertTrue(confusionMatrix.getF1Measure("false") <= 1d);
} finally {
if (reader != null) {
reader.close();