LUCENE-5548 - improved testing for SNBC

This commit is contained in:
Tommaso Teofili 2017-04-07 10:58:49 +02:00
parent 276ccff751
commit f37fad206b
1 changed files with 9 additions and 2 deletions

View File

@ -59,8 +59,10 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
try { try {
MockAnalyzer analyzer = new MockAnalyzer(random()); MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = getSampleIndex(analyzer); leafReader = getSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "it")); TermQuery query = new TermQuery(new Term(textFieldName, "a"));
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); SimpleNaiveBayesClassifier classifier = new SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName);
checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(classifier, POLITICS_INPUT, POLITICS_RESULT);
} finally { } finally {
if (leafReader != null) { if (leafReader != null) {
leafReader.close(); leafReader.close();
@ -112,6 +114,11 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000); assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
double avgClassificationTime = confusionMatrix.getAvgClassificationTime(); double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime); assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime);
double f1 = confusionMatrix.getF1Measure();
assertTrue(f1 >= 0d);
assertTrue(f1 <= 1d);
double accuracy = confusionMatrix.getAccuracy(); double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d); assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d); assertTrue(accuracy <= 1d);