mirror of https://github.com/apache/lucene.git
LUCENE-5284 - added method for training with a Query filter
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1532983 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
9cda012557
commit
e60f7af1d1
|
@ -16,12 +16,6 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.classification;
|
package org.apache.lucene.classification;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.StringReader;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.SortedMap;
|
|
||||||
import java.util.TreeMap;
|
|
||||||
|
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.analysis.TokenStream;
|
import org.apache.lucene.analysis.TokenStream;
|
||||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||||
|
@ -33,6 +27,7 @@ import org.apache.lucene.index.Terms;
|
||||||
import org.apache.lucene.index.TermsEnum;
|
import org.apache.lucene.index.TermsEnum;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.IntsRef;
|
import org.apache.lucene.util.IntsRef;
|
||||||
|
@ -41,6 +36,11 @@ import org.apache.lucene.util.fst.FST;
|
||||||
import org.apache.lucene.util.fst.PositiveIntOutputs;
|
import org.apache.lucene.util.fst.PositiveIntOutputs;
|
||||||
import org.apache.lucene.util.fst.Util;
|
import org.apache.lucene.util.fst.Util;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.SortedMap;
|
||||||
|
import java.util.TreeMap;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
|
* A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
|
||||||
* <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
|
* <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
|
||||||
|
@ -114,6 +114,15 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
@Override
|
@Override
|
||||||
public void train(AtomicReader atomicReader, String textFieldName,
|
public void train(AtomicReader atomicReader, String textFieldName,
|
||||||
String classFieldName, Analyzer analyzer) throws IOException {
|
String classFieldName, Analyzer analyzer) throws IOException {
|
||||||
|
train(atomicReader, textFieldName, classFieldName, analyzer, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void train(AtomicReader atomicReader, String textFieldName,
|
||||||
|
String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||||
this.textTerms = MultiFields.getTerms(atomicReader, textFieldName);
|
this.textTerms = MultiFields.getTerms(atomicReader, textFieldName);
|
||||||
|
|
||||||
if (textTerms == null) {
|
if (textTerms == null) {
|
||||||
|
@ -151,8 +160,15 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
|
|
||||||
int batchCount = 0;
|
int batchCount = 0;
|
||||||
|
|
||||||
|
Query q;
|
||||||
|
if (query != null) {
|
||||||
|
q = query;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
q = new MatchAllDocsQuery();
|
||||||
|
}
|
||||||
// do a *:* search and use stored field values
|
// do a *:* search and use stored field values
|
||||||
for (ScoreDoc scoreDoc : indexSearcher.search(new MatchAllDocsQuery(),
|
for (ScoreDoc scoreDoc : indexSearcher.search(q,
|
||||||
Integer.MAX_VALUE).scoreDocs) {
|
Integer.MAX_VALUE).scoreDocs) {
|
||||||
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
|
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.apache.lucene.classification;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.index.AtomicReader;
|
import org.apache.lucene.index.AtomicReader;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
@ -47,4 +48,16 @@ public interface Classifier<T> {
|
||||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
|
||||||
throws IOException;
|
throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Train the classifier using the underlying Lucene index
|
||||||
|
* @param atomicReader the reader to use to access the Lucene index
|
||||||
|
* @param textFieldName the name of the field used to compare documents
|
||||||
|
* @param classFieldName the name of the field containing the class assigned to documents
|
||||||
|
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||||
|
* @param query the query to filter which documents use for training
|
||||||
|
* @throws IOException If there is a low-level I/O error.
|
||||||
|
*/
|
||||||
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
||||||
|
throws IOException;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.lucene.classification;
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.index.AtomicReader;
|
import org.apache.lucene.index.AtomicReader;
|
||||||
import org.apache.lucene.queries.mlt.MoreLikeThis;
|
import org.apache.lucene.queries.mlt.MoreLikeThis;
|
||||||
|
import org.apache.lucene.search.BooleanClause;
|
||||||
|
import org.apache.lucene.search.BooleanQuery;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.Query;
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
@ -43,6 +45,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
private String classFieldName;
|
private String classFieldName;
|
||||||
private IndexSearcher indexSearcher;
|
private IndexSearcher indexSearcher;
|
||||||
private int k;
|
private int k;
|
||||||
|
private Query query;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a {@link Classifier} using kNN algorithm
|
* Create a {@link Classifier} using kNN algorithm
|
||||||
|
@ -61,7 +64,16 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
if (mlt == null) {
|
if (mlt == null) {
|
||||||
throw new IOException("You must first call Classifier#train");
|
throw new IOException("You must first call Classifier#train");
|
||||||
}
|
}
|
||||||
Query q = mlt.like(new StringReader(text), textFieldName);
|
Query q;
|
||||||
|
if (query != null) {
|
||||||
|
Query mltQuery = mlt.like(new StringReader(text), textFieldName);
|
||||||
|
BooleanQuery bq = new BooleanQuery();
|
||||||
|
bq.add(query, BooleanClause.Occur.MUST);
|
||||||
|
bq.add(mltQuery, BooleanClause.Occur.MUST);
|
||||||
|
q = bq;
|
||||||
|
} else {
|
||||||
|
q = mlt.like(new StringReader(text), textFieldName);
|
||||||
|
}
|
||||||
TopDocs topDocs = indexSearcher.search(q, k);
|
TopDocs topDocs = indexSearcher.search(q, k);
|
||||||
return selectClassFromNeighbors(topDocs);
|
return selectClassFromNeighbors(topDocs);
|
||||||
}
|
}
|
||||||
|
@ -96,11 +108,20 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
||||||
|
train(atomicReader, textFieldName, classFieldName, analyzer, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||||
this.textFieldName = textFieldName;
|
this.textFieldName = textFieldName;
|
||||||
this.classFieldName = classFieldName;
|
this.classFieldName = classFieldName;
|
||||||
mlt = new MoreLikeThis(atomicReader);
|
mlt = new MoreLikeThis(atomicReader);
|
||||||
mlt.setAnalyzer(analyzer);
|
mlt.setAnalyzer(analyzer);
|
||||||
mlt.setFieldNames(new String[]{textFieldName});
|
mlt.setFieldNames(new String[]{textFieldName});
|
||||||
indexSearcher = new IndexSearcher(atomicReader);
|
indexSearcher = new IndexSearcher(atomicReader);
|
||||||
|
this.query = query;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.lucene.index.TermsEnum;
|
||||||
import org.apache.lucene.search.BooleanClause;
|
import org.apache.lucene.search.BooleanClause;
|
||||||
import org.apache.lucene.search.BooleanQuery;
|
import org.apache.lucene.search.BooleanQuery;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.TermQuery;
|
import org.apache.lucene.search.TermQuery;
|
||||||
import org.apache.lucene.search.TotalHitCountCollector;
|
import org.apache.lucene.search.TotalHitCountCollector;
|
||||||
import org.apache.lucene.search.WildcardQuery;
|
import org.apache.lucene.search.WildcardQuery;
|
||||||
|
@ -49,6 +50,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
private int docsWithClassSize;
|
private int docsWithClassSize;
|
||||||
private Analyzer analyzer;
|
private Analyzer analyzer;
|
||||||
private IndexSearcher indexSearcher;
|
private IndexSearcher indexSearcher;
|
||||||
|
private Query query;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new NaiveBayes classifier.
|
* Creates a new NaiveBayes classifier.
|
||||||
|
@ -62,7 +64,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
this.atomicReader = atomicReader;
|
this.atomicReader = atomicReader;
|
||||||
this.indexSearcher = new IndexSearcher(this.atomicReader);
|
this.indexSearcher = new IndexSearcher(this.atomicReader);
|
||||||
|
@ -70,13 +72,29 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
this.classFieldName = classFieldName;
|
this.classFieldName = classFieldName;
|
||||||
this.analyzer = analyzer;
|
this.analyzer = analyzer;
|
||||||
this.docsWithClassSize = countDocsWithClass();
|
this.docsWithClassSize = countDocsWithClass();
|
||||||
|
this.query = query;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
||||||
|
train(atomicReader, textFieldName, classFieldName, analyzer, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private int countDocsWithClass() throws IOException {
|
private int countDocsWithClass() throws IOException {
|
||||||
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
|
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
|
||||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||||
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
||||||
indexSearcher.search(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))),
|
Query q;
|
||||||
|
if (query != null) {
|
||||||
|
BooleanQuery bq = new BooleanQuery();
|
||||||
|
WildcardQuery wq = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
|
||||||
|
bq.add(wq, BooleanClause.Occur.MUST);
|
||||||
|
bq.add(query, BooleanClause.Occur.MUST);
|
||||||
|
q = bq;
|
||||||
|
} else {
|
||||||
|
q = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
|
||||||
|
}
|
||||||
|
indexSearcher.search(q,
|
||||||
totalHitCountCollector);
|
totalHitCountCollector);
|
||||||
docCount = totalHitCountCollector.getTotalHits();
|
docCount = totalHitCountCollector.getTotalHits();
|
||||||
}
|
}
|
||||||
|
@ -157,6 +175,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
BooleanQuery booleanQuery = new BooleanQuery();
|
BooleanQuery booleanQuery = new BooleanQuery();
|
||||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
|
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
|
||||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
||||||
|
if (query != null) {
|
||||||
|
booleanQuery.add(query, BooleanClause.Occur.MUST);
|
||||||
|
}
|
||||||
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
||||||
indexSearcher.search(booleanQuery, totalHitCountCollector);
|
indexSearcher.search(booleanQuery, totalHitCountCollector);
|
||||||
return totalHitCountCollector.getTotalHits();
|
return totalHitCountCollector.getTotalHits();
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
package org.apache.lucene.classification;
|
package org.apache.lucene.classification;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.MockAnalyzer;
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.search.TermQuery;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -34,6 +36,11 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
||||||
checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
|
checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicUsageWithQuery() throws Exception {
|
||||||
|
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName, new TermQuery(new Term(textFieldName, "it")));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPerformance() throws Exception {
|
public void testPerformance() throws Exception {
|
||||||
checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName);
|
checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName);
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.lucene.document.TextField;
|
||||||
import org.apache.lucene.index.AtomicReader;
|
import org.apache.lucene.index.AtomicReader;
|
||||||
import org.apache.lucene.index.RandomIndexWriter;
|
import org.apache.lucene.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
|
@ -70,13 +71,16 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
dir.close();
|
dir.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
|
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
|
||||||
|
checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
|
||||||
AtomicReader atomicReader = null;
|
AtomicReader atomicReader = null;
|
||||||
try {
|
try {
|
||||||
populateSampleIndex(analyzer);
|
populateSampleIndex(analyzer);
|
||||||
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||||
classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
|
classifier.train(atomicReader, textFieldName, classFieldName, analyzer, query);
|
||||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||||
assertNotNull(classificationResult.getAssignedClass());
|
assertNotNull(classificationResult.getAssignedClass());
|
||||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
package org.apache.lucene.classification;
|
package org.apache.lucene.classification;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.MockAnalyzer;
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.search.TermQuery;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
@ -30,6 +32,11 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
||||||
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicUsageWithQuery() throws Exception {
|
||||||
|
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPerformance() throws Exception {
|
public void testPerformance() throws Exception {
|
||||||
checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);
|
checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);
|
||||||
|
|
|
@ -22,6 +22,8 @@ import org.apache.lucene.analysis.Tokenizer;
|
||||||
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
||||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
||||||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.search.TermQuery;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -41,6 +43,11 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicUsageWithQuery() throws Exception {
|
||||||
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNGramUsage() throws Exception {
|
public void testNGramUsage() throws Exception {
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
|
||||||
|
|
Loading…
Reference in New Issue