mirror of https://github.com/apache/lucene.git
LUCENE-5338 - avoid considering unlabeled documents for training
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1540703 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
7315a90a7b
commit
debb363f6e
|
@ -23,12 +23,15 @@ import org.apache.lucene.index.AtomicReader;
|
|||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.StorableField;
|
||||
import org.apache.lucene.index.StoredDocument;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.IntsRef;
|
||||
import org.apache.lucene.util.fst.Builder;
|
||||
|
@ -160,14 +163,12 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
|
||||
int batchCount = 0;
|
||||
|
||||
Query q;
|
||||
BooleanQuery q = new BooleanQuery();
|
||||
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, "*")), BooleanClause.Occur.MUST));
|
||||
if (query != null) {
|
||||
q = query;
|
||||
q.add(new BooleanClause(query, BooleanClause.Occur.MUST));
|
||||
}
|
||||
else {
|
||||
q = new MatchAllDocsQuery();
|
||||
}
|
||||
// do a *:* search and use stored field values
|
||||
// run the search and use stored field values
|
||||
for (ScoreDoc scoreDoc : indexSearcher.search(q,
|
||||
Integer.MAX_VALUE).scoreDocs) {
|
||||
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.apache.lucene.classification;
|
|||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.index.AtomicReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.queries.mlt.MoreLikeThis;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
|
@ -25,6 +26,7 @@ import org.apache.lucene.search.IndexSearcher;
|
|||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -64,20 +66,16 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
if (mlt == null) {
|
||||
throw new IOException("You must first call Classifier#train");
|
||||
}
|
||||
Query q;
|
||||
BooleanQuery mltQuery = new BooleanQuery();
|
||||
for (String textFieldName : textFieldNames) {
|
||||
mltQuery.add(new BooleanClause(mlt.like(new StringReader(text), textFieldName), BooleanClause.Occur.SHOULD));
|
||||
}
|
||||
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
|
||||
mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
|
||||
if (query != null) {
|
||||
BooleanQuery bq = new BooleanQuery();
|
||||
bq.add(query, BooleanClause.Occur.MUST);
|
||||
bq.add(mltQuery, BooleanClause.Occur.MUST);
|
||||
q = bq;
|
||||
} else {
|
||||
q = mltQuery;
|
||||
mltQuery.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
TopDocs topDocs = indexSearcher.search(q, k);
|
||||
TopDocs topDocs = indexSearcher.search(mltQuery, k);
|
||||
return selectClassFromNeighbors(topDocs);
|
||||
}
|
||||
|
||||
|
|
|
@ -102,15 +102,10 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
|
||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
||||
Query q;
|
||||
BooleanQuery q = new BooleanQuery();
|
||||
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
||||
if (query != null) {
|
||||
BooleanQuery bq = new BooleanQuery();
|
||||
WildcardQuery wq = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
|
||||
bq.add(wq, BooleanClause.Occur.MUST);
|
||||
bq.add(query, BooleanClause.Occur.MUST);
|
||||
q = bq;
|
||||
} else {
|
||||
q = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
|
||||
q.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
indexSearcher.search(q,
|
||||
totalHitCountCollector);
|
||||
|
@ -191,7 +186,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||
}
|
||||
int docsWithC = atomicReader.docFreq(new Term(classFieldName, c));
|
||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
|
||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
||||
}
|
||||
|
||||
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
|
||||
|
|
|
@ -94,12 +94,11 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
|
||||
AtomicReader atomicReader = null;
|
||||
long trainStart = System.currentTimeMillis();
|
||||
long trainEnd = 0l;
|
||||
try {
|
||||
populatePerformanceIndex(analyzer);
|
||||
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||
classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
|
||||
trainEnd = System.currentTimeMillis();
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
|
||||
} finally {
|
||||
|
@ -212,6 +211,11 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
indexWriter.addDocument(doc, analyzer);
|
||||
|
||||
doc = new Document();
|
||||
text = "unlabeled doc";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
indexWriter.addDocument(doc, analyzer);
|
||||
|
||||
indexWriter.commit();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue