diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
index 32e94881e3e..d20f722b8f5 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
@@ -16,12 +16,6 @@
*/
package org.apache.lucene.classification;
-import java.io.IOException;
-import java.io.StringReader;
-import java.util.Map;
-import java.util.SortedMap;
-import java.util.TreeMap;
-
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
@@ -33,6 +27,7 @@ import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntsRef;
@@ -41,6 +36,11 @@ import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;
+import java.io.IOException;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
/**
* A perceptron (see http://en.wikipedia.org/wiki/Perceptron
) based
* Boolean
{@link org.apache.lucene.classification.Classifier}. The
@@ -113,7 +113,16 @@ public class BooleanPerceptronClassifier implements Classifier {
*/
@Override
public void train(AtomicReader atomicReader, String textFieldName,
- String classFieldName, Analyzer analyzer) throws IOException {
+ String classFieldName, Analyzer analyzer) throws IOException {
+ train(atomicReader, textFieldName, classFieldName, analyzer, null);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public void train(AtomicReader atomicReader, String textFieldName,
+ String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textTerms = MultiFields.getTerms(atomicReader, textFieldName);
if (textTerms == null) {
@@ -151,8 +160,15 @@ public class BooleanPerceptronClassifier implements Classifier {
int batchCount = 0;
+ Query q;
+ if (query != null) {
+ q = query;
+ }
+ else {
+ q = new MatchAllDocsQuery();
+ }
// do a *:* search and use stored field values
- for (ScoreDoc scoreDoc : indexSearcher.search(new MatchAllDocsQuery(),
+ for (ScoreDoc scoreDoc : indexSearcher.search(q,
Integer.MAX_VALUE).scoreDocs) {
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java b/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
index 4d0fe2e2b4d..e5d10973559 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
@@ -18,6 +18,7 @@ package org.apache.lucene.classification;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.AtomicReader;
+import org.apache.lucene.search.Query;
import java.io.IOException;
@@ -47,4 +48,16 @@ public interface Classifier {
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
throws IOException;
+ /**
+ * Train the classifier using the underlying Lucene index
+ * @param atomicReader the reader to use to access the Lucene index
+ * @param textFieldName the name of the field used to compare documents
+ * @param classFieldName the name of the field containing the class assigned to documents
+ * @param analyzer the analyzer used to tokenize / filter the unseen text
+ * @param query the query to filter which documents use for training
+ * @throws IOException If there is a low-level I/O error.
+ */
+ public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
+ throws IOException;
+
}
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
index bbaa0566d5b..4084c611f72 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
@@ -19,6 +19,8 @@ package org.apache.lucene.classification;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.queries.mlt.MoreLikeThis;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
@@ -43,6 +45,7 @@ public class KNearestNeighborClassifier implements Classifier {
private String classFieldName;
private IndexSearcher indexSearcher;
private int k;
+ private Query query;
/**
* Create a {@link Classifier} using kNN algorithm
@@ -61,7 +64,16 @@ public class KNearestNeighborClassifier implements Classifier {
if (mlt == null) {
throw new IOException("You must first call Classifier#train");
}
- Query q = mlt.like(new StringReader(text), textFieldName);
+ Query q;
+ if (query != null) {
+ Query mltQuery = mlt.like(new StringReader(text), textFieldName);
+ BooleanQuery bq = new BooleanQuery();
+ bq.add(query, BooleanClause.Occur.MUST);
+ bq.add(mltQuery, BooleanClause.Occur.MUST);
+ q = bq;
+ } else {
+ q = mlt.like(new StringReader(text), textFieldName);
+ }
TopDocs topDocs = indexSearcher.search(q, k);
return selectClassFromNeighbors(topDocs);
}
@@ -96,11 +108,20 @@ public class KNearestNeighborClassifier implements Classifier {
*/
@Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
+ train(atomicReader, textFieldName, classFieldName, analyzer, null);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textFieldName = textFieldName;
this.classFieldName = classFieldName;
mlt = new MoreLikeThis(atomicReader);
mlt.setAnalyzer(analyzer);
mlt.setFieldNames(new String[]{textFieldName});
indexSearcher = new IndexSearcher(atomicReader);
+ this.query = query;
}
}
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
index 652c599b867..fa6c637a7c9 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
@@ -27,6 +27,7 @@ import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery;
@@ -49,6 +50,7 @@ public class SimpleNaiveBayesClassifier implements Classifier {
private int docsWithClassSize;
private Analyzer analyzer;
private IndexSearcher indexSearcher;
+ private Query query;
/**
* Creates a new NaiveBayes classifier.
@@ -62,7 +64,7 @@ public class SimpleNaiveBayesClassifier implements Classifier {
* {@inheritDoc}
*/
@Override
- public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
+ public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
throws IOException {
this.atomicReader = atomicReader;
this.indexSearcher = new IndexSearcher(this.atomicReader);
@@ -70,13 +72,29 @@ public class SimpleNaiveBayesClassifier implements Classifier {
this.classFieldName = classFieldName;
this.analyzer = analyzer;
this.docsWithClassSize = countDocsWithClass();
+ this.query = query;
+ }
+
+ @Override
+ public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
+ train(atomicReader, textFieldName, classFieldName, analyzer, null);
}
private int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
- indexSearcher.search(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))),
+ Query q;
+ if (query != null) {
+ BooleanQuery bq = new BooleanQuery();
+ WildcardQuery wq = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
+ bq.add(wq, BooleanClause.Occur.MUST);
+ bq.add(query, BooleanClause.Occur.MUST);
+ q = bq;
+ } else {
+ q = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
+ }
+ indexSearcher.search(q,
totalHitCountCollector);
docCount = totalHitCountCollector.getTotalHits();
}
@@ -157,6 +175,9 @@ public class SimpleNaiveBayesClassifier implements Classifier {
BooleanQuery booleanQuery = new BooleanQuery();
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
+ if (query != null) {
+ booleanQuery.add(query, BooleanClause.Occur.MUST);
+ }
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
indexSearcher.search(booleanQuery, totalHitCountCollector);
return totalHitCountCollector.getTotalHits();
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
index c6b7b10b543..0ec84c9f147 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
@@ -17,6 +17,8 @@
package org.apache.lucene.classification;
import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.TermQuery;
import org.junit.Test;
/**
@@ -34,6 +36,11 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase extends LuceneTestCase {
dir.close();
}
-
protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
+ checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
+ }
+
+ protected void checkCorrectClassification(Classifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
AtomicReader atomicReader = null;
try {
populateSampleIndex(analyzer);
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
- classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
+ classifier.train(atomicReader, textFieldName, classFieldName, analyzer, query);
ClassificationResult classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
index 664750a0f9b..7e754adb560 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
@@ -17,6 +17,8 @@
package org.apache.lucene.classification;
import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.junit.Test;
@@ -30,6 +32,11 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase