diff --git a/lucene/classification/build.xml b/lucene/classification/build.xml index 39cc28d3753..3e4103d9ee0 100644 --- a/lucene/classification/build.xml +++ b/lucene/classification/build.xml @@ -38,7 +38,7 @@ - + diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index 17fa9d87a0f..06d5b83aeb0 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -29,6 +29,7 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TotalHitCountCollector; +import org.apache.lucene.search.WildcardQuery; import org.apache.lucene.util.BytesRef; import java.io.IOException; @@ -69,7 +70,18 @@ public class SimpleNaiveBayesClassifier implements Classifier { this.textFieldName = textFieldName; this.classFieldName = classFieldName; this.analyzer = analyzer; - this.docsWithClassSize = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount(); + this.docsWithClassSize = countDocsWithClass(); + } + + private int countDocsWithClass() throws IOException { + int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount(); + if (docCount == -1) { // in case codec doesn't support getDocCount + TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); + indexSearcher.search(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), + totalHitCountCollector); + docCount = totalHitCountCollector.getTotalHits(); + } + return docCount; } private String[] tokenizeDoc(String doc) throws IOException {