mirror of https://github.com/apache/lucene.git
LUCENE-4927 - switched to log prior/likelihood to avoid possible underflows
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1544433 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
eb1dcbaa70
commit
f9b3e389b2
|
@ -64,23 +64,17 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
||||||
throws IOException {
|
train(atomicReader, textFieldName, classFieldName, analyzer, null);
|
||||||
this.atomicReader = atomicReader;
|
|
||||||
this.indexSearcher = new IndexSearcher(this.atomicReader);
|
|
||||||
this.textFieldNames = new String[]{textFieldName};
|
|
||||||
this.classFieldName = classFieldName;
|
|
||||||
this.analyzer = analyzer;
|
|
||||||
this.docsWithClassSize = countDocsWithClass();
|
|
||||||
this.query = query;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
||||||
train(atomicReader, textFieldName, classFieldName, analyzer, null);
|
throws IOException {
|
||||||
|
train(atomicReader, new String[]{textFieldName}, classFieldName, analyzer, query);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -137,7 +131,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
if (atomicReader == null) {
|
if (atomicReader == null) {
|
||||||
throw new IOException("You must first call Classifier#train");
|
throw new IOException("You must first call Classifier#train");
|
||||||
}
|
}
|
||||||
double max = 0d;
|
double max = - Double.MAX_VALUE;
|
||||||
BytesRef foundClass = new BytesRef();
|
BytesRef foundClass = new BytesRef();
|
||||||
|
|
||||||
Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
|
Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
|
||||||
|
@ -145,20 +139,20 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
BytesRef next;
|
BytesRef next;
|
||||||
String[] tokenizedDoc = tokenizeDoc(inputDocument);
|
String[] tokenizedDoc = tokenizeDoc(inputDocument);
|
||||||
while ((next = termsEnum.next()) != null) {
|
while ((next = termsEnum.next()) != null) {
|
||||||
// TODO : turn it to be in log scale
|
double clVal = calculateLogPrior(next) + calculateLogLikelihood(tokenizedDoc, next);
|
||||||
double clVal = calculatePrior(next) * calculateLikelihood(tokenizedDoc, next);
|
|
||||||
if (clVal > max) {
|
if (clVal > max) {
|
||||||
max = clVal;
|
max = clVal;
|
||||||
foundClass = BytesRef.deepCopyOf(next);
|
foundClass = BytesRef.deepCopyOf(next);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return new ClassificationResult<BytesRef>(foundClass, max);
|
double score = 10 / Math.abs(max);
|
||||||
|
return new ClassificationResult<BytesRef>(foundClass, score);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private double calculateLikelihood(String[] tokenizedDoc, BytesRef c) throws IOException {
|
private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c) throws IOException {
|
||||||
// for each word
|
// for each word
|
||||||
double result = 1d;
|
double result = 0d;
|
||||||
for (String word : tokenizedDoc) {
|
for (String word : tokenizedDoc) {
|
||||||
// search with text:word AND class:c
|
// search with text:word AND class:c
|
||||||
int hits = getWordFreqForClass(word, c);
|
int hits = getWordFreqForClass(word, c);
|
||||||
|
@ -171,10 +165,10 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
|
|
||||||
// P(w|c) = num/den
|
// P(w|c) = num/den
|
||||||
double wordProbability = num / den;
|
double wordProbability = num / den;
|
||||||
result *= wordProbability;
|
result += Math.log(wordProbability);
|
||||||
}
|
}
|
||||||
|
|
||||||
// P(d|c) = P(w1|c)*...*P(wn|c)
|
// log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c))
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,8 +199,8 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
return totalHitCountCollector.getTotalHits();
|
return totalHitCountCollector.getTotalHits();
|
||||||
}
|
}
|
||||||
|
|
||||||
private double calculatePrior(BytesRef currentClass) throws IOException {
|
private double calculateLogPrior(BytesRef currentClass) throws IOException {
|
||||||
return (double) docCount(currentClass) / docsWithClassSize;
|
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
private int docCount(BytesRef countedClass) throws IOException {
|
private int docCount(BytesRef countedClass) throws IOException {
|
||||||
|
|
Loading…
Reference in New Issue