diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index d1393523c09..923f695852a 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -64,23 +64,17 @@ public class SimpleNaiveBayesClassifier implements Classifier { * {@inheritDoc} */ @Override - public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) - throws IOException { - this.atomicReader = atomicReader; - this.indexSearcher = new IndexSearcher(this.atomicReader); - this.textFieldNames = new String[]{textFieldName}; - this.classFieldName = classFieldName; - this.analyzer = analyzer; - this.docsWithClassSize = countDocsWithClass(); - this.query = query; + public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { + train(atomicReader, textFieldName, classFieldName, analyzer, null); } /** * {@inheritDoc} */ @Override - public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { - train(atomicReader, textFieldName, classFieldName, analyzer, null); + public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) + throws IOException { + train(atomicReader, new String[]{textFieldName}, classFieldName, analyzer, query); } /** @@ -137,7 +131,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { if (atomicReader == null) { throw new IOException("You must first call Classifier#train"); } - double max = 0d; + double max = - Double.MAX_VALUE; BytesRef foundClass = new BytesRef(); Terms terms = MultiFields.getTerms(atomicReader, classFieldName); @@ -145,20 +139,20 @@ public class SimpleNaiveBayesClassifier implements Classifier { BytesRef next; String[] tokenizedDoc = tokenizeDoc(inputDocument); while ((next = termsEnum.next()) != null) { - // TODO : turn it to be in log scale - double clVal = calculatePrior(next) * calculateLikelihood(tokenizedDoc, next); + double clVal = calculateLogPrior(next) + calculateLogLikelihood(tokenizedDoc, next); if (clVal > max) { max = clVal; foundClass = BytesRef.deepCopyOf(next); } } - return new ClassificationResult(foundClass, max); + double score = 10 / Math.abs(max); + return new ClassificationResult(foundClass, score); } - private double calculateLikelihood(String[] tokenizedDoc, BytesRef c) throws IOException { + private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c) throws IOException { // for each word - double result = 1d; + double result = 0d; for (String word : tokenizedDoc) { // search with text:word AND class:c int hits = getWordFreqForClass(word, c); @@ -171,10 +165,10 @@ public class SimpleNaiveBayesClassifier implements Classifier { // P(w|c) = num/den double wordProbability = num / den; - result *= wordProbability; + result += Math.log(wordProbability); } - // P(d|c) = P(w1|c)*...*P(wn|c) + // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c)) return result; } @@ -205,8 +199,8 @@ public class SimpleNaiveBayesClassifier implements Classifier { return totalHitCountCollector.getTotalHits(); } - private double calculatePrior(BytesRef currentClass) throws IOException { - return (double) docCount(currentClass) / docsWithClassSize; + private double calculateLogPrior(BytesRef currentClass) throws IOException { + return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize); } private int docCount(BytesRef countedClass) throws IOException {