LUCENE-7303 - avoid NPE in MultiFields.getTerms(leafReader, classFieldName), removed duplicated code in DocumentSNBC

This commit is contained in:
Tommaso Teofili 2016-05-26 15:59:12 +02:00
parent 2aabed4ab6
commit 8808cf5373
2 changed files with 16 additions and 34 deletions

View File

@ -145,18 +145,19 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
TermsEnum classesEnum = classes.iterator();
BytesRef next;
String[] tokenizedText = tokenize(inputDocument);
int docsWithClassSize = countDocsWithClass();
while ((next = classesEnum.next()) != null) {
if (next.length > 0) {
Term term = new Term(this.classFieldName, next);
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
if (classes != null) {
TermsEnum classesEnum = classes.iterator();
BytesRef next;
String[] tokenizedText = tokenize(inputDocument);
int docsWithClassSize = countDocsWithClass();
while ((next = classesEnum.next()) != null) {
if (next.length > 0) {
Term term = new Term(this.classFieldName, next);
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
}
}
}
// normalization; the values transforms to a 0-1 range
return normClassificationResults(assignedClasses);
}
@ -168,8 +169,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* @throws IOException if accessing to term vectors or search fails
*/
protected int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount
Terms terms = MultiFields.getTerms(this.leafReader, this.classFieldName);
int docCount;
if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
BooleanQuery.Builder q = new BooleanQuery.Builder();
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
@ -179,6 +181,8 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
indexSearcher.search(q.build(),
classQueryCountCollector);
docCount = classQueryCountCollector.getTotalHits();
} else {
docCount = terms.getDocCount();
}
return docCount;
}

View File

@ -168,28 +168,6 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
}
}
/**
* Counts the number of documents in the index having at least a value for the 'class' field
*
* @return the no. of documents having a value for the 'class' field
* @throws java.io.IOException If accessing to term vectors or search fails
*/
protected int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
BooleanQuery.Builder q = new BooleanQuery.Builder();
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
if (query != null) {
q.add(query, BooleanClause.Occur.MUST);
}
indexSearcher.search(q.build(),
classQueryCountCollector);
docCount = classQueryCountCollector.getTotalHits();
}
return docCount;
}
/**
* Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input
*