mirror of https://github.com/apache/lucene.git
LUCENE-7303 - avoid NPE in MultiFields.getTerms(leafReader, classFieldName), removed duplicated code in DocumentSNBC
(cherry picked from commit 8808cf5
)
This commit is contained in:
parent
b29eac852b
commit
8c64931517
|
@ -145,18 +145,19 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||
|
||||
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
||||
TermsEnum classesEnum = classes.iterator();
|
||||
BytesRef next;
|
||||
String[] tokenizedText = tokenize(inputDocument);
|
||||
int docsWithClassSize = countDocsWithClass();
|
||||
while ((next = classesEnum.next()) != null) {
|
||||
if (next.length > 0) {
|
||||
Term term = new Term(this.classFieldName, next);
|
||||
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
|
||||
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
|
||||
if (classes != null) {
|
||||
TermsEnum classesEnum = classes.iterator();
|
||||
BytesRef next;
|
||||
String[] tokenizedText = tokenize(inputDocument);
|
||||
int docsWithClassSize = countDocsWithClass();
|
||||
while ((next = classesEnum.next()) != null) {
|
||||
if (next.length > 0) {
|
||||
Term term = new Term(this.classFieldName, next);
|
||||
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
|
||||
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalization; the values transforms to a 0-1 range
|
||||
return normClassificationResults(assignedClasses);
|
||||
}
|
||||
|
@ -168,8 +169,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
* @throws IOException if accessing to term vectors or search fails
|
||||
*/
|
||||
protected int countDocsWithClass() throws IOException {
|
||||
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
|
||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||
Terms terms = MultiFields.getTerms(this.leafReader, this.classFieldName);
|
||||
int docCount;
|
||||
if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount
|
||||
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
||||
BooleanQuery.Builder q = new BooleanQuery.Builder();
|
||||
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
||||
|
@ -179,6 +181,8 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
indexSearcher.search(q.build(),
|
||||
classQueryCountCollector);
|
||||
docCount = classQueryCountCollector.getTotalHits();
|
||||
} else {
|
||||
docCount = terms.getDocCount();
|
||||
}
|
||||
return docCount;
|
||||
}
|
||||
|
|
|
@ -168,28 +168,6 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts the number of documents in the index having at least a value for the 'class' field
|
||||
*
|
||||
* @return the no. of documents having a value for the 'class' field
|
||||
* @throws java.io.IOException If accessing to term vectors or search fails
|
||||
*/
|
||||
protected int countDocsWithClass() throws IOException {
|
||||
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
|
||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
||||
BooleanQuery.Builder q = new BooleanQuery.Builder();
|
||||
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
||||
if (query != null) {
|
||||
q.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
indexSearcher.search(q.build(),
|
||||
classQueryCountCollector);
|
||||
docCount = classQueryCountCollector.getTotalHits();
|
||||
}
|
||||
return docCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue