mirror of https://github.com/apache/lucene.git
LUCENE-6045 - refactor Classifier API to work better with multithreading
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1676997 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
7736e49b3e
commit
92842e7c34
|
@ -58,76 +58,14 @@ import org.apache.lucene.util.fst.Util;
|
|||
*/
|
||||
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||
|
||||
private Double threshold;
|
||||
private final Integer batchSize;
|
||||
private Terms textTerms;
|
||||
private Analyzer analyzer;
|
||||
private String textFieldName;
|
||||
private final Double threshold;
|
||||
private final Terms textTerms;
|
||||
private final Analyzer analyzer;
|
||||
private final String textFieldName;
|
||||
private FST<Long> fst;
|
||||
|
||||
/**
|
||||
* Create a {@link BooleanPerceptronClassifier}
|
||||
*
|
||||
* @param threshold the binary threshold for perceptron output evaluation
|
||||
*/
|
||||
public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
|
||||
this.threshold = threshold;
|
||||
this.batchSize = batchSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Default constructor, no batch updates of FST, perceptron threshold is
|
||||
* calculated via underlying index metrics during
|
||||
* {@link #train(org.apache.lucene.index.LeafReader, String, String, org.apache.lucene.analysis.Analyzer)
|
||||
* training}
|
||||
*/
|
||||
public BooleanPerceptronClassifier() {
|
||||
batchSize = 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public ClassificationResult<Boolean> assignClass(String text)
|
||||
throws IOException {
|
||||
if (textTerms == null) {
|
||||
throw new IOException("You must first call Classifier#train");
|
||||
}
|
||||
Long output = 0l;
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
||||
CharTermAttribute charTermAttribute = tokenStream
|
||||
.addAttribute(CharTermAttribute.class);
|
||||
tokenStream.reset();
|
||||
while (tokenStream.incrementToken()) {
|
||||
String s = charTermAttribute.toString();
|
||||
Long d = Util.get(fst, new BytesRef(s));
|
||||
if (d != null) {
|
||||
output += d;
|
||||
}
|
||||
}
|
||||
tokenStream.end();
|
||||
}
|
||||
|
||||
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
|
||||
return new ClassificationResult<>(output >= threshold, score);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String textFieldName,
|
||||
String classFieldName, Analyzer analyzer) throws IOException {
|
||||
train(leafReader, textFieldName, classFieldName, analyzer, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String textFieldName,
|
||||
String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
public BooleanPerceptronClassifier(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer,
|
||||
Query query, Integer batchSize, Double threshold) throws IOException {
|
||||
this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
|
||||
|
||||
if (textTerms == null) {
|
||||
|
@ -144,9 +82,11 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
this.threshold = (double) sumDocFreq / 2d;
|
||||
} else {
|
||||
throw new IOException(
|
||||
"threshold cannot be assigned since term vectors for field "
|
||||
+ textFieldName + " do not exist");
|
||||
"threshold cannot be assigned since term vectors for field "
|
||||
+ textFieldName + " do not exist");
|
||||
}
|
||||
} else {
|
||||
this.threshold = threshold;
|
||||
}
|
||||
|
||||
// TODO : remove this map as soon as we have a writable FST
|
||||
|
@ -170,7 +110,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
}
|
||||
// run the search and use stored field values
|
||||
for (ScoreDoc scoreDoc : indexSearcher.search(q,
|
||||
Integer.MAX_VALUE).scoreDocs) {
|
||||
Integer.MAX_VALUE).scoreDocs) {
|
||||
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
|
||||
|
||||
StorableField textField = doc.getField(textFieldName);
|
||||
|
@ -187,7 +127,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
long modifier = correctClass.compareTo(assignedClass);
|
||||
if (modifier != 0) {
|
||||
updateWeights(leafReader, scoreDoc.doc, assignedClass,
|
||||
weights, modifier, batchCount % batchSize == 0);
|
||||
weights, modifier, batchCount % batchSize == 0);
|
||||
}
|
||||
batchCount++;
|
||||
}
|
||||
|
@ -195,11 +135,6 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
weights.clear(); // free memory while waiting for GC
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
|
||||
}
|
||||
|
||||
private void updateWeights(LeafReader leafReader,
|
||||
int docId, Boolean assignedClass, SortedMap<String, Double> weights,
|
||||
double modifier, boolean updateFST) throws IOException {
|
||||
|
@ -210,7 +145,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
|
||||
if (terms == null) {
|
||||
throw new IOException("term vectors must be stored for field "
|
||||
+ textFieldName);
|
||||
+ textFieldName);
|
||||
}
|
||||
|
||||
TermsEnum termsEnum = terms.iterator();
|
||||
|
@ -240,17 +175,46 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
for (Map.Entry<String, Double> entry : weights.entrySet()) {
|
||||
scratchBytes.copyChars(entry.getKey());
|
||||
fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
|
||||
.getValue().longValue());
|
||||
.getValue().longValue());
|
||||
}
|
||||
fst = fstBuilder.finish();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public ClassificationResult<Boolean> assignClass(String text)
|
||||
throws IOException {
|
||||
if (textTerms == null) {
|
||||
throw new IOException("You must first call Classifier#train");
|
||||
}
|
||||
Long output = 0l;
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
||||
CharTermAttribute charTermAttribute = tokenStream
|
||||
.addAttribute(CharTermAttribute.class);
|
||||
tokenStream.reset();
|
||||
while (tokenStream.incrementToken()) {
|
||||
String s = charTermAttribute.toString();
|
||||
Long d = Util.get(fst, new BytesRef(s));
|
||||
if (d != null) {
|
||||
output += d;
|
||||
}
|
||||
}
|
||||
tokenStream.end();
|
||||
}
|
||||
|
||||
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
|
||||
return new ClassificationResult<>(output >= threshold, score);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<Boolean>> getClasses(String text)
|
||||
throws IOException {
|
||||
throws IOException {
|
||||
throw new RuntimeException("not implemented");
|
||||
}
|
||||
|
||||
|
@ -259,7 +223,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<Boolean>> getClasses(String text, int max)
|
||||
throws IOException {
|
||||
throws IOException {
|
||||
throw new RuntimeException("not implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -49,50 +49,30 @@ import org.apache.lucene.util.BytesRef;
|
|||
*/
|
||||
public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
||||
//for caching classes this will be the classification class list
|
||||
private ArrayList<BytesRef> cclasses = new ArrayList<>();
|
||||
private final ArrayList<BytesRef> cclasses = new ArrayList<>();
|
||||
// it's a term-inmap style map, where the inmap contains class-hit pairs to the
|
||||
// upper term
|
||||
private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>();
|
||||
private final Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>();
|
||||
// the term frequency in classes
|
||||
private Map<BytesRef, Double> classTermFreq = new HashMap<>();
|
||||
private final Map<BytesRef, Double> classTermFreq = new HashMap<>();
|
||||
private boolean justCachedTerms;
|
||||
private int docsWithClassSize;
|
||||
|
||||
/**
|
||||
* Creates a new NaiveBayes classifier with inside caching. Note that you must
|
||||
* call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before
|
||||
* you can classify any documents. If you want less memory usage you could
|
||||
* Creates a new NaiveBayes classifier with inside caching. If you want less memory usage you could
|
||||
* call {@link #reInitCache(int, boolean) reInitCache()}.
|
||||
*/
|
||||
public CachingNaiveBayesClassifier() {
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
||||
train(leafReader, textFieldName, classFieldName, analyzer, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
super.train(leafReader, textFieldNames, classFieldName, analyzer, query);
|
||||
public CachingNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
||||
super(leafReader, analyzer, query, classFieldName, textFieldNames);
|
||||
// building the cache
|
||||
reInitCache(0, true);
|
||||
try {
|
||||
reInitCache(0, true);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||
if (leafReader == null) {
|
||||
throw new IOException("You must first call Classifier#train");
|
||||
|
|
|
@ -18,17 +18,19 @@ package org.apache.lucene.classification;
|
|||
|
||||
/**
|
||||
* The result of a call to {@link Classifier#assignClass(String)} holding an assigned class of type <code>T</code> and a score.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class ClassificationResult<T> implements Comparable<ClassificationResult<T>>{
|
||||
public class ClassificationResult<T> implements Comparable<ClassificationResult<T>> {
|
||||
|
||||
private final T assignedClass;
|
||||
private double score;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param assignedClass the class <code>T</code> assigned by a {@link Classifier}
|
||||
* @param score the score for the assignedClass as a <code>double</code>
|
||||
* @param score the score for the assignedClass as a <code>double</code>
|
||||
*/
|
||||
public ClassificationResult(T assignedClass, double score) {
|
||||
this.assignedClass = assignedClass;
|
||||
|
@ -37,6 +39,7 @@ public class ClassificationResult<T> implements Comparable<ClassificationResult<
|
|||
|
||||
/**
|
||||
* retrieve the result class
|
||||
*
|
||||
* @return a <code>T</code> representing an assigned class
|
||||
*/
|
||||
public T getAssignedClass() {
|
||||
|
@ -45,14 +48,16 @@ public class ClassificationResult<T> implements Comparable<ClassificationResult<
|
|||
|
||||
/**
|
||||
* retrieve the result score
|
||||
*
|
||||
* @return a <code>double</code> representing a result score
|
||||
*/
|
||||
public double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* set the score value
|
||||
*
|
||||
* @param score the score for the assignedClass as a <code>double</code>
|
||||
*/
|
||||
public void setScore(double score) {
|
||||
|
|
|
@ -22,7 +22,6 @@ import java.util.List;
|
|||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
|
||||
|
@ -39,7 +38,7 @@ public interface Classifier<T> {
|
|||
* @return a {@link ClassificationResult} holding assigned class of type <code>T</code> and score
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public ClassificationResult<T> assignClass(String text) throws IOException;
|
||||
ClassificationResult<T> assignClass(String text) throws IOException;
|
||||
|
||||
/**
|
||||
* Get all the classes (sorted by score, descending) assigned to the given text String.
|
||||
|
@ -48,7 +47,7 @@ public interface Classifier<T> {
|
|||
* @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public List<ClassificationResult<T>> getClasses(String text) throws IOException;
|
||||
List<ClassificationResult<T>> getClasses(String text) throws IOException;
|
||||
|
||||
/**
|
||||
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
|
||||
|
@ -58,44 +57,6 @@ public interface Classifier<T> {
|
|||
* @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
|
||||
|
||||
/**
|
||||
* Train the classifier using the underlying Lucene index
|
||||
*
|
||||
* @param leafReader the reader to use to access the Lucene index
|
||||
* @param textFieldName the name of the field used to compare documents
|
||||
* @param classFieldName the name of the field containing the class assigned to documents
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer)
|
||||
throws IOException;
|
||||
|
||||
/**
|
||||
* Train the classifier using the underlying Lucene index
|
||||
*
|
||||
* @param leafReader the reader to use to access the Lucene index
|
||||
* @param textFieldName the name of the field used to compare documents
|
||||
* @param classFieldName the name of the field containing the class assigned to documents
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @param query the query to filter which documents use for training
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
||||
throws IOException;
|
||||
|
||||
/**
|
||||
* Train the classifier using the underlying Lucene index
|
||||
*
|
||||
* @param leafReader the reader to use to access the Lucene index
|
||||
* @param textFieldNames the names of the fields to be used to compare documents
|
||||
* @param classFieldName the name of the field containing the class assigned to documents
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @param query the query to filter which documents use for training
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
|
||||
throws IOException;
|
||||
List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
|
||||
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import java.util.Map;
|
|||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.StorableField;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.queries.mlt.MoreLikeThis;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
|
@ -45,37 +46,31 @@ import org.apache.lucene.util.BytesRef;
|
|||
*/
|
||||
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||
|
||||
private MoreLikeThis mlt;
|
||||
private String[] textFieldNames;
|
||||
private String classFieldName;
|
||||
private IndexSearcher indexSearcher;
|
||||
private final MoreLikeThis mlt;
|
||||
private final String[] textFieldNames;
|
||||
private final String classFieldName;
|
||||
private final IndexSearcher indexSearcher;
|
||||
private final int k;
|
||||
private Query query;
|
||||
private final Query query;
|
||||
|
||||
private int minDocsFreq;
|
||||
private int minTermFreq;
|
||||
|
||||
/**
|
||||
* Create a {@link Classifier} using kNN algorithm
|
||||
*
|
||||
* @param k the number of neighbors to analyze as an <code>int</code>
|
||||
*/
|
||||
public KNearestNeighborClassifier(int k) {
|
||||
public KNearestNeighborClassifier(LeafReader leafReader, Analyzer analyzer, Query query, int k, int minDocsFreq,
|
||||
int minTermFreq, String classFieldName, String... textFieldNames) {
|
||||
this.textFieldNames = textFieldNames;
|
||||
this.classFieldName = classFieldName;
|
||||
this.mlt = new MoreLikeThis(leafReader);
|
||||
this.mlt.setAnalyzer(analyzer);
|
||||
this.mlt.setFieldNames(textFieldNames);
|
||||
this.indexSearcher = new IndexSearcher(leafReader);
|
||||
if (minDocsFreq > 0) {
|
||||
mlt.setMinDocFreq(minDocsFreq);
|
||||
}
|
||||
if (minTermFreq > 0) {
|
||||
mlt.setMinTermFreq(minTermFreq);
|
||||
}
|
||||
this.query = query;
|
||||
this.k = k;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link Classifier} using kNN algorithm
|
||||
*
|
||||
* @param k the number of neighbors to analyze as an <code>int</code>
|
||||
* @param minDocsFreq the minimum number of docs frequency for MLT to be set with {@link MoreLikeThis#setMinDocFreq(int)}
|
||||
* @param minTermFreq the minimum number of term frequency for MLT to be set with {@link MoreLikeThis#setMinTermFreq(int)}
|
||||
*/
|
||||
public KNearestNeighborClassifier(int k, int minDocsFreq, int minTermFreq) {
|
||||
this.k = k;
|
||||
this.minDocsFreq = minDocsFreq;
|
||||
this.minTermFreq = minTermFreq;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
|
@ -136,12 +131,15 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
|
||||
Map<BytesRef, Integer> classCounts = new HashMap<>();
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
|
||||
Integer count = classCounts.get(cl);
|
||||
if (count != null) {
|
||||
classCounts.put(cl, count + 1);
|
||||
} else {
|
||||
classCounts.put(cl, 1);
|
||||
StorableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
|
||||
if (storableField != null) {
|
||||
BytesRef cl = new BytesRef(storableField.stringValue());
|
||||
Integer count = classCounts.get(cl);
|
||||
if (count != null) {
|
||||
classCounts.put(cl, count + 1);
|
||||
} else {
|
||||
classCounts.put(cl, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||
|
@ -161,39 +159,4 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
return returnList;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
||||
train(leafReader, textFieldName, classFieldName, analyzer, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
this.textFieldNames = textFieldNames;
|
||||
this.classFieldName = classFieldName;
|
||||
mlt = new MoreLikeThis(leafReader);
|
||||
mlt.setAnalyzer(analyzer);
|
||||
mlt.setFieldNames(textFieldNames);
|
||||
indexSearcher = new IndexSearcher(leafReader);
|
||||
if (minDocsFreq > 0) {
|
||||
mlt.setMinDocFreq(minDocsFreq);
|
||||
}
|
||||
if (minTermFreq > 0) {
|
||||
mlt.setMinTermFreq(minTermFreq);
|
||||
}
|
||||
this.query = query;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,64 +51,38 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
* {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
|
||||
* index
|
||||
*/
|
||||
protected LeafReader leafReader;
|
||||
protected final LeafReader leafReader;
|
||||
|
||||
/**
|
||||
* names of the fields to be used as input text
|
||||
*/
|
||||
protected String[] textFieldNames;
|
||||
protected final String[] textFieldNames;
|
||||
|
||||
/**
|
||||
* name of the field to be used as a class / category output
|
||||
*/
|
||||
protected String classFieldName;
|
||||
protected final String classFieldName;
|
||||
|
||||
/**
|
||||
* {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text
|
||||
*/
|
||||
protected Analyzer analyzer;
|
||||
protected final Analyzer analyzer;
|
||||
|
||||
/**
|
||||
* {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies
|
||||
*/
|
||||
protected IndexSearcher indexSearcher;
|
||||
protected final IndexSearcher indexSearcher;
|
||||
|
||||
/**
|
||||
* {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify
|
||||
*/
|
||||
protected Query query;
|
||||
protected final Query query;
|
||||
|
||||
/**
|
||||
* Creates a new NaiveBayes classifier.
|
||||
* Note that you must call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before you can
|
||||
* classify any documents.
|
||||
*/
|
||||
public SimpleNaiveBayesClassifier() {
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
||||
train(leafReader, textFieldName, classFieldName, analyzer, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
||||
throws IOException {
|
||||
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
|
||||
throws IOException {
|
||||
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
||||
this.leafReader = leafReader;
|
||||
this.indexSearcher = new IndexSearcher(this.leafReader);
|
||||
this.textFieldNames = textFieldNames;
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
/**
|
||||
* Uses already seen data (the indexed documents) to classify new documents.
|
||||
* <p>
|
||||
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
|
||||
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
|
||||
* Neighbor classifier and a Perceptron based classifier.
|
||||
*/
|
||||
package org.apache.lucene.classification;
|
||||
|
|
|
@ -33,7 +33,8 @@ public class DocToDoubleVectorUtils {
|
|||
|
||||
/**
|
||||
* create a sparse <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
|
||||
* @param docTerms term vectors for a given document
|
||||
*
|
||||
* @param docTerms term vectors for a given document
|
||||
* @param fieldTerms field term vectors
|
||||
* @return a sparse vector of <code>Double</code>s as an array
|
||||
* @throws IOException in case accessing the underlying index fails
|
||||
|
@ -54,8 +55,7 @@ public class DocToDoubleVectorUtils {
|
|||
if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) {
|
||||
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
|
||||
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
freqVector[i] = 0d;
|
||||
}
|
||||
i++;
|
||||
|
@ -66,6 +66,7 @@ public class DocToDoubleVectorUtils {
|
|||
|
||||
/**
|
||||
* create a dense <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
|
||||
*
|
||||
* @param docTerms term vectors for a given document
|
||||
* @return a dense vector of <code>Double</code>s as an array
|
||||
* @throws IOException in case accessing the underlying index fails
|
||||
|
@ -73,16 +74,16 @@ public class DocToDoubleVectorUtils {
|
|||
public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms) throws IOException {
|
||||
Double[] freqVector = null;
|
||||
if (docTerms != null) {
|
||||
freqVector = new Double[(int) docTerms.size()];
|
||||
int i = 0;
|
||||
TermsEnum docTermsEnum = docTerms.iterator();
|
||||
freqVector = new Double[(int) docTerms.size()];
|
||||
int i = 0;
|
||||
TermsEnum docTermsEnum = docTerms.iterator();
|
||||
|
||||
while (docTermsEnum.next() != null) {
|
||||
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
|
||||
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
|
||||
i++;
|
||||
}
|
||||
while (docTermsEnum.next() != null) {
|
||||
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
|
||||
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
|
||||
i++;
|
||||
}
|
||||
}
|
||||
return freqVector;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
package org.apache.lucene.classification;
|
||||
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.junit.Test;
|
||||
|
@ -28,22 +30,45 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
|||
|
||||
@Test
|
||||
public void testBasicUsage() throws Exception {
|
||||
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, null, 1, null), TECHNOLOGY_INPUT, false);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExplicitThreshold() throws Exception {
|
||||
checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, null, 1, 100d), TECHNOLOGY_INPUT, false);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicUsageWithQuery() throws Exception {
|
||||
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName, new TermQuery(new Term(textFieldName, "it")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName);
|
||||
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, query, 1, null), TECHNOLOGY_INPUT, false);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -23,6 +23,8 @@ import org.apache.lucene.analysis.Tokenizer;
|
|||
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
||||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
@ -35,18 +37,46 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
|
|||
|
||||
@Test
|
||||
public void testBasicUsage() throws Exception {
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicUsageWithQuery() throws Exception {
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNGramUsage() throws Exception {
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
NGramAnalyzer analyzer = new NGramAnalyzer();
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class NGramAnalyzer extends Analyzer {
|
||||
|
@ -57,9 +87,4 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
checkPerformance(new CachingNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -41,14 +41,14 @@ import org.junit.Before;
|
|||
*/
|
||||
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
|
||||
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
|
||||
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
|
||||
public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
|
||||
|
||||
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
|
||||
" Truth is, Amazon may know more.";
|
||||
" Truth is, Amazon may know more.";
|
||||
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
||||
|
||||
private RandomIndexWriter indexWriter;
|
||||
protected RandomIndexWriter indexWriter;
|
||||
private Directory dir;
|
||||
private FieldType ft;
|
||||
|
||||
|
@ -79,53 +79,34 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
dir.close();
|
||||
}
|
||||
|
||||
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
|
||||
checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
|
||||
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult) throws Exception {
|
||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||
assertNotNull(classificationResult.getAssignedClass());
|
||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||
double score = classificationResult.getScore();
|
||||
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
|
||||
}
|
||||
|
||||
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
populateSampleIndex(analyzer);
|
||||
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||
classifier.train(leafReader, textFieldName, classFieldName, analyzer, query);
|
||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||
assertNotNull(classificationResult.getAssignedClass());
|
||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||
double score = classificationResult.getScore();
|
||||
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
|
||||
} finally {
|
||||
if (leafReader != null)
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
|
||||
checkOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
|
||||
}
|
||||
|
||||
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
populateSampleIndex(analyzer);
|
||||
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||
classifier.train(leafReader, textFieldName, classFieldName, analyzer, query);
|
||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||
assertNotNull(classificationResult.getAssignedClass());
|
||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||
double score = classificationResult.getScore();
|
||||
assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
|
||||
updateSampleIndex();
|
||||
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
|
||||
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
|
||||
assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
|
||||
populateSampleIndex(analyzer);
|
||||
|
||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||
assertNotNull(classificationResult.getAssignedClass());
|
||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||
double score = classificationResult.getScore();
|
||||
assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
|
||||
updateSampleIndex();
|
||||
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
|
||||
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
|
||||
assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
|
||||
|
||||
} finally {
|
||||
if (leafReader != null)
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
|
||||
private void populateSampleIndex(Analyzer analyzer) throws IOException {
|
||||
protected LeafReader populateSampleIndex(Analyzer analyzer) throws IOException {
|
||||
indexWriter.close();
|
||||
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
|
||||
indexWriter.commit();
|
||||
|
@ -134,8 +115,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
|
||||
Document doc = new Document();
|
||||
text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
|
||||
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
|
||||
"the Unknown Soldier in Warsaw Tuesday.";
|
||||
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
|
||||
"the Unknown Soldier in Warsaw Tuesday.";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||
doc.add(new Field(booleanFieldName, "true", ft));
|
||||
|
@ -144,7 +125,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
|
||||
doc = new Document();
|
||||
text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
|
||||
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
|
||||
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||
doc.add(new Field(booleanFieldName, "true", ft));
|
||||
|
@ -152,8 +133,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
|
||||
doc = new Document();
|
||||
text = "And there's a threshold question that he has to answer for the American people and " +
|
||||
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
|
||||
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
|
||||
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
|
||||
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||
doc.add(new Field(booleanFieldName, "true", ft));
|
||||
|
@ -161,8 +142,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
|
||||
doc = new Document();
|
||||
text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
|
||||
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
|
||||
"Albany's School of Criminal Justice.";
|
||||
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
|
||||
"Albany's School of Criminal Justice.";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||
doc.add(new Field(booleanFieldName, "true", ft));
|
||||
|
@ -170,8 +151,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
|
||||
doc = new Document();
|
||||
text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
|
||||
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
|
||||
"world through the Internet.";
|
||||
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
|
||||
"world through the Internet.";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
|
@ -179,7 +160,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
|
||||
doc = new Document();
|
||||
text = "So, about all those experts and analysts who've spent the past year or so saying " +
|
||||
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
|
||||
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
|
@ -187,8 +168,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
|
||||
doc = new Document();
|
||||
text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
|
||||
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
|
||||
"generally transfer or store huge volumes of personal data online.";
|
||||
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
|
||||
"generally transfer or store huge volumes of personal data online.";
|
||||
doc.add(new Field(textFieldName, text, ft));
|
||||
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||
doc.add(new Field(booleanFieldName, "false", ft));
|
||||
|
@ -200,22 +181,15 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
|||
indexWriter.addDocument(doc);
|
||||
|
||||
indexWriter.commit();
|
||||
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||
}
|
||||
|
||||
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
|
||||
LeafReader leafReader = null;
|
||||
long trainStart = System.currentTimeMillis();
|
||||
try {
|
||||
populatePerformanceIndex(analyzer);
|
||||
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||
classifier.train(leafReader, textFieldName, classFieldName, analyzer);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
|
||||
} finally {
|
||||
if (leafReader != null)
|
||||
leafReader.close();
|
||||
}
|
||||
populatePerformanceIndex(analyzer);
|
||||
long trainEnd = System.currentTimeMillis();
|
||||
long trainTime = trainEnd - trainStart;
|
||||
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
|
||||
}
|
||||
|
||||
private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
package org.apache.lucene.classification;
|
||||
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
@ -29,20 +31,32 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
|||
|
||||
@Test
|
||||
public void testBasicUsage() throws Exception {
|
||||
// usage with default MLT min docs / term freq
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(3), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||
// usage without custom min docs / term freq for MLT
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicUsageWithQuery() throws Exception {
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -22,14 +22,13 @@ import org.apache.lucene.analysis.Tokenizer;
|
|||
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
||||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.Reader;
|
||||
|
||||
/**
|
||||
* Testcase for {@link SimpleNaiveBayesClassifier}
|
||||
*/
|
||||
|
@ -37,18 +36,46 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
|||
|
||||
@Test
|
||||
public void testBasicUsage() throws Exception {
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicUsageWithQuery() throws Exception {
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNGramUsage() throws Exception {
|
||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
|
||||
LeafReader leafReader = null;
|
||||
try {
|
||||
Analyzer analyzer = new NGramAnalyzer();
|
||||
leafReader = populateSampleIndex(analyzer);
|
||||
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class NGramAnalyzer extends Analyzer {
|
||||
|
@ -59,9 +86,4 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
checkPerformance(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue