LUCENE-6045 - refactor Classifier API to work better with multithreading

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1676997 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-04-30 14:12:03 +00:00
parent 7736e49b3e
commit 92842e7c34
13 changed files with 282 additions and 374 deletions

View File

@ -58,76 +58,14 @@ import org.apache.lucene.util.fst.Util;
*/ */
public class BooleanPerceptronClassifier implements Classifier<Boolean> { public class BooleanPerceptronClassifier implements Classifier<Boolean> {
private Double threshold; private final Double threshold;
private final Integer batchSize; private final Terms textTerms;
private Terms textTerms; private final Analyzer analyzer;
private Analyzer analyzer; private final String textFieldName;
private String textFieldName;
private FST<Long> fst; private FST<Long> fst;
/** public BooleanPerceptronClassifier(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer,
* Create a {@link BooleanPerceptronClassifier} Query query, Integer batchSize, Double threshold) throws IOException {
*
* @param threshold the binary threshold for perceptron output evaluation
*/
public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
this.threshold = threshold;
this.batchSize = batchSize;
}
/**
* Default constructor, no batch updates of FST, perceptron threshold is
* calculated via underlying index metrics during
* {@link #train(org.apache.lucene.index.LeafReader, String, String, org.apache.lucene.analysis.Analyzer)
* training}
*/
public BooleanPerceptronClassifier() {
batchSize = 1;
}
/**
* {@inheritDoc}
*/
@Override
public ClassificationResult<Boolean> assignClass(String text)
throws IOException {
if (textTerms == null) {
throw new IOException("You must first call Classifier#train");
}
Long output = 0l;
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream
.addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
String s = charTermAttribute.toString();
Long d = Util.get(fst, new BytesRef(s));
if (d != null) {
output += d;
}
}
tokenStream.end();
}
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
return new ClassificationResult<>(output >= threshold, score);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName,
String classFieldName, Analyzer analyzer) throws IOException {
train(leafReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName,
String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textTerms = MultiFields.getTerms(leafReader, textFieldName); this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
if (textTerms == null) { if (textTerms == null) {
@ -144,9 +82,11 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
this.threshold = (double) sumDocFreq / 2d; this.threshold = (double) sumDocFreq / 2d;
} else { } else {
throw new IOException( throw new IOException(
"threshold cannot be assigned since term vectors for field " "threshold cannot be assigned since term vectors for field "
+ textFieldName + " do not exist"); + textFieldName + " do not exist");
} }
} else {
this.threshold = threshold;
} }
// TODO : remove this map as soon as we have a writable FST // TODO : remove this map as soon as we have a writable FST
@ -170,7 +110,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
} }
// run the search and use stored field values // run the search and use stored field values
for (ScoreDoc scoreDoc : indexSearcher.search(q, for (ScoreDoc scoreDoc : indexSearcher.search(q,
Integer.MAX_VALUE).scoreDocs) { Integer.MAX_VALUE).scoreDocs) {
StoredDocument doc = indexSearcher.doc(scoreDoc.doc); StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
StorableField textField = doc.getField(textFieldName); StorableField textField = doc.getField(textFieldName);
@ -187,7 +127,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
long modifier = correctClass.compareTo(assignedClass); long modifier = correctClass.compareTo(assignedClass);
if (modifier != 0) { if (modifier != 0) {
updateWeights(leafReader, scoreDoc.doc, assignedClass, updateWeights(leafReader, scoreDoc.doc, assignedClass,
weights, modifier, batchCount % batchSize == 0); weights, modifier, batchCount % batchSize == 0);
} }
batchCount++; batchCount++;
} }
@ -195,11 +135,6 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
weights.clear(); // free memory while waiting for GC weights.clear(); // free memory while waiting for GC
} }
@Override
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
}
private void updateWeights(LeafReader leafReader, private void updateWeights(LeafReader leafReader,
int docId, Boolean assignedClass, SortedMap<String, Double> weights, int docId, Boolean assignedClass, SortedMap<String, Double> weights,
double modifier, boolean updateFST) throws IOException { double modifier, boolean updateFST) throws IOException {
@ -210,7 +145,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
if (terms == null) { if (terms == null) {
throw new IOException("term vectors must be stored for field " throw new IOException("term vectors must be stored for field "
+ textFieldName); + textFieldName);
} }
TermsEnum termsEnum = terms.iterator(); TermsEnum termsEnum = terms.iterator();
@ -240,17 +175,46 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
for (Map.Entry<String, Double> entry : weights.entrySet()) { for (Map.Entry<String, Double> entry : weights.entrySet()) {
scratchBytes.copyChars(entry.getKey()); scratchBytes.copyChars(entry.getKey());
fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
.getValue().longValue()); .getValue().longValue());
} }
fst = fstBuilder.finish(); fst = fstBuilder.finish();
} }
/**
* {@inheritDoc}
*/
@Override
public ClassificationResult<Boolean> assignClass(String text)
throws IOException {
if (textTerms == null) {
throw new IOException("You must first call Classifier#train");
}
Long output = 0l;
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream
.addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
String s = charTermAttribute.toString();
Long d = Util.get(fst, new BytesRef(s));
if (d != null) {
output += d;
}
}
tokenStream.end();
}
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
return new ClassificationResult<>(output >= threshold, score);
}
/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
@Override @Override
public List<ClassificationResult<Boolean>> getClasses(String text) public List<ClassificationResult<Boolean>> getClasses(String text)
throws IOException { throws IOException {
throw new RuntimeException("not implemented"); throw new RuntimeException("not implemented");
} }
@ -259,7 +223,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
*/ */
@Override @Override
public List<ClassificationResult<Boolean>> getClasses(String text, int max) public List<ClassificationResult<Boolean>> getClasses(String text, int max)
throws IOException { throws IOException {
throw new RuntimeException("not implemented"); throw new RuntimeException("not implemented");
} }

