LUCENE-5284 - added method for training with a Query filter

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1532983 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2013-10-17 07:02:59 +00:00
parent 9cda012557
commit e60f7af1d1
8 changed files with 109 additions and 13 deletions

View File

@ -16,12 +16,6 @@
*/ */
package org.apache.lucene.classification; package org.apache.lucene.classification;
import java.io.IOException;
import java.io.StringReader;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
@ -33,6 +27,7 @@ import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum; import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntsRef; import org.apache.lucene.util.IntsRef;
@ -41,6 +36,11 @@ import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.PositiveIntOutputs; import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util; import org.apache.lucene.util.fst.Util;
import java.io.IOException;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
/** /**
* A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based * A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
* <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The * <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
@ -113,7 +113,16 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
*/ */
@Override @Override
public void train(AtomicReader atomicReader, String textFieldName, public void train(AtomicReader atomicReader, String textFieldName,
String classFieldName, Analyzer analyzer) throws IOException { String classFieldName, Analyzer analyzer) throws IOException {
train(atomicReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(AtomicReader atomicReader, String textFieldName,
String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textTerms = MultiFields.getTerms(atomicReader, textFieldName); this.textTerms = MultiFields.getTerms(atomicReader, textFieldName);
if (textTerms == null) { if (textTerms == null) {
@ -151,8 +160,15 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
int batchCount = 0; int batchCount = 0;
Query q;
if (query != null) {
q = query;
}
else {
q = new MatchAllDocsQuery();
}
// do a *:* search and use stored field values // do a *:* search and use stored field values
for (ScoreDoc scoreDoc : indexSearcher.search(new MatchAllDocsQuery(), for (ScoreDoc scoreDoc : indexSearcher.search(q,
Integer.MAX_VALUE).scoreDocs) { Integer.MAX_VALUE).scoreDocs) {
StoredDocument doc = indexSearcher.doc(scoreDoc.doc); StoredDocument doc = indexSearcher.doc(scoreDoc.doc);

View File

@ -18,6 +18,7 @@ package org.apache.lucene.classification;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.AtomicReader; import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.search.Query;
import java.io.IOException; import java.io.IOException;
@ -47,4 +48,16 @@ public interface Classifier<T> {
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
throws IOException; throws IOException;
/**
* Train the classifier using the underlying Lucene index
* @param atomicReader the reader to use to access the Lucene index
* @param textFieldName the name of the field used to compare documents
* @param classFieldName the name of the field containing the class assigned to documents
* @param analyzer the analyzer used to tokenize / filter the unseen text
* @param query the query to filter which documents use for training
* @throws IOException If there is a low-level I/O error.
*/
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
throws IOException;
} }

View File

@ -19,6 +19,8 @@ package org.apache.lucene.classification;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.AtomicReader; import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.queries.mlt.MoreLikeThis; import org.apache.lucene.queries.mlt.MoreLikeThis;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
@ -43,6 +45,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
private String classFieldName; private String classFieldName;
private IndexSearcher indexSearcher; private IndexSearcher indexSearcher;
private int k; private int k;
private Query query;
/** /**
* Create a {@link Classifier} using kNN algorithm * Create a {@link Classifier} using kNN algorithm
@ -61,7 +64,16 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
if (mlt == null) { if (mlt == null) {
throw new IOException("You must first call Classifier#train"); throw new IOException("You must first call Classifier#train");
} }
Query q = mlt.like(new StringReader(text), textFieldName); Query q;
if (query != null) {
Query mltQuery = mlt.like(new StringReader(text), textFieldName);
BooleanQuery bq = new BooleanQuery();
bq.add(query, BooleanClause.Occur.MUST);
bq.add(mltQuery, BooleanClause.Occur.MUST);
q = bq;
} else {
q = mlt.like(new StringReader(text), textFieldName);
}
TopDocs topDocs = indexSearcher.search(q, k); TopDocs topDocs = indexSearcher.search(q, k);
return selectClassFromNeighbors(topDocs); return selectClassFromNeighbors(topDocs);
} }
@ -96,11 +108,20 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
*/ */
@Override @Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
train(atomicReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textFieldName = textFieldName; this.textFieldName = textFieldName;
this.classFieldName = classFieldName; this.classFieldName = classFieldName;
mlt = new MoreLikeThis(atomicReader); mlt = new MoreLikeThis(atomicReader);
mlt.setAnalyzer(analyzer); mlt.setAnalyzer(analyzer);
mlt.setFieldNames(new String[]{textFieldName}); mlt.setFieldNames(new String[]{textFieldName});
indexSearcher = new IndexSearcher(atomicReader); indexSearcher = new IndexSearcher(atomicReader);
this.query = query;
} }
} }

View File

@ -27,6 +27,7 @@ import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery; import org.apache.lucene.search.WildcardQuery;
@ -49,6 +50,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
private int docsWithClassSize; private int docsWithClassSize;
private Analyzer analyzer; private Analyzer analyzer;
private IndexSearcher indexSearcher; private IndexSearcher indexSearcher;
private Query query;
/** /**
* Creates a new NaiveBayes classifier. * Creates a new NaiveBayes classifier.
@ -62,7 +64,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* {@inheritDoc} * {@inheritDoc}
*/ */
@Override @Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
throws IOException { throws IOException {
this.atomicReader = atomicReader; this.atomicReader = atomicReader;
this.indexSearcher = new IndexSearcher(this.atomicReader); this.indexSearcher = new IndexSearcher(this.atomicReader);
@ -70,13 +72,29 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
this.classFieldName = classFieldName; this.classFieldName = classFieldName;
this.analyzer = analyzer; this.analyzer = analyzer;
this.docsWithClassSize = countDocsWithClass(); this.docsWithClassSize = countDocsWithClass();
this.query = query;
}
@Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
train(atomicReader, textFieldName, classFieldName, analyzer, null);
} }
private int countDocsWithClass() throws IOException { private int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount(); int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount if (docCount == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
indexSearcher.search(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), Query q;
if (query != null) {
BooleanQuery bq = new BooleanQuery();
WildcardQuery wq = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
bq.add(wq, BooleanClause.Occur.MUST);
bq.add(query, BooleanClause.Occur.MUST);
q = bq;
} else {
q = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
}
indexSearcher.search(q,
totalHitCountCollector); totalHitCountCollector);
docCount = totalHitCountCollector.getTotalHits(); docCount = totalHitCountCollector.getTotalHits();
} }
@ -157,6 +175,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
BooleanQuery booleanQuery = new BooleanQuery(); BooleanQuery booleanQuery = new BooleanQuery();
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST);
}
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
indexSearcher.search(booleanQuery, totalHitCountCollector); indexSearcher.search(booleanQuery, totalHitCountCollector);
return totalHitCountCollector.getTotalHits(); return totalHitCountCollector.getTotalHits();

