LUCENE-6045 - refactor Classifier API to work better with multithreading

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1676997 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-04-30 14:12:03 +00:00
parent 7736e49b3e
commit 92842e7c34
13 changed files with 282 additions and 374 deletions

View File

@ -58,76 +58,14 @@ import org.apache.lucene.util.fst.Util;
*/
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
private Double threshold;
private final Integer batchSize;
private Terms textTerms;
private Analyzer analyzer;
private String textFieldName;
private final Double threshold;
private final Terms textTerms;
private final Analyzer analyzer;
private final String textFieldName;
private FST<Long> fst;
/**
* Create a {@link BooleanPerceptronClassifier}
*
* @param threshold the binary threshold for perceptron output evaluation
*/
public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
this.threshold = threshold;
this.batchSize = batchSize;
}
/**
* Default constructor, no batch updates of FST, perceptron threshold is
* calculated via underlying index metrics during
* {@link #train(org.apache.lucene.index.LeafReader, String, String, org.apache.lucene.analysis.Analyzer)
* training}
*/
public BooleanPerceptronClassifier() {
batchSize = 1;
}
/**
* {@inheritDoc}
*/
@Override
public ClassificationResult<Boolean> assignClass(String text)
throws IOException {
if (textTerms == null) {
throw new IOException("You must first call Classifier#train");
}
Long output = 0l;
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream
.addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
String s = charTermAttribute.toString();
Long d = Util.get(fst, new BytesRef(s));
if (d != null) {
output += d;
}
}
tokenStream.end();
}
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
return new ClassificationResult<>(output >= threshold, score);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName,
String classFieldName, Analyzer analyzer) throws IOException {
train(leafReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName,
String classFieldName, Analyzer analyzer, Query query) throws IOException {
public BooleanPerceptronClassifier(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer,
Query query, Integer batchSize, Double threshold) throws IOException {
this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
if (textTerms == null) {
@ -144,9 +82,11 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
this.threshold = (double) sumDocFreq / 2d;
} else {
throw new IOException(
"threshold cannot be assigned since term vectors for field "
+ textFieldName + " do not exist");
"threshold cannot be assigned since term vectors for field "
+ textFieldName + " do not exist");
}
} else {
this.threshold = threshold;
}
// TODO : remove this map as soon as we have a writable FST
@ -170,7 +110,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
}
// run the search and use stored field values
for (ScoreDoc scoreDoc : indexSearcher.search(q,
Integer.MAX_VALUE).scoreDocs) {
Integer.MAX_VALUE).scoreDocs) {
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
StorableField textField = doc.getField(textFieldName);
@ -187,7 +127,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
long modifier = correctClass.compareTo(assignedClass);
if (modifier != 0) {
updateWeights(leafReader, scoreDoc.doc, assignedClass,
weights, modifier, batchCount % batchSize == 0);
weights, modifier, batchCount % batchSize == 0);
}
batchCount++;
}
@ -195,11 +135,6 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
weights.clear(); // free memory while waiting for GC
}
@Override
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
}
private void updateWeights(LeafReader leafReader,
int docId, Boolean assignedClass, SortedMap<String, Double> weights,
double modifier, boolean updateFST) throws IOException {
@ -210,7 +145,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
if (terms == null) {
throw new IOException("term vectors must be stored for field "
+ textFieldName);
+ textFieldName);
}
TermsEnum termsEnum = terms.iterator();
@ -240,17 +175,46 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
for (Map.Entry<String, Double> entry : weights.entrySet()) {
scratchBytes.copyChars(entry.getKey());
fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
.getValue().longValue());
.getValue().longValue());
}
fst = fstBuilder.finish();
}
/**
* {@inheritDoc}
*/
@Override
public ClassificationResult<Boolean> assignClass(String text)
throws IOException {
if (textTerms == null) {
throw new IOException("You must first call Classifier#train");
}
Long output = 0l;
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream
.addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
String s = charTermAttribute.toString();
Long d = Util.get(fst, new BytesRef(s));
if (d != null) {
output += d;
}
}
tokenStream.end();
}
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
return new ClassificationResult<>(output >= threshold, score);
}
/**
* {@inheritDoc}
*/
@Override
public List<ClassificationResult<Boolean>> getClasses(String text)
throws IOException {
throws IOException {
throw new RuntimeException("not implemented");
}
@ -259,7 +223,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
*/
@Override
public List<ClassificationResult<Boolean>> getClasses(String text, int max)
throws IOException {
throws IOException {
throw new RuntimeException("not implemented");
}

View File

@ -49,50 +49,30 @@ import org.apache.lucene.util.BytesRef;
*/
public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
//for caching classes this will be the classification class list
private ArrayList<BytesRef> cclasses = new ArrayList<>();
private final ArrayList<BytesRef> cclasses = new ArrayList<>();
// it's a term-inmap style map, where the inmap contains class-hit pairs to the
// upper term
private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>();
private final Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>();
// the term frequency in classes
private Map<BytesRef, Double> classTermFreq = new HashMap<>();
private final Map<BytesRef, Double> classTermFreq = new HashMap<>();
private boolean justCachedTerms;
private int docsWithClassSize;
/**
* Creates a new NaiveBayes classifier with inside caching. Note that you must
* call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before
* you can classify any documents. If you want less memory usage you could
* Creates a new NaiveBayes classifier with inside caching. If you want less memory usage you could
* call {@link #reInitCache(int, boolean) reInitCache()}.
*/
public CachingNaiveBayesClassifier() {
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
train(leafReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
super.train(leafReader, textFieldNames, classFieldName, analyzer, query);
public CachingNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
super(leafReader, analyzer, query, classFieldName, textFieldNames);
// building the cache
reInitCache(0, true);
try {
reInitCache(0, true);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
if (leafReader == null) {
throw new IOException("You must first call Classifier#train");

View File

@ -18,17 +18,19 @@ package org.apache.lucene.classification;
/**
* The result of a call to {@link Classifier#assignClass(String)} holding an assigned class of type <code>T</code> and a score.
*
* @lucene.experimental
*/
public class ClassificationResult<T> implements Comparable<ClassificationResult<T>>{
public class ClassificationResult<T> implements Comparable<ClassificationResult<T>> {
private final T assignedClass;
private double score;
/**
* Constructor
*
* @param assignedClass the class <code>T</code> assigned by a {@link Classifier}
* @param score the score for the assignedClass as a <code>double</code>
* @param score the score for the assignedClass as a <code>double</code>
*/
public ClassificationResult(T assignedClass, double score) {
this.assignedClass = assignedClass;
@ -37,6 +39,7 @@ public class ClassificationResult<T> implements Comparable<ClassificationResult<
/**
* retrieve the result class
*
* @return a <code>T</code> representing an assigned class
*/
public T getAssignedClass() {
@ -45,14 +48,16 @@ public class ClassificationResult<T> implements Comparable<ClassificationResult<
/**
* retrieve the result score
*
* @return a <code>double</code> representing a result score
*/
public double getScore() {
return score;
}
/**
* set the score value
*
* @param score the score for the assignedClass as a <code>double</code>
*/
public void setScore(double score) {

View File

@ -22,7 +22,6 @@ import java.util.List;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
/**
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
@ -39,7 +38,7 @@ public interface Classifier<T> {
* @return a {@link ClassificationResult} holding assigned class of type <code>T</code> and score
* @throws IOException If there is a low-level I/O error.
*/
public ClassificationResult<T> assignClass(String text) throws IOException;
ClassificationResult<T> assignClass(String text) throws IOException;
/**
* Get all the classes (sorted by score, descending) assigned to the given text String.
@ -48,7 +47,7 @@ public interface Classifier<T> {
* @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
* @throws IOException If there is a low-level I/O error.
*/
public List<ClassificationResult<T>> getClasses(String text) throws IOException;
List<ClassificationResult<T>> getClasses(String text) throws IOException;
/**
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
@ -58,44 +57,6 @@ public interface Classifier<T> {
* @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
* @throws IOException If there is a low-level I/O error.
*/
public List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
/**
* Train the classifier using the underlying Lucene index
*
* @param leafReader the reader to use to access the Lucene index
* @param textFieldName the name of the field used to compare documents
* @param classFieldName the name of the field containing the class assigned to documents
* @param analyzer the analyzer used to tokenize / filter the unseen text
* @throws IOException If there is a low-level I/O error.
*/
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer)
throws IOException;
/**
* Train the classifier using the underlying Lucene index
*
* @param leafReader the reader to use to access the Lucene index
* @param textFieldName the name of the field used to compare documents
* @param classFieldName the name of the field containing the class assigned to documents
* @param analyzer the analyzer used to tokenize / filter the unseen text
* @param query the query to filter which documents use for training
* @throws IOException If there is a low-level I/O error.
*/
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
throws IOException;
/**
* Train the classifier using the underlying Lucene index
*
* @param leafReader the reader to use to access the Lucene index
* @param textFieldNames the names of the fields to be used to compare documents
* @param classFieldName the name of the field containing the class assigned to documents
* @param analyzer the analyzer used to tokenize / filter the unseen text
* @param query the query to filter which documents use for training
* @throws IOException If there is a low-level I/O error.
*/
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
throws IOException;
List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
}

View File

@ -26,6 +26,7 @@ import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.StorableField;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.mlt.MoreLikeThis;
import org.apache.lucene.search.BooleanClause;
@ -45,37 +46,31 @@ import org.apache.lucene.util.BytesRef;
*/
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
private MoreLikeThis mlt;
private String[] textFieldNames;
private String classFieldName;
private IndexSearcher indexSearcher;
private final MoreLikeThis mlt;
private final String[] textFieldNames;
private final String classFieldName;
private final IndexSearcher indexSearcher;
private final int k;
private Query query;
private final Query query;
private int minDocsFreq;
private int minTermFreq;
/**
* Create a {@link Classifier} using kNN algorithm
*
* @param k the number of neighbors to analyze as an <code>int</code>
*/
public KNearestNeighborClassifier(int k) {
public KNearestNeighborClassifier(LeafReader leafReader, Analyzer analyzer, Query query, int k, int minDocsFreq,
int minTermFreq, String classFieldName, String... textFieldNames) {
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
this.mlt = new MoreLikeThis(leafReader);
this.mlt.setAnalyzer(analyzer);
this.mlt.setFieldNames(textFieldNames);
this.indexSearcher = new IndexSearcher(leafReader);
if (minDocsFreq > 0) {
mlt.setMinDocFreq(minDocsFreq);
}
if (minTermFreq > 0) {
mlt.setMinTermFreq(minTermFreq);
}
this.query = query;
this.k = k;
}
/**
* Create a {@link Classifier} using kNN algorithm
*
* @param k the number of neighbors to analyze as an <code>int</code>
* @param minDocsFreq the minimum number of docs frequency for MLT to be set with {@link MoreLikeThis#setMinDocFreq(int)}
* @param minTermFreq the minimum number of term frequency for MLT to be set with {@link MoreLikeThis#setMinTermFreq(int)}
*/
public KNearestNeighborClassifier(int k, int minDocsFreq, int minTermFreq) {
this.k = k;
this.minDocsFreq = minDocsFreq;
this.minTermFreq = minTermFreq;
}
/**
* {@inheritDoc}
@ -136,12 +131,15 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
Map<BytesRef, Integer> classCounts = new HashMap<>();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
Integer count = classCounts.get(cl);
if (count != null) {
classCounts.put(cl, count + 1);
} else {
classCounts.put(cl, 1);
StorableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
if (storableField != null) {
BytesRef cl = new BytesRef(storableField.stringValue());
Integer count = classCounts.get(cl);
if (count != null) {
classCounts.put(cl, count + 1);
} else {
classCounts.put(cl, 1);
}
}
}
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
@ -161,39 +159,4 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
return returnList;
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
train(leafReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
mlt = new MoreLikeThis(leafReader);
mlt.setAnalyzer(analyzer);
mlt.setFieldNames(textFieldNames);
indexSearcher = new IndexSearcher(leafReader);
if (minDocsFreq > 0) {
mlt.setMinDocFreq(minDocsFreq);
}
if (minTermFreq > 0) {
mlt.setMinTermFreq(minTermFreq);
}
this.query = query;
}
}

View File

@ -51,64 +51,38 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
* index
*/
protected LeafReader leafReader;
protected final LeafReader leafReader;
/**
* names of the fields to be used as input text
*/
protected String[] textFieldNames;
protected final String[] textFieldNames;
/**
* name of the field to be used as a class / category output
*/
protected String classFieldName;
protected final String classFieldName;
/**
* {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text
*/
protected Analyzer analyzer;
protected final Analyzer analyzer;
/**
* {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies
*/
protected IndexSearcher indexSearcher;
protected final IndexSearcher indexSearcher;
/**
* {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify
*/
protected Query query;
protected final Query query;
/**
* Creates a new NaiveBayes classifier.
* Note that you must call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before you can
* classify any documents.
*/
public SimpleNaiveBayesClassifier() {
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
train(leafReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
throws IOException {
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
throws IOException {
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
this.leafReader = leafReader;
this.indexSearcher = new IndexSearcher(this.leafReader);
this.textFieldNames = textFieldNames;

View File

@ -18,7 +18,7 @@
/**
* Uses already seen data (the indexed documents) to classify new documents.
* <p>
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
* Neighbor classifier and a Perceptron based classifier.
*/
package org.apache.lucene.classification;

View File

@ -33,7 +33,8 @@ public class DocToDoubleVectorUtils {
/**
* create a sparse <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
* @param docTerms term vectors for a given document
*
* @param docTerms term vectors for a given document
* @param fieldTerms field term vectors
* @return a sparse vector of <code>Double</code>s as an array
* @throws IOException in case accessing the underlying index fails
@ -54,8 +55,7 @@ public class DocToDoubleVectorUtils {
if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) {
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
}
else {
} else {
freqVector[i] = 0d;
}
i++;
@ -66,6 +66,7 @@ public class DocToDoubleVectorUtils {
/**
* create a dense <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
*
* @param docTerms term vectors for a given document
* @return a dense vector of <code>Double</code>s as an array
* @throws IOException in case accessing the underlying index fails
@ -73,16 +74,16 @@ public class DocToDoubleVectorUtils {
public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms) throws IOException {
Double[] freqVector = null;
if (docTerms != null) {
freqVector = new Double[(int) docTerms.size()];
int i = 0;
TermsEnum docTermsEnum = docTerms.iterator();
freqVector = new Double[(int) docTerms.size()];
int i = 0;
TermsEnum docTermsEnum = docTerms.iterator();
while (docTermsEnum.next() != null) {
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
i++;
}
while (docTermsEnum.next() != null) {
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
i++;
}
}
return freqVector;
}
}
}

View File

@ -17,6 +17,8 @@
package org.apache.lucene.classification;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.junit.Test;
@ -28,22 +30,45 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
@Test
public void testBasicUsage() throws Exception {
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, null, 1, null), TECHNOLOGY_INPUT, false);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testExplicitThreshold() throws Exception {
checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, null, 1, 100d), TECHNOLOGY_INPUT, false);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName, new TermQuery(new Term(textFieldName, "it")));
}
@Test
public void testPerformance() throws Exception {
checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName);
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, query, 1, null), TECHNOLOGY_INPUT, false);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
}

View File

@ -23,6 +23,8 @@ import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.core.KeywordTokenizer;
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
@ -35,18 +37,46 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
@Test
public void testBasicUsage() throws Exception {
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
checkCorrectClassification(new CachingNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testNGramUsage() throws Exception {
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
LeafReader leafReader = null;
try {
NGramAnalyzer analyzer = new NGramAnalyzer();
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
private class NGramAnalyzer extends Analyzer {
@ -57,9 +87,4 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
}
}
@Test
public void testPerformance() throws Exception {
checkPerformance(new CachingNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
}
}

View File

@ -41,14 +41,14 @@ import org.junit.Before;
*/
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
" Truth is, Amazon may know more.";
" Truth is, Amazon may know more.";
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
private RandomIndexWriter indexWriter;
protected RandomIndexWriter indexWriter;
private Directory dir;
private FieldType ft;
@ -79,53 +79,34 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
dir.close();
}
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult) throws Exception {
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
double score = classificationResult.getScore();
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
}
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
LeafReader leafReader = null;
try {
populateSampleIndex(analyzer);
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(leafReader, textFieldName, classFieldName, analyzer, query);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
double score = classificationResult.getScore();
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
} finally {
if (leafReader != null)
leafReader.close();
}
}
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
checkOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
}
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
LeafReader leafReader = null;
try {
populateSampleIndex(analyzer);
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(leafReader, textFieldName, classFieldName, analyzer, query);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
double score = classificationResult.getScore();
assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
updateSampleIndex();
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
populateSampleIndex(analyzer);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
double score = classificationResult.getScore();
assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
updateSampleIndex();
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
} finally {
if (leafReader != null)
leafReader.close();
}
}
private void populateSampleIndex(Analyzer analyzer) throws IOException {
protected LeafReader populateSampleIndex(Analyzer analyzer) throws IOException {
indexWriter.close();
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
indexWriter.commit();
@ -134,8 +115,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
Document doc = new Document();
text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
"the Unknown Soldier in Warsaw Tuesday.";
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
"the Unknown Soldier in Warsaw Tuesday.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
@ -144,7 +125,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document();
text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
@ -152,8 +133,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document();
text = "And there's a threshold question that he has to answer for the American people and " +
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
@ -161,8 +142,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document();
text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
"Albany's School of Criminal Justice.";
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
"Albany's School of Criminal Justice.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
@ -170,8 +151,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document();
text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
"world through the Internet.";
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
"world through the Internet.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
@ -179,7 +160,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document();
text = "So, about all those experts and analysts who've spent the past year or so saying " +
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
@ -187,8 +168,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document();
text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
"generally transfer or store huge volumes of personal data online.";
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
"generally transfer or store huge volumes of personal data online.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
@ -200,22 +181,15 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
indexWriter.addDocument(doc);
indexWriter.commit();
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
}
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
LeafReader leafReader = null;
long trainStart = System.currentTimeMillis();
try {
populatePerformanceIndex(analyzer);
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(leafReader, textFieldName, classFieldName, analyzer);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
} finally {
if (leafReader != null)
leafReader.close();
}
populatePerformanceIndex(analyzer);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
}
private void populatePerformanceIndex(Analyzer analyzer) throws IOException {

View File

@ -17,6 +17,8 @@
package org.apache.lucene.classification;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
@ -29,20 +31,32 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
@Test
public void testBasicUsage() throws Exception {
// usage with default MLT min docs / term freq
checkCorrectClassification(new KNearestNeighborClassifier(3), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
// usage without custom min docs / term freq for MLT
checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
}
@Test
public void testPerformance() throws Exception {
checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
}

View File

@ -22,14 +22,13 @@ import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.core.KeywordTokenizer;
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.junit.Test;
import java.io.Reader;
/**
* Testcase for {@link SimpleNaiveBayesClassifier}
*/
@ -37,18 +36,46 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
@Test
public void testBasicUsage() throws Exception {
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
@Test
public void testNGramUsage() throws Exception {
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
LeafReader leafReader = null;
try {
Analyzer analyzer = new NGramAnalyzer();
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
}
private class NGramAnalyzer extends Analyzer {
@ -59,9 +86,4 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
}
}
@Test
public void testPerformance() throws Exception {
checkPerformance(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
}
}