mirror of https://github.com/apache/lucene.git
[LUCENE-4345] - starting incorporating Simon's suggestions: using BytesRef and TotalHitCountCollector
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1384657 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
cf02188f2b
commit
dc3f1d7b3d
|
@ -24,6 +24,7 @@ import java.io.IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>
|
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>
|
||||||
|
* @lucene.experimental
|
||||||
*/
|
*/
|
||||||
public interface Classifier {
|
public interface Classifier {
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.lucene.search.BooleanClause;
|
||||||
import org.apache.lucene.search.BooleanQuery;
|
import org.apache.lucene.search.BooleanQuery;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.TermQuery;
|
import org.apache.lucene.search.TermQuery;
|
||||||
|
import org.apache.lucene.search.TotalHitCountCollector;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -38,6 +39,7 @@ import java.util.LinkedList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code>
|
* A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code>
|
||||||
|
* @lucene.experimental
|
||||||
*/
|
*/
|
||||||
public class SimpleNaiveBayesClassifier implements Classifier {
|
public class SimpleNaiveBayesClassifier implements Classifier {
|
||||||
|
|
||||||
|
@ -82,29 +84,27 @@ public class SimpleNaiveBayesClassifier implements Classifier {
|
||||||
if (atomicReader == null) {
|
if (atomicReader == null) {
|
||||||
throw new RuntimeException("need to train the classifier first");
|
throw new RuntimeException("need to train the classifier first");
|
||||||
}
|
}
|
||||||
Double max = 0d;
|
double max = 0d;
|
||||||
String foundClass = null;
|
String foundClass = null;
|
||||||
|
|
||||||
Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
|
Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
|
||||||
TermsEnum termsEnum = terms.iterator(null);
|
TermsEnum termsEnum = terms.iterator(null);
|
||||||
BytesRef t = termsEnum.next();
|
BytesRef next;
|
||||||
while (t != null) {
|
while((next = termsEnum.next()) != null) {
|
||||||
String classValue = t.utf8ToString();
|
|
||||||
// TODO : turn it to be in log scale
|
// TODO : turn it to be in log scale
|
||||||
Double clVal = calculatePrior(classValue) * calculateLikelihood(inputDocument, classValue);
|
double clVal = calculatePrior(next) * calculateLikelihood(inputDocument, next);
|
||||||
if (clVal > max) {
|
if (clVal > max) {
|
||||||
max = clVal;
|
max = clVal;
|
||||||
foundClass = classValue;
|
foundClass = next.utf8ToString();
|
||||||
}
|
}
|
||||||
t = termsEnum.next();
|
|
||||||
}
|
}
|
||||||
return foundClass;
|
return foundClass;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private Double calculateLikelihood(String document, String c) throws IOException {
|
private double calculateLikelihood(String document, BytesRef c) throws IOException {
|
||||||
// for each word
|
// for each word
|
||||||
Double result = 1d;
|
double result = 1d;
|
||||||
for (String word : tokenizeDoc(document)) {
|
for (String word : tokenizeDoc(document)) {
|
||||||
// search with text:word AND class:c
|
// search with text:word AND class:c
|
||||||
int hits = getWordFreqForClass(word, c);
|
int hits = getWordFreqForClass(word, c);
|
||||||
|
@ -124,7 +124,7 @@ public class SimpleNaiveBayesClassifier implements Classifier {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private double getTextTermFreqForClass(String c) throws IOException {
|
private double getTextTermFreqForClass(BytesRef c) throws IOException {
|
||||||
Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
|
Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
|
||||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||||
double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||||
|
@ -132,18 +132,20 @@ public class SimpleNaiveBayesClassifier implements Classifier {
|
||||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
|
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
|
||||||
}
|
}
|
||||||
|
|
||||||
private int getWordFreqForClass(String word, String c) throws IOException {
|
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
|
||||||
BooleanQuery booleanQuery = new BooleanQuery();
|
BooleanQuery booleanQuery = new BooleanQuery();
|
||||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
|
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
|
||||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
||||||
return indexSearcher.search(booleanQuery, 1).totalHits;
|
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
||||||
|
indexSearcher.search(booleanQuery, totalHitCountCollector);
|
||||||
|
return totalHitCountCollector.getTotalHits();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Double calculatePrior(String currentClass) throws IOException {
|
private double calculatePrior(BytesRef currentClass) throws IOException {
|
||||||
return (double) docCount(currentClass) / docsWithClassSize;
|
return (double) docCount(currentClass) / docsWithClassSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
private int docCount(String countedClass) throws IOException {
|
private int docCount(BytesRef countedClass) throws IOException {
|
||||||
return atomicReader.docFreq(new Term(classFieldName, countedClass));
|
return atomicReader.docFreq(new Term(classFieldName, countedClass));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue