mirror of https://github.com/apache/lucene.git
LUCENE-6854 - global confusion matrix precision and recall metrics
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1723729 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
bcf2913b9f
commit
d9f4eb84cf
|
@ -246,6 +246,54 @@ public class ConfusionMatrixGenerator {
|
|||
return this.accuracy;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the precision (see {@link #getPrecision(String)}) over all the classes.
|
||||
*
|
||||
* @return the precision as computed from the whole confusion matrix
|
||||
*/
|
||||
public double getPrecision() {
|
||||
double tp = 0;
|
||||
double fp = 0;
|
||||
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
|
||||
String klass = classification.getKey();
|
||||
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
|
||||
if (klass.equals(entry.getKey())) {
|
||||
tp += entry.getValue();
|
||||
}
|
||||
}
|
||||
for (Map<String, Long> values : linearizedMatrix.values()) {
|
||||
if (values.containsKey(klass)) {
|
||||
fp += values.get(klass);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tp + fp > 0 ? tp / (tp + fp) : 0;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* get the recall (see {@link #getRecall(String)}) over all the classes
|
||||
*
|
||||
* @return the recall as computed from the whole confusion matrix
|
||||
*/
|
||||
public double getRecall() {
|
||||
double tp = 0;
|
||||
double fn = 0;
|
||||
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
|
||||
String klass = classification.getKey();
|
||||
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
|
||||
if (klass.equals(entry.getKey())) {
|
||||
tp += entry.getValue();
|
||||
} else {
|
||||
fn += entry.getValue();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tp + fn > 0 ? tp / (tp + fn) : 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ConfusionMatrix{" +
|
||||
|
|
|
@ -25,7 +25,10 @@ import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
|||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Test;
|
||||
|
@ -109,7 +112,32 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
|
|||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(5000 > avgClassificationTime);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy > 0d);
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
|
||||
Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
|
||||
TermsEnum iterator = terms.iterator();
|
||||
BytesRef term;
|
||||
while ((term = iterator.next()) != null) {
|
||||
String s = term.utf8ToString();
|
||||
recall = confusionMatrix.getRecall(s);
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
precision = confusionMatrix.getPrecision(s);
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure(s);
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
}
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
|
|
|
@ -23,7 +23,10 @@ import org.apache.lucene.analysis.MockAnalyzer;
|
|||
import org.apache.lucene.analysis.en.EnglishAnalyzer;
|
||||
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||
import org.apache.lucene.search.similarities.LMDirichletSimilarity;
|
||||
|
@ -139,7 +142,32 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(5000 > avgClassificationTime);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy > 0d);
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
|
||||
Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
|
||||
TermsEnum iterator = terms.iterator();
|
||||
BytesRef term;
|
||||
while ((term = iterator.next()) != null) {
|
||||
String s = term.utf8ToString();
|
||||
recall = confusionMatrix.getRecall(s);
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
precision = confusionMatrix.getPrecision(s);
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure(s);
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
}
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
|
|
|
@ -24,7 +24,10 @@ import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
|||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||
import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Ignore;
|
||||
|
@ -110,7 +113,33 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
|||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime);
|
||||
double accuracy = confusionMatrix.getAccuracy();
|
||||
assertTrue(accuracy > 0d);
|
||||
assertTrue(accuracy >= 0d);
|
||||
assertTrue(accuracy <= 1d);
|
||||
|
||||
double recall = confusionMatrix.getRecall();
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
|
||||
double precision = confusionMatrix.getPrecision();
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
|
||||
Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
|
||||
TermsEnum iterator = terms.iterator();
|
||||
BytesRef term;
|
||||
while ((term = iterator.next()) != null) {
|
||||
String s = term.utf8ToString();
|
||||
recall = confusionMatrix.getRecall(s);
|
||||
assertTrue(recall >= 0d);
|
||||
assertTrue(recall <= 1d);
|
||||
precision = confusionMatrix.getPrecision(s);
|
||||
assertTrue(precision >= 0d);
|
||||
assertTrue(precision <= 1d);
|
||||
double f1Measure = confusionMatrix.getF1Measure(s);
|
||||
assertTrue(f1Measure >= 0d);
|
||||
assertTrue(f1Measure <= 1d);
|
||||
}
|
||||
|
||||
} finally {
|
||||
leafReader.close();
|
||||
}
|
||||
|
|
|
@ -65,6 +65,12 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
|
||||
assertTrue(avgClassificationTime >= 0d );
|
||||
assertTrue(confusionMatrix.getAccuracy() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision() >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision() <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall() >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall() <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -83,8 +89,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() > 0d);
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision() >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision() <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall() >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall() <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -103,8 +114,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() > 0d);
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision() >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision() <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall() >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall() <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -123,8 +139,13 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix);
|
||||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() > 0d);
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision() >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision() <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall() >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall() <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
@ -144,13 +165,24 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
|
|||
assertNotNull(confusionMatrix.getLinearizedMatrix());
|
||||
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
|
||||
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() > 0d);
|
||||
assertTrue(confusionMatrix.getPrecision("true") > 0d);
|
||||
assertTrue(confusionMatrix.getPrecision("false") > 0d);
|
||||
assertTrue(confusionMatrix.getRecall("true") > 0d);
|
||||
assertTrue(confusionMatrix.getRecall("false") > 0d);
|
||||
assertTrue(confusionMatrix.getF1Measure("true") > 0d);
|
||||
assertTrue(confusionMatrix.getF1Measure("false") > 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() >= 0d);
|
||||
assertTrue(confusionMatrix.getAccuracy() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision() >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision() <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall() >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall() <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision("true") >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision("true") <= 1d);
|
||||
assertTrue(confusionMatrix.getPrecision("false") >= 0d);
|
||||
assertTrue(confusionMatrix.getPrecision("false") <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall("true") >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall("true") <= 1d);
|
||||
assertTrue(confusionMatrix.getRecall("false") >= 0d);
|
||||
assertTrue(confusionMatrix.getRecall("false") <= 1d);
|
||||
assertTrue(confusionMatrix.getF1Measure("true") >= 0d);
|
||||
assertTrue(confusionMatrix.getF1Measure("true") <= 1d);
|
||||
assertTrue(confusionMatrix.getF1Measure("false") >= 0d);
|
||||
assertTrue(confusionMatrix.getF1Measure("false") <= 1d);
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
reader.close();
|
||||
|
|
Loading…
Reference in New Issue