LUCENE-7303 - avoid NPE in MultiFields.getTerms(leafReader, classFieldName), removed duplicated code in DocumentSNBC

(cherry picked from commit 8808cf5)
This commit is contained in:
Tommaso Teofili 2016-05-26 15:59:12 +02:00
parent b29eac852b
commit 8c64931517
2 changed files with 16 additions and 34 deletions

View File

@ -145,18 +145,19 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>(); List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
Terms classes = MultiFields.getTerms(leafReader, classFieldName); Terms classes = MultiFields.getTerms(leafReader, classFieldName);
TermsEnum classesEnum = classes.iterator(); if (classes != null) {
BytesRef next; TermsEnum classesEnum = classes.iterator();
String[] tokenizedText = tokenize(inputDocument); BytesRef next;
int docsWithClassSize = countDocsWithClass(); String[] tokenizedText = tokenize(inputDocument);
while ((next = classesEnum.next()) != null) { int docsWithClassSize = countDocsWithClass();
if (next.length > 0) { while ((next = classesEnum.next()) != null) {
Term term = new Term(this.classFieldName, next); if (next.length > 0) {
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize); Term term = new Term(this.classFieldName, next);
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal)); double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
}
} }
} }
// normalization; the values transforms to a 0-1 range // normalization; the values transforms to a 0-1 range
return normClassificationResults(assignedClasses); return normClassificationResults(assignedClasses);
} }
@ -168,8 +169,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* @throws IOException if accessing to term vectors or search fails * @throws IOException if accessing to term vectors or search fails
*/ */
protected int countDocsWithClass() throws IOException { protected int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount(); Terms terms = MultiFields.getTerms(this.leafReader, this.classFieldName);
if (docCount == -1) { // in case codec doesn't support getDocCount int docCount;
if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector(); TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
BooleanQuery.Builder q = new BooleanQuery.Builder(); BooleanQuery.Builder q = new BooleanQuery.Builder();
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST)); q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
@ -179,6 +181,8 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
indexSearcher.search(q.build(), indexSearcher.search(q.build(),
classQueryCountCollector); classQueryCountCollector);
docCount = classQueryCountCollector.getTotalHits(); docCount = classQueryCountCollector.getTotalHits();
} else {
docCount = terms.getDocCount();
} }
return docCount; return docCount;
} }

View File

@ -168,28 +168,6 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
} }
} }
/**
* Counts the number of documents in the index having at least a value for the 'class' field
*
* @return the no. of documents having a value for the 'class' field
* @throws java.io.IOException If accessing to term vectors or search fails
*/
protected int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
BooleanQuery.Builder q = new BooleanQuery.Builder();
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
if (query != null) {
q.add(query, BooleanClause.Occur.MUST);
}
indexSearcher.search(q.build(),
classQueryCountCollector);
docCount = classQueryCountCollector.getTotalHits();
}
return docCount;
}
/** /**
* Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input * Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input
* *