View File

@ -49,50 +49,30 @@ import org.apache.lucene.util.BytesRef;
*/ */
public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier { public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
//for caching classes this will be the classification class list //for caching classes this will be the classification class list
private ArrayList<BytesRef> cclasses = new ArrayList<>(); private final ArrayList<BytesRef> cclasses = new ArrayList<>();
// it's a term-inmap style map, where the inmap contains class-hit pairs to the // it's a term-inmap style map, where the inmap contains class-hit pairs to the
// upper term // upper term
private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>(); private final Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>();
// the term frequency in classes // the term frequency in classes
private Map<BytesRef, Double> classTermFreq = new HashMap<>(); private final Map<BytesRef, Double> classTermFreq = new HashMap<>();
private boolean justCachedTerms; private boolean justCachedTerms;
private int docsWithClassSize; private int docsWithClassSize;
/** /**
* Creates a new NaiveBayes classifier with inside caching. Note that you must * Creates a new NaiveBayes classifier with inside caching. If you want less memory usage you could
* call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before
* you can classify any documents. If you want less memory usage you could
* call {@link #reInitCache(int, boolean) reInitCache()}. * call {@link #reInitCache(int, boolean) reInitCache()}.
*/ */
public CachingNaiveBayesClassifier() { public CachingNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
} super(leafReader, analyzer, query, classFieldName, textFieldNames);
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
train(leafReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
super.train(leafReader, textFieldNames, classFieldName, analyzer, query);
// building the cache // building the cache
reInitCache(0, true); try {
reInitCache(0, true);
} catch (IOException e) {
throw new RuntimeException(e);
}
} }
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException { private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
if (leafReader == null) { if (leafReader == null) {
throw new IOException("You must first call Classifier#train"); throw new IOException("You must first call Classifier#train");

View File

@ -18,17 +18,19 @@ package org.apache.lucene.classification;
/** /**
* The result of a call to {@link Classifier#assignClass(String)} holding an assigned class of type <code>T</code> and a score. * The result of a call to {@link Classifier#assignClass(String)} holding an assigned class of type <code>T</code> and a score.
*
* @lucene.experimental * @lucene.experimental
*/ */
public class ClassificationResult<T> implements Comparable<ClassificationResult<T>>{ public class ClassificationResult<T> implements Comparable<ClassificationResult<T>> {
private final T assignedClass; private final T assignedClass;
private double score; private double score;
/** /**
* Constructor * Constructor
*
* @param assignedClass the class <code>T</code> assigned by a {@link Classifier} * @param assignedClass the class <code>T</code> assigned by a {@link Classifier}
* @param score the score for the assignedClass as a <code>double</code> * @param score the score for the assignedClass as a <code>double</code>
*/ */
public ClassificationResult(T assignedClass, double score) { public ClassificationResult(T assignedClass, double score) {
this.assignedClass = assignedClass; this.assignedClass = assignedClass;
@ -37,6 +39,7 @@ public class ClassificationResult<T> implements Comparable<ClassificationResult<
/** /**
* retrieve the result class * retrieve the result class
*
* @return a <code>T</code> representing an assigned class * @return a <code>T</code> representing an assigned class
*/ */
public T getAssignedClass() { public T getAssignedClass() {
@ -45,14 +48,16 @@ public class ClassificationResult<T> implements Comparable<ClassificationResult<
/** /**
* retrieve the result score * retrieve the result score
*
* @return a <code>double</code> representing a result score * @return a <code>double</code> representing a result score
*/ */
public double getScore() { public double getScore() {
return score; return score;
} }
/** /**
* set the score value * set the score value
*
* @param score the score for the assignedClass as a <code>double</code> * @param score the score for the assignedClass as a <code>double</code>
*/ */
public void setScore(double score) { public void setScore(double score) {

View File

@ -22,7 +22,6 @@ import java.util.List;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
/** /**
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type * A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
@ -39,7 +38,7 @@ public interface Classifier<T> {
* @return a {@link ClassificationResult} holding assigned class of type <code>T</code> and score * @return a {@link ClassificationResult} holding assigned class of type <code>T</code> and score
* @throws IOException If there is a low-level I/O error. * @throws IOException If there is a low-level I/O error.
*/ */
public ClassificationResult<T> assignClass(String text) throws IOException; ClassificationResult<T> assignClass(String text) throws IOException;
/** /**
* Get all the classes (sorted by score, descending) assigned to the given text String. * Get all the classes (sorted by score, descending) assigned to the given text String.
@ -48,7 +47,7 @@ public interface Classifier<T> {
* @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists. * @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
* @throws IOException If there is a low-level I/O error. * @throws IOException If there is a low-level I/O error.
*/ */
public List<ClassificationResult<T>> getClasses(String text) throws IOException; List<ClassificationResult<T>> getClasses(String text) throws IOException;
/** /**
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String. * Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
@ -58,44 +57,6 @@ public interface Classifier<T> {
* @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists. * @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
* @throws IOException If there is a low-level I/O error. * @throws IOException If there is a low-level I/O error.
*/ */
public List<ClassificationResult<T>> getClasses(String text, int max) throws IOException; List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
/**
* Train the classifier using the underlying Lucene index
*
* @param leafReader the reader to use to access the Lucene index
* @param textFieldName the name of the field used to compare documents
* @param classFieldName the name of the field containing the class assigned to documents
* @param analyzer the analyzer used to tokenize / filter the unseen text
* @throws IOException If there is a low-level I/O error.
*/
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer)
throws IOException;
/**
* Train the classifier using the underlying Lucene index
*
* @param leafReader the reader to use to access the Lucene index
* @param textFieldName the name of the field used to compare documents
* @param classFieldName the name of the field containing the class assigned to documents
* @param analyzer the analyzer used to tokenize / filter the unseen text
* @param query the query to filter which documents use for training
* @throws IOException If there is a low-level I/O error.
*/
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
throws IOException;
/**
* Train the classifier using the underlying Lucene index
*
* @param leafReader the reader to use to access the Lucene index
* @param textFieldNames the names of the fields to be used to compare documents
* @param classFieldName the name of the field containing the class assigned to documents
* @param analyzer the analyzer used to tokenize / filter the unseen text
* @param query the query to filter which documents use for training
* @throws IOException If there is a low-level I/O error.
*/
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
throws IOException;
} }

View File

@ -26,6 +26,7 @@ import java.util.Map;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.StorableField;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.queries.mlt.MoreLikeThis; import org.apache.lucene.queries.mlt.MoreLikeThis;
import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanClause;
@ -45,37 +46,31 @@ import org.apache.lucene.util.BytesRef;
*/ */
public class KNearestNeighborClassifier implements Classifier<BytesRef> { public class KNearestNeighborClassifier implements Classifier<BytesRef> {
private MoreLikeThis mlt; private final MoreLikeThis mlt;
private String[] textFieldNames; private final String[] textFieldNames;
private String classFieldName; private final String classFieldName;
private IndexSearcher indexSearcher; private final IndexSearcher indexSearcher;
private final int k; private final int k;
private Query query; private final Query query;
private int minDocsFreq; public KNearestNeighborClassifier(LeafReader leafReader, Analyzer analyzer, Query query, int k, int minDocsFreq,
private int minTermFreq; int minTermFreq, String classFieldName, String... textFieldNames) {
this.textFieldNames = textFieldNames;
/** this.classFieldName = classFieldName;
* Create a {@link Classifier} using kNN algorithm this.mlt = new MoreLikeThis(leafReader);
* this.mlt.setAnalyzer(analyzer);
* @param k the number of neighbors to analyze as an <code>int</code> this.mlt.setFieldNames(textFieldNames);
*/ this.indexSearcher = new IndexSearcher(leafReader);
public KNearestNeighborClassifier(int k) { if (minDocsFreq > 0) {
mlt.setMinDocFreq(minDocsFreq);
}
if (minTermFreq > 0) {
mlt.setMinTermFreq(minTermFreq);
}
this.query = query;
this.k = k; this.k = k;
} }
/**
* Create a {@link Classifier} using kNN algorithm
*
* @param k the number of neighbors to analyze as an <code>int</code>
* @param minDocsFreq the minimum number of docs frequency for MLT to be set with {@link MoreLikeThis#setMinDocFreq(int)}
* @param minTermFreq the minimum number of term frequency for MLT to be set with {@link MoreLikeThis#setMinTermFreq(int)}
*/
public KNearestNeighborClassifier(int k, int minDocsFreq, int minTermFreq) {
this.k = k;
this.minDocsFreq = minDocsFreq;
this.minTermFreq = minTermFreq;
}
/** /**
* {@inheritDoc} * {@inheritDoc}
@ -136,12 +131,15 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException { private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
Map<BytesRef, Integer> classCounts = new HashMap<>(); Map<BytesRef, Integer> classCounts = new HashMap<>();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) { for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue()); StorableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
Integer count = classCounts.get(cl); if (storableField != null) {
if (count != null) { BytesRef cl = new BytesRef(storableField.stringValue());
classCounts.put(cl, count + 1); Integer count = classCounts.get(cl);
} else { if (count != null) {
classCounts.put(cl, 1); classCounts.put(cl, count + 1);
} else {
classCounts.put(cl, 1);
}
} }
} }
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>(); List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
@ -161,39 +159,4 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
return returnList; return returnList;
} }
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
train(leafReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
mlt = new MoreLikeThis(leafReader);
mlt.setAnalyzer(analyzer);
mlt.setFieldNames(textFieldNames);
indexSearcher = new IndexSearcher(leafReader);
if (minDocsFreq > 0) {
mlt.setMinDocFreq(minDocsFreq);
}
if (minTermFreq > 0) {
mlt.setMinTermFreq(minTermFreq);
}
this.query = query;
}
} }

View File

@ -51,64 +51,38 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s * {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
* index * index
*/ */
protected LeafReader leafReader; protected final LeafReader leafReader;
/** /**
* names of the fields to be used as input text * names of the fields to be used as input text
*/ */
protected String[] textFieldNames; protected final String[] textFieldNames;
/** /**
* name of the field to be used as a class / category output * name of the field to be used as a class / category output
*/ */
protected String classFieldName; protected final String classFieldName;
/** /**
* {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text * {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text
*/ */
protected Analyzer analyzer; protected final Analyzer analyzer;
/** /**
* {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies * {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies
*/ */
protected IndexSearcher indexSearcher; protected final IndexSearcher indexSearcher;
/** /**
* {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify * {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify
*/ */
protected Query query; protected final Query query;
/** /**
* Creates a new NaiveBayes classifier. * Creates a new NaiveBayes classifier.
* Note that you must call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before you can
* classify any documents. * classify any documents.
*/ */
public SimpleNaiveBayesClassifier() { public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
train(leafReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
throws IOException {
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
}
/**
* {@inheritDoc}
*/
@Override
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
throws IOException {
this.leafReader = leafReader; this.leafReader = leafReader;
this.indexSearcher = new IndexSearcher(this.leafReader); this.indexSearcher = new IndexSearcher(this.leafReader);
this.textFieldNames = textFieldNames; this.textFieldNames = textFieldNames;

View File

@ -18,7 +18,7 @@
/** /**
* Uses already seen data (the indexed documents) to classify new documents. * Uses already seen data (the indexed documents) to classify new documents.
* <p> * <p>
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
* Neighbor classifier and a Perceptron based classifier. * Neighbor classifier and a Perceptron based classifier.
*/ */
package org.apache.lucene.classification; package org.apache.lucene.classification;

View File

@ -33,7 +33,8 @@ public class DocToDoubleVectorUtils {
/** /**
* create a sparse <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc * create a sparse <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
* @param docTerms term vectors for a given document *
* @param docTerms term vectors for a given document
* @param fieldTerms field term vectors * @param fieldTerms field term vectors
* @return a sparse vector of <code>Double</code>s as an array * @return a sparse vector of <code>Double</code>s as an array
* @throws IOException in case accessing the underlying index fails * @throws IOException in case accessing the underlying index fails
@ -54,8 +55,7 @@ public class DocToDoubleVectorUtils {
if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) { if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) {
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue(); freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
} } else {
else {
freqVector[i] = 0d; freqVector[i] = 0d;
} }
i++; i++;
@ -66,6 +66,7 @@ public class DocToDoubleVectorUtils {
/** /**
* create a dense <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc * create a dense <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
*
* @param docTerms term vectors for a given document * @param docTerms term vectors for a given document
* @return a dense vector of <code>Double</code>s as an array * @return a dense vector of <code>Double</code>s as an array
* @throws IOException in case accessing the underlying index fails * @throws IOException in case accessing the underlying index fails
@ -73,16 +74,16 @@ public class DocToDoubleVectorUtils {
public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms) throws IOException { public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms) throws IOException {
Double[] freqVector = null; Double[] freqVector = null;
if (docTerms != null) { if (docTerms != null) {
freqVector = new Double[(int) docTerms.size()]; freqVector = new Double[(int) docTerms.size()];
int i = 0; int i = 0;
TermsEnum docTermsEnum = docTerms.iterator(); TermsEnum docTermsEnum = docTerms.iterator();
while (docTermsEnum.next() != null) { while (docTermsEnum.next() != null) {
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue(); freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
i++; i++;
} }
} }
return freqVector; return freqVector;
} }
} }

View File

@ -17,6 +17,8 @@
package org.apache.lucene.classification; package org.apache.lucene.classification;
import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.junit.Test; import org.junit.Test;
@ -28,22 +30,45 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
@Test @Test
public void testBasicUsage() throws Exception { public void testBasicUsage() throws Exception {
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName); LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, null, 1, null), TECHNOLOGY_INPUT, false);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
@Test @Test
public void testExplicitThreshold() throws Exception { public void testExplicitThreshold() throws Exception {
checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName); LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, null, 1, 100d), TECHNOLOGY_INPUT, false);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
@Test @Test
public void testBasicUsageWithQuery() throws Exception { public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName, new TermQuery(new Term(textFieldName, "it"))); TermQuery query = new TermQuery(new Term(textFieldName, "it"));
} LeafReader leafReader = null;
try {
@Test MockAnalyzer analyzer = new MockAnalyzer(random());
public void testPerformance() throws Exception { leafReader = populateSampleIndex(analyzer);
checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName); checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, query, 1, null), TECHNOLOGY_INPUT, false);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
} }

View File

@ -23,6 +23,8 @@ import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.core.KeywordTokenizer; import org.apache.lucene.analysis.core.KeywordTokenizer;
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter; import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.apache.lucene.analysis.reverse.ReverseStringFilter; import org.apache.lucene.analysis.reverse.ReverseStringFilter;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
@ -35,18 +37,46 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
@Test @Test
public void testBasicUsage() throws Exception { public void testBasicUsage() throws Exception {
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName); LeafReader leafReader = null;
checkCorrectClassification(new CachingNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName); try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
@Test @Test
public void testBasicUsageWithQuery() throws Exception { public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it"))); LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
@Test @Test
public void testNGramUsage() throws Exception { public void testNGramUsage() throws Exception {
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName); LeafReader leafReader = null;
try {
NGramAnalyzer analyzer = new NGramAnalyzer();
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
private class NGramAnalyzer extends Analyzer { private class NGramAnalyzer extends Analyzer {
@ -57,9 +87,4 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
} }
} }
@Test
public void testPerformance() throws Exception {
checkPerformance(new CachingNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
}
} }

View File

@ -41,14 +41,14 @@ import org.junit.Before;
*/ */
public abstract class ClassificationTestBase<T> extends LuceneTestCase { public abstract class ClassificationTestBase<T> extends LuceneTestCase {
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " + public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section."; "If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
public static final BytesRef POLITICS_RESULT = new BytesRef("politics"); public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." + public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
" Truth is, Amazon may know more."; " Truth is, Amazon may know more.";
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology"); public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
private RandomIndexWriter indexWriter; protected RandomIndexWriter indexWriter;
private Directory dir; private Directory dir;
private FieldType ft; private FieldType ft;
@ -79,53 +79,34 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
dir.close(); dir.close();
} }
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception { protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult) throws Exception {
checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null); ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
double score = classificationResult.getScore();
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
} }
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
LeafReader leafReader = null;
try {
populateSampleIndex(analyzer);
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(leafReader, textFieldName, classFieldName, analyzer, query);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
double score = classificationResult.getScore();
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
} finally {
if (leafReader != null)
leafReader.close();
}
}
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception { protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
checkOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null); checkOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
} }
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception { protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
LeafReader leafReader = null; populateSampleIndex(analyzer);
try {
populateSampleIndex(analyzer); ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); assertNotNull(classificationResult.getAssignedClass());
classifier.train(leafReader, textFieldName, classFieldName, analyzer, query); assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc); double score = classificationResult.getScore();
assertNotNull(classificationResult.getAssignedClass()); assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); updateSampleIndex();
double score = classificationResult.getScore(); ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0); assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
updateSampleIndex(); assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
} finally {
if (leafReader != null)
leafReader.close();
}
} }
private void populateSampleIndex(Analyzer analyzer) throws IOException { protected LeafReader populateSampleIndex(Analyzer analyzer) throws IOException {
indexWriter.close(); indexWriter.close();
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE)); indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
indexWriter.commit(); indexWriter.commit();
@ -134,8 +115,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
Document doc = new Document(); Document doc = new Document();
text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " + text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " + "who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
"the Unknown Soldier in Warsaw Tuesday."; "the Unknown Soldier in Warsaw Tuesday.";
doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft)); doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft)); doc.add(new Field(booleanFieldName, "true", ft));
@ -144,7 +125,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document(); doc = new Document();
text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" + text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama."; " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft)); doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft)); doc.add(new Field(booleanFieldName, "true", ft));
@ -152,8 +133,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document(); doc = new Document();
text = "And there's a threshold question that he has to answer for the American people and " + text = "And there's a threshold question that he has to answer for the American people and " +
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " + "that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\""; "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft)); doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft)); doc.add(new Field(booleanFieldName, "true", ft));
@ -161,8 +142,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document(); doc = new Document();
text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " + text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " + "keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
"Albany's School of Criminal Justice."; "Albany's School of Criminal Justice.";
doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft)); doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft)); doc.add(new Field(booleanFieldName, "true", ft));
@ -170,8 +151,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document(); doc = new Document();
text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " + text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " + "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
"world through the Internet."; "world through the Internet.";
doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft)); doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft)); doc.add(new Field(booleanFieldName, "false", ft));
@ -179,7 +160,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document(); doc = new Document();
text = "So, about all those experts and analysts who've spent the past year or so saying " + text = "So, about all those experts and analysts who've spent the past year or so saying " +
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen."; "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft)); doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft)); doc.add(new Field(booleanFieldName, "false", ft));
@ -187,8 +168,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
doc = new Document(); doc = new Document();
text = "More than 400 million people trust Google with their e-mail, and 50 million store files" + text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " + " in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
"generally transfer or store huge volumes of personal data online."; "generally transfer or store huge volumes of personal data online.";
doc.add(new Field(textFieldName, text, ft)); doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft)); doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft)); doc.add(new Field(booleanFieldName, "false", ft));
@ -200,22 +181,15 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
indexWriter.addDocument(doc); indexWriter.addDocument(doc);
indexWriter.commit(); indexWriter.commit();
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
} }
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception { protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
LeafReader leafReader = null;
long trainStart = System.currentTimeMillis(); long trainStart = System.currentTimeMillis();
try { populatePerformanceIndex(analyzer);
populatePerformanceIndex(analyzer); long trainEnd = System.currentTimeMillis();
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); long trainTime = trainEnd - trainStart;
classifier.train(leafReader, textFieldName, classFieldName, analyzer); assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
} finally {
if (leafReader != null)
leafReader.close();
}
} }
private void populatePerformanceIndex(Analyzer analyzer) throws IOException { private void populatePerformanceIndex(Analyzer analyzer) throws IOException {

View File

@ -17,6 +17,8 @@
package org.apache.lucene.classification; package org.apache.lucene.classification;
import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
@ -29,20 +31,32 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
@Test @Test
public void testBasicUsage() throws Exception { public void testBasicUsage() throws Exception {
// usage with default MLT min docs / term freq LeafReader leafReader = null;
checkCorrectClassification(new KNearestNeighborClassifier(3), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName); try {
// usage without custom min docs / term freq for MLT MockAnalyzer analyzer = new MockAnalyzer(random());
checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName); leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
@Test @Test
public void testBasicUsageWithQuery() throws Exception { public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it"))); LeafReader leafReader = null;
} try {
MockAnalyzer analyzer = new MockAnalyzer(random());
@Test leafReader = populateSampleIndex(analyzer);
public void testPerformance() throws Exception { TermQuery query = new TermQuery(new Term(textFieldName, "it"));
checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName); checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
} }

View File

@ -22,14 +22,13 @@ import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.core.KeywordTokenizer; import org.apache.lucene.analysis.core.KeywordTokenizer;
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter; import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.apache.lucene.analysis.reverse.ReverseStringFilter; import org.apache.lucene.analysis.reverse.ReverseStringFilter;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.junit.Test; import org.junit.Test;
import java.io.Reader;
/** /**
* Testcase for {@link SimpleNaiveBayesClassifier} * Testcase for {@link SimpleNaiveBayesClassifier}
*/ */
@ -37,18 +36,46 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
@Test @Test
public void testBasicUsage() throws Exception { public void testBasicUsage() throws Exception {
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName); LeafReader leafReader = null;
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName); try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
@Test @Test
public void testBasicUsageWithQuery() throws Exception { public void testBasicUsageWithQuery() throws Exception {
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it"))); LeafReader leafReader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
leafReader = populateSampleIndex(analyzer);
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
@Test @Test
public void testNGramUsage() throws Exception { public void testNGramUsage() throws Exception {
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName); LeafReader leafReader = null;
try {
Analyzer analyzer = new NGramAnalyzer();
leafReader = populateSampleIndex(analyzer);
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
} finally {
if (leafReader != null) {
leafReader.close();
}
}
} }
private class NGramAnalyzer extends Analyzer { private class NGramAnalyzer extends Analyzer {
@ -59,9 +86,4 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
} }
} }
@Test
public void testPerformance() throws Exception {
checkPerformance(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
}
} }