LUCENE-7950 - fixed potential NPE when no docs have the class field

This commit is contained in:
Tommaso Teofili 2017-09-02 14:43:59 +02:00
parent cd471cc98d
commit c2c2e8a85e
1 changed files with 16 additions and 14 deletions

View File

@ -113,24 +113,26 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>(); Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
Map<String, Float> fieldName2boost = new LinkedHashMap<>(); Map<String, Float> fieldName2boost = new LinkedHashMap<>();
Terms classes = MultiFields.getTerms(indexReader, classFieldName); Terms classes = MultiFields.getTerms(indexReader, classFieldName);
TermsEnum classesEnum = classes.iterator(); if (classes != null) {
BytesRef c; TermsEnum classesEnum = classes.iterator();
BytesRef c;
analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost); analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost);
int docsWithClassSize = countDocsWithClass(); int docsWithClassSize = countDocsWithClass();
while ((c = classesEnum.next()) != null) { while ((c = classesEnum.next()) != null) {
double classScore = 0; double classScore = 0;
Term term = new Term(this.classFieldName, c); Term term = new Term(this.classFieldName, c);
for (String fieldName : textFieldNames) { for (String fieldName : textFieldNames) {
List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName); List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName);
double fieldScore = 0; double fieldScore = 0;
for (String[] fieldTokensArray : tokensArrays) { for (String[] fieldTokensArray : tokensArrays) {
fieldScore += calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, term, docsWithClassSize) * fieldName2boost.get(fieldName); fieldScore += calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, term, docsWithClassSize) * fieldName2boost.get(fieldName);
}
classScore += fieldScore;
} }
classScore += fieldScore; assignedClasses.add(new ClassificationResult<>(term.bytes(), classScore));
} }
assignedClasses.add(new ClassificationResult<>(term.bytes(), classScore));
} }
return normClassificationResults(assignedClasses); return normClassificationResults(assignedClasses);
} }