[LUCENE-4345] - starting incorporating Simon's suggestions: using BytesRef and TotalHitCountCollector

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1384657 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2012-09-14 06:55:12 +00:00
parent cf02188f2b
commit dc3f1d7b3d
2 changed files with 17 additions and 14 deletions

View File

@ -24,6 +24,7 @@ import java.io.IOException;
/** /**
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code> * A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>
* @lucene.experimental
*/ */
public interface Classifier { public interface Classifier {

View File

@ -29,6 +29,7 @@ import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import java.io.IOException; import java.io.IOException;
@ -38,6 +39,7 @@ import java.util.LinkedList;
/** /**
* A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code> * A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code>
* @lucene.experimental
*/ */
public class SimpleNaiveBayesClassifier implements Classifier { public class SimpleNaiveBayesClassifier implements Classifier {
@ -82,29 +84,27 @@ public class SimpleNaiveBayesClassifier implements Classifier {
if (atomicReader == null) { if (atomicReader == null) {
throw new RuntimeException("need to train the classifier first"); throw new RuntimeException("need to train the classifier first");
} }
Double max = 0d; double max = 0d;
String foundClass = null; String foundClass = null;
Terms terms = MultiFields.getTerms(atomicReader, classFieldName); Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
TermsEnum termsEnum = terms.iterator(null); TermsEnum termsEnum = terms.iterator(null);
BytesRef t = termsEnum.next(); BytesRef next;
while (t != null) { while((next = termsEnum.next()) != null) {
String classValue = t.utf8ToString();
// TODO : turn it to be in log scale // TODO : turn it to be in log scale
Double clVal = calculatePrior(classValue) * calculateLikelihood(inputDocument, classValue); double clVal = calculatePrior(next) * calculateLikelihood(inputDocument, next);
if (clVal > max) { if (clVal > max) {
max = clVal; max = clVal;
foundClass = classValue; foundClass = next.utf8ToString();
} }
t = termsEnum.next();
} }
return foundClass; return foundClass;
} }
private Double calculateLikelihood(String document, String c) throws IOException { private double calculateLikelihood(String document, BytesRef c) throws IOException {
// for each word // for each word
Double result = 1d; double result = 1d;
for (String word : tokenizeDoc(document)) { for (String word : tokenizeDoc(document)) {
// search with text:word AND class:c // search with text:word AND class:c
int hits = getWordFreqForClass(word, c); int hits = getWordFreqForClass(word, c);
@ -124,7 +124,7 @@ public class SimpleNaiveBayesClassifier implements Classifier {
return result; return result;
} }
private double getTextTermFreqForClass(String c) throws IOException { private double getTextTermFreqForClass(BytesRef c) throws IOException {
Terms terms = MultiFields.getTerms(atomicReader, textFieldName); Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
@ -132,18 +132,20 @@ public class SimpleNaiveBayesClassifier implements Classifier {
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
} }
private int getWordFreqForClass(String word, String c) throws IOException { private int getWordFreqForClass(String word, BytesRef c) throws IOException {
BooleanQuery booleanQuery = new BooleanQuery(); BooleanQuery booleanQuery = new BooleanQuery();
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
return indexSearcher.search(booleanQuery, 1).totalHits; TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
indexSearcher.search(booleanQuery, totalHitCountCollector);
return totalHitCountCollector.getTotalHits();
} }
private Double calculatePrior(String currentClass) throws IOException { private double calculatePrior(BytesRef currentClass) throws IOException {
return (double) docCount(currentClass) / docsWithClassSize; return (double) docCount(currentClass) / docsWithClassSize;
} }
private int docCount(String countedClass) throws IOException { private int docCount(BytesRef countedClass) throws IOException {
return atomicReader.docFreq(new Term(classFieldName, countedClass)); return atomicReader.docFreq(new Term(classFieldName, countedClass));
} }
} }