diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java index 32e94881e3e..d20f722b8f5 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java @@ -16,12 +16,6 @@ */ package org.apache.lucene.classification; -import java.io.IOException; -import java.io.StringReader; -import java.util.Map; -import java.util.SortedMap; -import java.util.TreeMap; - import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; @@ -33,6 +27,7 @@ import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IntsRef; @@ -41,6 +36,11 @@ import org.apache.lucene.util.fst.FST; import org.apache.lucene.util.fst.PositiveIntOutputs; import org.apache.lucene.util.fst.Util; +import java.io.IOException; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + /** * A perceptron (see http://en.wikipedia.org/wiki/Perceptron) based * Boolean {@link org.apache.lucene.classification.Classifier}. The @@ -113,7 +113,16 @@ public class BooleanPerceptronClassifier implements Classifier { */ @Override public void train(AtomicReader atomicReader, String textFieldName, - String classFieldName, Analyzer analyzer) throws IOException { + String classFieldName, Analyzer analyzer) throws IOException { + train(atomicReader, textFieldName, classFieldName, analyzer, null); + } + + /** + * {@inheritDoc} + */ + @Override + public void train(AtomicReader atomicReader, String textFieldName, + String classFieldName, Analyzer analyzer, Query query) throws IOException { this.textTerms = MultiFields.getTerms(atomicReader, textFieldName); if (textTerms == null) { @@ -151,8 +160,15 @@ public class BooleanPerceptronClassifier implements Classifier { int batchCount = 0; + Query q; + if (query != null) { + q = query; + } + else { + q = new MatchAllDocsQuery(); + } // do a *:* search and use stored field values - for (ScoreDoc scoreDoc : indexSearcher.search(new MatchAllDocsQuery(), + for (ScoreDoc scoreDoc : indexSearcher.search(q, Integer.MAX_VALUE).scoreDocs) { StoredDocument doc = indexSearcher.doc(scoreDoc.doc); diff --git a/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java b/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java index 4d0fe2e2b4d..e5d10973559 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java @@ -18,6 +18,7 @@ package org.apache.lucene.classification; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.index.AtomicReader; +import org.apache.lucene.search.Query; import java.io.IOException; @@ -47,4 +48,16 @@ public interface Classifier { public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException; + /** + * Train the classifier using the underlying Lucene index + * @param atomicReader the reader to use to access the Lucene index + * @param textFieldName the name of the field used to compare documents + * @param classFieldName the name of the field containing the class assigned to documents + * @param analyzer the analyzer used to tokenize / filter the unseen text + * @param query the query to filter which documents use for training + * @throws IOException If there is a low-level I/O error. + */ + public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) + throws IOException; + } diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java index bbaa0566d5b..4084c611f72 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java @@ -19,6 +19,8 @@ package org.apache.lucene.classification; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.index.AtomicReader; import org.apache.lucene.queries.mlt.MoreLikeThis; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -43,6 +45,7 @@ public class KNearestNeighborClassifier implements Classifier { private String classFieldName; private IndexSearcher indexSearcher; private int k; + private Query query; /** * Create a {@link Classifier} using kNN algorithm @@ -61,7 +64,16 @@ public class KNearestNeighborClassifier implements Classifier { if (mlt == null) { throw new IOException("You must first call Classifier#train"); } - Query q = mlt.like(new StringReader(text), textFieldName); + Query q; + if (query != null) { + Query mltQuery = mlt.like(new StringReader(text), textFieldName); + BooleanQuery bq = new BooleanQuery(); + bq.add(query, BooleanClause.Occur.MUST); + bq.add(mltQuery, BooleanClause.Occur.MUST); + q = bq; + } else { + q = mlt.like(new StringReader(text), textFieldName); + } TopDocs topDocs = indexSearcher.search(q, k); return selectClassFromNeighbors(topDocs); } @@ -96,11 +108,20 @@ public class KNearestNeighborClassifier implements Classifier { */ @Override public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { + train(atomicReader, textFieldName, classFieldName, analyzer, null); + } + + /** + * {@inheritDoc} + */ + @Override + public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException { this.textFieldName = textFieldName; this.classFieldName = classFieldName; mlt = new MoreLikeThis(atomicReader); mlt.setAnalyzer(analyzer); mlt.setFieldNames(new String[]{textFieldName}); indexSearcher = new IndexSearcher(atomicReader); + this.query = query; } } diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java index 652c599b867..fa6c637a7c9 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java @@ -27,6 +27,7 @@ import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.WildcardQuery; @@ -49,6 +50,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { private int docsWithClassSize; private Analyzer analyzer; private IndexSearcher indexSearcher; + private Query query; /** * Creates a new NaiveBayes classifier. @@ -62,7 +64,7 @@ public class SimpleNaiveBayesClassifier implements Classifier { * {@inheritDoc} */ @Override - public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) + public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException { this.atomicReader = atomicReader; this.indexSearcher = new IndexSearcher(this.atomicReader); @@ -70,13 +72,29 @@ public class SimpleNaiveBayesClassifier implements Classifier { this.classFieldName = classFieldName; this.analyzer = analyzer; this.docsWithClassSize = countDocsWithClass(); + this.query = query; + } + + @Override + public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { + train(atomicReader, textFieldName, classFieldName, analyzer, null); } private int countDocsWithClass() throws IOException { int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount(); if (docCount == -1) { // in case codec doesn't support getDocCount TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); - indexSearcher.search(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), + Query q; + if (query != null) { + BooleanQuery bq = new BooleanQuery(); + WildcardQuery wq = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))); + bq.add(wq, BooleanClause.Occur.MUST); + bq.add(query, BooleanClause.Occur.MUST); + q = bq; + } else { + q = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))); + } + indexSearcher.search(q, totalHitCountCollector); docCount = totalHitCountCollector.getTotalHits(); } @@ -157,6 +175,9 @@ public class SimpleNaiveBayesClassifier implements Classifier { BooleanQuery booleanQuery = new BooleanQuery(); booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST)); + if (query != null) { + booleanQuery.add(query, BooleanClause.Occur.MUST); + } TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); indexSearcher.search(booleanQuery, totalHitCountCollector); return totalHitCountCollector.getTotalHits(); diff --git a/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java index c6b7b10b543..0ec84c9f147 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java @@ -17,6 +17,8 @@ package org.apache.lucene.classification; import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.TermQuery; import org.junit.Test; /** @@ -34,6 +36,11 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase extends LuceneTestCase { dir.close(); } - protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception { + checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null); + } + + protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception { AtomicReader atomicReader = null; try { populateSampleIndex(analyzer); atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); - classifier.train(atomicReader, textFieldName, classFieldName, analyzer); + classifier.train(atomicReader, textFieldName, classFieldName, analyzer, query); ClassificationResult classificationResult = classifier.assignClass(inputDoc); assertNotNull(classificationResult.getAssignedClass()); assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java index 664750a0f9b..7e754adb560 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java @@ -17,6 +17,8 @@ package org.apache.lucene.classification; import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.TermQuery; import org.apache.lucene.util.BytesRef; import org.junit.Test; @@ -30,6 +32,11 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase