LUCENE-4927 - switched to log prior/likelihood to avoid possible underflows

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1544433 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2013-11-22 08:29:16 +00:00
parent eb1dcbaa70
commit f9b3e389b2
1 changed files with 15 additions and 21 deletions

View File

@ -64,23 +64,17 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* {@inheritDoc} * {@inheritDoc}
*/ */
@Override @Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
throws IOException { train(atomicReader, textFieldName, classFieldName, analyzer, null);
this.atomicReader = atomicReader;
this.indexSearcher = new IndexSearcher(this.atomicReader);
this.textFieldNames = new String[]{textFieldName};
this.classFieldName = classFieldName;
this.analyzer = analyzer;
this.docsWithClassSize = countDocsWithClass();
this.query = query;
} }
/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
@Override @Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
train(atomicReader, textFieldName, classFieldName, analyzer, null); throws IOException {
train(atomicReader, new String[]{textFieldName}, classFieldName, analyzer, query);
} }
/** /**
@ -137,7 +131,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
if (atomicReader == null) { if (atomicReader == null) {
throw new IOException("You must first call Classifier#train"); throw new IOException("You must first call Classifier#train");
} }
double max = 0d; double max = - Double.MAX_VALUE;
BytesRef foundClass = new BytesRef(); BytesRef foundClass = new BytesRef();
Terms terms = MultiFields.getTerms(atomicReader, classFieldName); Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
@ -145,20 +139,20 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
BytesRef next; BytesRef next;
String[] tokenizedDoc = tokenizeDoc(inputDocument); String[] tokenizedDoc = tokenizeDoc(inputDocument);
while ((next = termsEnum.next()) != null) { while ((next = termsEnum.next()) != null) {
// TODO : turn it to be in log scale double clVal = calculateLogPrior(next) + calculateLogLikelihood(tokenizedDoc, next);
double clVal = calculatePrior(next) * calculateLikelihood(tokenizedDoc, next);
if (clVal > max) { if (clVal > max) {
max = clVal; max = clVal;
foundClass = BytesRef.deepCopyOf(next); foundClass = BytesRef.deepCopyOf(next);
} }
} }
return new ClassificationResult<BytesRef>(foundClass, max); double score = 10 / Math.abs(max);
return new ClassificationResult<BytesRef>(foundClass, score);
} }
private double calculateLikelihood(String[] tokenizedDoc, BytesRef c) throws IOException { private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c) throws IOException {
// for each word // for each word
double result = 1d; double result = 0d;
for (String word : tokenizedDoc) { for (String word : tokenizedDoc) {
// search with text:word AND class:c // search with text:word AND class:c
int hits = getWordFreqForClass(word, c); int hits = getWordFreqForClass(word, c);
@ -171,10 +165,10 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
// P(w|c) = num/den // P(w|c) = num/den
double wordProbability = num / den; double wordProbability = num / den;
result *= wordProbability; result += Math.log(wordProbability);
} }
// P(d|c) = P(w1|c)*...*P(wn|c) // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c))
return result; return result;
} }
@ -205,8 +199,8 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return totalHitCountCollector.getTotalHits(); return totalHitCountCollector.getTotalHits();
} }
private double calculatePrior(BytesRef currentClass) throws IOException { private double calculateLogPrior(BytesRef currentClass) throws IOException {
return (double) docCount(currentClass) / docsWithClassSize; return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
} }
private int docCount(BytesRef countedClass) throws IOException { private int docCount(BytesRef countedClass) throws IOException {