mirror of https://github.com/apache/lucene.git
LUCENE-7303 - avoid NPE in MultiFields.getTerms(leafReader, classFieldName), removed duplicated code in DocumentSNBC
(cherry picked from commit 8808cf5
)
This commit is contained in:
parent
b29eac852b
commit
8c64931517
|
@ -145,18 +145,19 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||||
|
|
||||||
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
||||||
TermsEnum classesEnum = classes.iterator();
|
if (classes != null) {
|
||||||
BytesRef next;
|
TermsEnum classesEnum = classes.iterator();
|
||||||
String[] tokenizedText = tokenize(inputDocument);
|
BytesRef next;
|
||||||
int docsWithClassSize = countDocsWithClass();
|
String[] tokenizedText = tokenize(inputDocument);
|
||||||
while ((next = classesEnum.next()) != null) {
|
int docsWithClassSize = countDocsWithClass();
|
||||||
if (next.length > 0) {
|
while ((next = classesEnum.next()) != null) {
|
||||||
Term term = new Term(this.classFieldName, next);
|
if (next.length > 0) {
|
||||||
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
|
Term term = new Term(this.classFieldName, next);
|
||||||
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
|
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
|
||||||
|
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalization; the values transforms to a 0-1 range
|
// normalization; the values transforms to a 0-1 range
|
||||||
return normClassificationResults(assignedClasses);
|
return normClassificationResults(assignedClasses);
|
||||||
}
|
}
|
||||||
|
@ -168,8 +169,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
* @throws IOException if accessing to term vectors or search fails
|
* @throws IOException if accessing to term vectors or search fails
|
||||||
*/
|
*/
|
||||||
protected int countDocsWithClass() throws IOException {
|
protected int countDocsWithClass() throws IOException {
|
||||||
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
|
Terms terms = MultiFields.getTerms(this.leafReader, this.classFieldName);
|
||||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
int docCount;
|
||||||
|
if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount
|
||||||
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
||||||
BooleanQuery.Builder q = new BooleanQuery.Builder();
|
BooleanQuery.Builder q = new BooleanQuery.Builder();
|
||||||
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
||||||
|
@ -179,6 +181,8 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
indexSearcher.search(q.build(),
|
indexSearcher.search(q.build(),
|
||||||
classQueryCountCollector);
|
classQueryCountCollector);
|
||||||
docCount = classQueryCountCollector.getTotalHits();
|
docCount = classQueryCountCollector.getTotalHits();
|
||||||
|
} else {
|
||||||
|
docCount = terms.getDocCount();
|
||||||
}
|
}
|
||||||
return docCount;
|
return docCount;
|
||||||
}
|
}
|
||||||
|
|
|
@ -168,28 +168,6 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Counts the number of documents in the index having at least a value for the 'class' field
|
|
||||||
*
|
|
||||||
* @return the no. of documents having a value for the 'class' field
|
|
||||||
* @throws java.io.IOException If accessing to term vectors or search fails
|
|
||||||
*/
|
|
||||||
protected int countDocsWithClass() throws IOException {
|
|
||||||
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
|
|
||||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
|
||||||
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
|
||||||
BooleanQuery.Builder q = new BooleanQuery.Builder();
|
|
||||||
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
|
||||||
if (query != null) {
|
|
||||||
q.add(query, BooleanClause.Occur.MUST);
|
|
||||||
}
|
|
||||||
indexSearcher.search(q.build(),
|
|
||||||
classQueryCountCollector);
|
|
||||||
docCount = classQueryCountCollector.getTotalHits();
|
|
||||||
}
|
|
||||||
return docCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input
|
* Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input
|
||||||
*
|
*
|
||||||
|
|
Loading…
Reference in New Issue