View File

@ -17,6 +17,8 @@
package org.apache.lucene.classification; package org.apache.lucene.classification;
import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.junit.Test; import org.junit.Test;
/** /**
@ -34,6 +36,11 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName); checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
} }
@Test
public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName, new TermQuery(new Term(textFieldName, "it")));
}
@Test @Test
public void testPerformance() throws Exception { public void testPerformance() throws Exception {
checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName); checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName);

View File

@ -24,6 +24,7 @@ import org.apache.lucene.document.TextField;
import org.apache.lucene.index.AtomicReader; import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.SlowCompositeReaderWrapper; import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
@ -70,13 +71,16 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
dir.close(); dir.close();
} }
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception { protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
}
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
AtomicReader atomicReader = null; AtomicReader atomicReader = null;
try { try {
populateSampleIndex(analyzer); populateSampleIndex(analyzer);
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(atomicReader, textFieldName, classFieldName, analyzer); classifier.train(atomicReader, textFieldName, classFieldName, analyzer, query);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc); ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass()); assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());

View File

@ -17,6 +17,8 @@
package org.apache.lucene.classification; package org.apache.lucene.classification;
import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.junit.Test; import org.junit.Test;
@ -30,6 +32,11 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName); checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
} }
@Test
public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
}
@Test @Test
public void testPerformance() throws Exception { public void testPerformance() throws Exception {
checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName); checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);

View File

@ -22,6 +22,8 @@ import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.core.KeywordTokenizer; import org.apache.lucene.analysis.core.KeywordTokenizer;
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter; import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.apache.lucene.analysis.reverse.ReverseStringFilter; import org.apache.lucene.analysis.reverse.ReverseStringFilter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
import org.junit.Test; import org.junit.Test;
@ -41,6 +43,11 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName); checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
} }
@Test
public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
}
@Test @Test
public void testNGramUsage() throws Exception { public void testNGramUsage() throws Exception {
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName); checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);