LUCENE-5338 - avoid considering unlabeled documents for training

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1540703 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2013-11-11 13:17:36 +00:00
parent 7315a90a7b
commit debb363f6e
4 changed files with 24 additions and 26 deletions

View File

@ -23,12 +23,15 @@ import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.StorableField;
import org.apache.lucene.index.StoredDocument;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.fst.Builder;
@ -160,14 +163,12 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
int batchCount = 0;
Query q;
BooleanQuery q = new BooleanQuery();
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, "*")), BooleanClause.Occur.MUST));
if (query != null) {
q = query;
q.add(new BooleanClause(query, BooleanClause.Occur.MUST));
}
else {
q = new MatchAllDocsQuery();
}
// do a *:* search and use stored field values
// run the search and use stored field values
for (ScoreDoc scoreDoc : indexSearcher.search(q,
Integer.MAX_VALUE).scoreDocs) {
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);

View File

@ -18,6 +18,7 @@ package org.apache.lucene.classification;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.mlt.MoreLikeThis;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
@ -25,6 +26,7 @@ import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import java.io.IOException;
@ -64,20 +66,16 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
if (mlt == null) {
throw new IOException("You must first call Classifier#train");
}
Query q;
BooleanQuery mltQuery = new BooleanQuery();
for (String textFieldName : textFieldNames) {
mltQuery.add(new BooleanClause(mlt.like(new StringReader(text), textFieldName), BooleanClause.Occur.SHOULD));
}
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
if (query != null) {
BooleanQuery bq = new BooleanQuery();
bq.add(query, BooleanClause.Occur.MUST);
bq.add(mltQuery, BooleanClause.Occur.MUST);
q = bq;
} else {
q = mltQuery;
mltQuery.add(query, BooleanClause.Occur.MUST);
}
TopDocs topDocs = indexSearcher.search(q, k);
TopDocs topDocs = indexSearcher.search(mltQuery, k);
return selectClassFromNeighbors(topDocs);
}

View File

@ -102,15 +102,10 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
Query q;
BooleanQuery q = new BooleanQuery();
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
if (query != null) {
BooleanQuery bq = new BooleanQuery();
WildcardQuery wq = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
bq.add(wq, BooleanClause.Occur.MUST);
bq.add(query, BooleanClause.Occur.MUST);
q = bq;
} else {
q = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
q.add(query, BooleanClause.Occur.MUST);
}
indexSearcher.search(q,
totalHitCountCollector);
@ -191,7 +186,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
}
int docsWithC = atomicReader.docFreq(new Term(classFieldName, c));
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
}
private int getWordFreqForClass(String word, BytesRef c) throws IOException {

View File

@ -94,12 +94,11 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
AtomicReader atomicReader = null;
long trainStart = System.currentTimeMillis();
long trainEnd = 0l;
try {
populatePerformanceIndex(analyzer);
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
trainEnd = System.currentTimeMillis();
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
} finally {
@ -212,6 +211,11 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "unlabeled doc";
doc.add(new Field(textFieldName, text, ft));
indexWriter.addDocument(doc, analyzer);
indexWriter.commit();
}
}