mirror of https://github.com/apache/lucene.git
LUCENE-6045 - refactor Classifier API to work better with multithreading
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1676997 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
7736e49b3e
commit
92842e7c34
|
@ -58,76 +58,14 @@ import org.apache.lucene.util.fst.Util;
|
||||||
*/
|
*/
|
||||||
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
|
|
||||||
private Double threshold;
|
private final Double threshold;
|
||||||
private final Integer batchSize;
|
private final Terms textTerms;
|
||||||
private Terms textTerms;
|
private final Analyzer analyzer;
|
||||||
private Analyzer analyzer;
|
private final String textFieldName;
|
||||||
private String textFieldName;
|
|
||||||
private FST<Long> fst;
|
private FST<Long> fst;
|
||||||
|
|
||||||
/**
|
public BooleanPerceptronClassifier(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer,
|
||||||
* Create a {@link BooleanPerceptronClassifier}
|
Query query, Integer batchSize, Double threshold) throws IOException {
|
||||||
*
|
|
||||||
* @param threshold the binary threshold for perceptron output evaluation
|
|
||||||
*/
|
|
||||||
public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
|
|
||||||
this.threshold = threshold;
|
|
||||||
this.batchSize = batchSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Default constructor, no batch updates of FST, perceptron threshold is
|
|
||||||
* calculated via underlying index metrics during
|
|
||||||
* {@link #train(org.apache.lucene.index.LeafReader, String, String, org.apache.lucene.analysis.Analyzer)
|
|
||||||
* training}
|
|
||||||
*/
|
|
||||||
public BooleanPerceptronClassifier() {
|
|
||||||
batchSize = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public ClassificationResult<Boolean> assignClass(String text)
|
|
||||||
throws IOException {
|
|
||||||
if (textTerms == null) {
|
|
||||||
throw new IOException("You must first call Classifier#train");
|
|
||||||
}
|
|
||||||
Long output = 0l;
|
|
||||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
|
||||||
CharTermAttribute charTermAttribute = tokenStream
|
|
||||||
.addAttribute(CharTermAttribute.class);
|
|
||||||
tokenStream.reset();
|
|
||||||
while (tokenStream.incrementToken()) {
|
|
||||||
String s = charTermAttribute.toString();
|
|
||||||
Long d = Util.get(fst, new BytesRef(s));
|
|
||||||
if (d != null) {
|
|
||||||
output += d;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenStream.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
|
|
||||||
return new ClassificationResult<>(output >= threshold, score);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String textFieldName,
|
|
||||||
String classFieldName, Analyzer analyzer) throws IOException {
|
|
||||||
train(leafReader, textFieldName, classFieldName, analyzer, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String textFieldName,
|
|
||||||
String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
|
||||||
this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
|
this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
|
||||||
|
|
||||||
if (textTerms == null) {
|
if (textTerms == null) {
|
||||||
|
@ -144,9 +82,11 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
this.threshold = (double) sumDocFreq / 2d;
|
this.threshold = (double) sumDocFreq / 2d;
|
||||||
} else {
|
} else {
|
||||||
throw new IOException(
|
throw new IOException(
|
||||||
"threshold cannot be assigned since term vectors for field "
|
"threshold cannot be assigned since term vectors for field "
|
||||||
+ textFieldName + " do not exist");
|
+ textFieldName + " do not exist");
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
this.threshold = threshold;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO : remove this map as soon as we have a writable FST
|
// TODO : remove this map as soon as we have a writable FST
|
||||||
|
@ -170,7 +110,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
}
|
}
|
||||||
// run the search and use stored field values
|
// run the search and use stored field values
|
||||||
for (ScoreDoc scoreDoc : indexSearcher.search(q,
|
for (ScoreDoc scoreDoc : indexSearcher.search(q,
|
||||||
Integer.MAX_VALUE).scoreDocs) {
|
Integer.MAX_VALUE).scoreDocs) {
|
||||||
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
|
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
|
||||||
|
|
||||||
StorableField textField = doc.getField(textFieldName);
|
StorableField textField = doc.getField(textFieldName);
|
||||||
|
@ -187,7 +127,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
long modifier = correctClass.compareTo(assignedClass);
|
long modifier = correctClass.compareTo(assignedClass);
|
||||||
if (modifier != 0) {
|
if (modifier != 0) {
|
||||||
updateWeights(leafReader, scoreDoc.doc, assignedClass,
|
updateWeights(leafReader, scoreDoc.doc, assignedClass,
|
||||||
weights, modifier, batchCount % batchSize == 0);
|
weights, modifier, batchCount % batchSize == 0);
|
||||||
}
|
}
|
||||||
batchCount++;
|
batchCount++;
|
||||||
}
|
}
|
||||||
|
@ -195,11 +135,6 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
weights.clear(); // free memory while waiting for GC
|
weights.clear(); // free memory while waiting for GC
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
|
||||||
throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
|
|
||||||
}
|
|
||||||
|
|
||||||
private void updateWeights(LeafReader leafReader,
|
private void updateWeights(LeafReader leafReader,
|
||||||
int docId, Boolean assignedClass, SortedMap<String, Double> weights,
|
int docId, Boolean assignedClass, SortedMap<String, Double> weights,
|
||||||
double modifier, boolean updateFST) throws IOException {
|
double modifier, boolean updateFST) throws IOException {
|
||||||
|
@ -210,7 +145,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
|
|
||||||
if (terms == null) {
|
if (terms == null) {
|
||||||
throw new IOException("term vectors must be stored for field "
|
throw new IOException("term vectors must be stored for field "
|
||||||
+ textFieldName);
|
+ textFieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
TermsEnum termsEnum = terms.iterator();
|
TermsEnum termsEnum = terms.iterator();
|
||||||
|
@ -240,17 +175,46 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
for (Map.Entry<String, Double> entry : weights.entrySet()) {
|
for (Map.Entry<String, Double> entry : weights.entrySet()) {
|
||||||
scratchBytes.copyChars(entry.getKey());
|
scratchBytes.copyChars(entry.getKey());
|
||||||
fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
|
fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
|
||||||
.getValue().longValue());
|
.getValue().longValue());
|
||||||
}
|
}
|
||||||
fst = fstBuilder.finish();
|
fst = fstBuilder.finish();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public ClassificationResult<Boolean> assignClass(String text)
|
||||||
|
throws IOException {
|
||||||
|
if (textTerms == null) {
|
||||||
|
throw new IOException("You must first call Classifier#train");
|
||||||
|
}
|
||||||
|
Long output = 0l;
|
||||||
|
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
||||||
|
CharTermAttribute charTermAttribute = tokenStream
|
||||||
|
.addAttribute(CharTermAttribute.class);
|
||||||
|
tokenStream.reset();
|
||||||
|
while (tokenStream.incrementToken()) {
|
||||||
|
String s = charTermAttribute.toString();
|
||||||
|
Long d = Util.get(fst, new BytesRef(s));
|
||||||
|
if (d != null) {
|
||||||
|
output += d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokenStream.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
|
||||||
|
return new ClassificationResult<>(output >= threshold, score);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<ClassificationResult<Boolean>> getClasses(String text)
|
public List<ClassificationResult<Boolean>> getClasses(String text)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
throw new RuntimeException("not implemented");
|
throw new RuntimeException("not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -259,7 +223,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<ClassificationResult<Boolean>> getClasses(String text, int max)
|
public List<ClassificationResult<Boolean>> getClasses(String text, int max)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
throw new RuntimeException("not implemented");
|
throw new RuntimeException("not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,50 +49,30 @@ import org.apache.lucene.util.BytesRef;
|
||||||
*/
|
*/
|
||||||
public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
||||||
//for caching classes this will be the classification class list
|
//for caching classes this will be the classification class list
|
||||||
private ArrayList<BytesRef> cclasses = new ArrayList<>();
|
private final ArrayList<BytesRef> cclasses = new ArrayList<>();
|
||||||
// it's a term-inmap style map, where the inmap contains class-hit pairs to the
|
// it's a term-inmap style map, where the inmap contains class-hit pairs to the
|
||||||
// upper term
|
// upper term
|
||||||
private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>();
|
private final Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>();
|
||||||
// the term frequency in classes
|
// the term frequency in classes
|
||||||
private Map<BytesRef, Double> classTermFreq = new HashMap<>();
|
private final Map<BytesRef, Double> classTermFreq = new HashMap<>();
|
||||||
private boolean justCachedTerms;
|
private boolean justCachedTerms;
|
||||||
private int docsWithClassSize;
|
private int docsWithClassSize;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new NaiveBayes classifier with inside caching. Note that you must
|
* Creates a new NaiveBayes classifier with inside caching. If you want less memory usage you could
|
||||||
* call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before
|
|
||||||
* you can classify any documents. If you want less memory usage you could
|
|
||||||
* call {@link #reInitCache(int, boolean) reInitCache()}.
|
* call {@link #reInitCache(int, boolean) reInitCache()}.
|
||||||
*/
|
*/
|
||||||
public CachingNaiveBayesClassifier() {
|
public CachingNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
||||||
}
|
super(leafReader, analyzer, query, classFieldName, textFieldNames);
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
|
||||||
train(leafReader, textFieldName, classFieldName, analyzer, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
|
||||||
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
|
||||||
super.train(leafReader, textFieldNames, classFieldName, analyzer, query);
|
|
||||||
// building the cache
|
// building the cache
|
||||||
reInitCache(0, true);
|
try {
|
||||||
|
reInitCache(0, true);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||||
if (leafReader == null) {
|
if (leafReader == null) {
|
||||||
throw new IOException("You must first call Classifier#train");
|
throw new IOException("You must first call Classifier#train");
|
||||||
|
|
|
@ -18,17 +18,19 @@ package org.apache.lucene.classification;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The result of a call to {@link Classifier#assignClass(String)} holding an assigned class of type <code>T</code> and a score.
|
* The result of a call to {@link Classifier#assignClass(String)} holding an assigned class of type <code>T</code> and a score.
|
||||||
|
*
|
||||||
* @lucene.experimental
|
* @lucene.experimental
|
||||||
*/
|
*/
|
||||||
public class ClassificationResult<T> implements Comparable<ClassificationResult<T>>{
|
public class ClassificationResult<T> implements Comparable<ClassificationResult<T>> {
|
||||||
|
|
||||||
private final T assignedClass;
|
private final T assignedClass;
|
||||||
private double score;
|
private double score;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructor
|
* Constructor
|
||||||
|
*
|
||||||
* @param assignedClass the class <code>T</code> assigned by a {@link Classifier}
|
* @param assignedClass the class <code>T</code> assigned by a {@link Classifier}
|
||||||
* @param score the score for the assignedClass as a <code>double</code>
|
* @param score the score for the assignedClass as a <code>double</code>
|
||||||
*/
|
*/
|
||||||
public ClassificationResult(T assignedClass, double score) {
|
public ClassificationResult(T assignedClass, double score) {
|
||||||
this.assignedClass = assignedClass;
|
this.assignedClass = assignedClass;
|
||||||
|
@ -37,6 +39,7 @@ public class ClassificationResult<T> implements Comparable<ClassificationResult<
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* retrieve the result class
|
* retrieve the result class
|
||||||
|
*
|
||||||
* @return a <code>T</code> representing an assigned class
|
* @return a <code>T</code> representing an assigned class
|
||||||
*/
|
*/
|
||||||
public T getAssignedClass() {
|
public T getAssignedClass() {
|
||||||
|
@ -45,14 +48,16 @@ public class ClassificationResult<T> implements Comparable<ClassificationResult<
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* retrieve the result score
|
* retrieve the result score
|
||||||
|
*
|
||||||
* @return a <code>double</code> representing a result score
|
* @return a <code>double</code> representing a result score
|
||||||
*/
|
*/
|
||||||
public double getScore() {
|
public double getScore() {
|
||||||
return score;
|
return score;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set the score value
|
* set the score value
|
||||||
|
*
|
||||||
* @param score the score for the assignedClass as a <code>double</code>
|
* @param score the score for the assignedClass as a <code>double</code>
|
||||||
*/
|
*/
|
||||||
public void setScore(double score) {
|
public void setScore(double score) {
|
||||||
|
|
|
@ -22,7 +22,6 @@ import java.util.List;
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.search.Query;
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.util.BytesRef;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
|
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
|
||||||
|
@ -39,7 +38,7 @@ public interface Classifier<T> {
|
||||||
* @return a {@link ClassificationResult} holding assigned class of type <code>T</code> and score
|
* @return a {@link ClassificationResult} holding assigned class of type <code>T</code> and score
|
||||||
* @throws IOException If there is a low-level I/O error.
|
* @throws IOException If there is a low-level I/O error.
|
||||||
*/
|
*/
|
||||||
public ClassificationResult<T> assignClass(String text) throws IOException;
|
ClassificationResult<T> assignClass(String text) throws IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get all the classes (sorted by score, descending) assigned to the given text String.
|
* Get all the classes (sorted by score, descending) assigned to the given text String.
|
||||||
|
@ -48,7 +47,7 @@ public interface Classifier<T> {
|
||||||
* @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
|
* @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
|
||||||
* @throws IOException If there is a low-level I/O error.
|
* @throws IOException If there is a low-level I/O error.
|
||||||
*/
|
*/
|
||||||
public List<ClassificationResult<T>> getClasses(String text) throws IOException;
|
List<ClassificationResult<T>> getClasses(String text) throws IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
|
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
|
||||||
|
@ -58,44 +57,6 @@ public interface Classifier<T> {
|
||||||
* @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
|
* @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
|
||||||
* @throws IOException If there is a low-level I/O error.
|
* @throws IOException If there is a low-level I/O error.
|
||||||
*/
|
*/
|
||||||
public List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
|
List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
|
||||||
|
|
||||||
/**
|
|
||||||
* Train the classifier using the underlying Lucene index
|
|
||||||
*
|
|
||||||
* @param leafReader the reader to use to access the Lucene index
|
|
||||||
* @param textFieldName the name of the field used to compare documents
|
|
||||||
* @param classFieldName the name of the field containing the class assigned to documents
|
|
||||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
|
||||||
* @throws IOException If there is a low-level I/O error.
|
|
||||||
*/
|
|
||||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer)
|
|
||||||
throws IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Train the classifier using the underlying Lucene index
|
|
||||||
*
|
|
||||||
* @param leafReader the reader to use to access the Lucene index
|
|
||||||
* @param textFieldName the name of the field used to compare documents
|
|
||||||
* @param classFieldName the name of the field containing the class assigned to documents
|
|
||||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
|
||||||
* @param query the query to filter which documents use for training
|
|
||||||
* @throws IOException If there is a low-level I/O error.
|
|
||||||
*/
|
|
||||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
|
||||||
throws IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Train the classifier using the underlying Lucene index
|
|
||||||
*
|
|
||||||
* @param leafReader the reader to use to access the Lucene index
|
|
||||||
* @param textFieldNames the names of the fields to be used to compare documents
|
|
||||||
* @param classFieldName the name of the field containing the class assigned to documents
|
|
||||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
|
||||||
* @param query the query to filter which documents use for training
|
|
||||||
* @throws IOException If there is a low-level I/O error.
|
|
||||||
*/
|
|
||||||
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
|
|
||||||
throws IOException;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ import java.util.Map;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.StorableField;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.queries.mlt.MoreLikeThis;
|
import org.apache.lucene.queries.mlt.MoreLikeThis;
|
||||||
import org.apache.lucene.search.BooleanClause;
|
import org.apache.lucene.search.BooleanClause;
|
||||||
|
@ -45,37 +46,31 @@ import org.apache.lucene.util.BytesRef;
|
||||||
*/
|
*/
|
||||||
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
|
|
||||||
private MoreLikeThis mlt;
|
private final MoreLikeThis mlt;
|
||||||
private String[] textFieldNames;
|
private final String[] textFieldNames;
|
||||||
private String classFieldName;
|
private final String classFieldName;
|
||||||
private IndexSearcher indexSearcher;
|
private final IndexSearcher indexSearcher;
|
||||||
private final int k;
|
private final int k;
|
||||||
private Query query;
|
private final Query query;
|
||||||
|
|
||||||
private int minDocsFreq;
|
public KNearestNeighborClassifier(LeafReader leafReader, Analyzer analyzer, Query query, int k, int minDocsFreq,
|
||||||
private int minTermFreq;
|
int minTermFreq, String classFieldName, String... textFieldNames) {
|
||||||
|
this.textFieldNames = textFieldNames;
|
||||||
/**
|
this.classFieldName = classFieldName;
|
||||||
* Create a {@link Classifier} using kNN algorithm
|
this.mlt = new MoreLikeThis(leafReader);
|
||||||
*
|
this.mlt.setAnalyzer(analyzer);
|
||||||
* @param k the number of neighbors to analyze as an <code>int</code>
|
this.mlt.setFieldNames(textFieldNames);
|
||||||
*/
|
this.indexSearcher = new IndexSearcher(leafReader);
|
||||||
public KNearestNeighborClassifier(int k) {
|
if (minDocsFreq > 0) {
|
||||||
|
mlt.setMinDocFreq(minDocsFreq);
|
||||||
|
}
|
||||||
|
if (minTermFreq > 0) {
|
||||||
|
mlt.setMinTermFreq(minTermFreq);
|
||||||
|
}
|
||||||
|
this.query = query;
|
||||||
this.k = k;
|
this.k = k;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a {@link Classifier} using kNN algorithm
|
|
||||||
*
|
|
||||||
* @param k the number of neighbors to analyze as an <code>int</code>
|
|
||||||
* @param minDocsFreq the minimum number of docs frequency for MLT to be set with {@link MoreLikeThis#setMinDocFreq(int)}
|
|
||||||
* @param minTermFreq the minimum number of term frequency for MLT to be set with {@link MoreLikeThis#setMinTermFreq(int)}
|
|
||||||
*/
|
|
||||||
public KNearestNeighborClassifier(int k, int minDocsFreq, int minTermFreq) {
|
|
||||||
this.k = k;
|
|
||||||
this.minDocsFreq = minDocsFreq;
|
|
||||||
this.minTermFreq = minTermFreq;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
|
@ -136,12 +131,15 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
|
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
|
||||||
Map<BytesRef, Integer> classCounts = new HashMap<>();
|
Map<BytesRef, Integer> classCounts = new HashMap<>();
|
||||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||||
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
|
StorableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
|
||||||
Integer count = classCounts.get(cl);
|
if (storableField != null) {
|
||||||
if (count != null) {
|
BytesRef cl = new BytesRef(storableField.stringValue());
|
||||||
classCounts.put(cl, count + 1);
|
Integer count = classCounts.get(cl);
|
||||||
} else {
|
if (count != null) {
|
||||||
classCounts.put(cl, 1);
|
classCounts.put(cl, count + 1);
|
||||||
|
} else {
|
||||||
|
classCounts.put(cl, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||||
|
@ -161,39 +159,4 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
return returnList;
|
return returnList;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
|
||||||
train(leafReader, textFieldName, classFieldName, analyzer, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
|
||||||
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
|
||||||
this.textFieldNames = textFieldNames;
|
|
||||||
this.classFieldName = classFieldName;
|
|
||||||
mlt = new MoreLikeThis(leafReader);
|
|
||||||
mlt.setAnalyzer(analyzer);
|
|
||||||
mlt.setFieldNames(textFieldNames);
|
|
||||||
indexSearcher = new IndexSearcher(leafReader);
|
|
||||||
if (minDocsFreq > 0) {
|
|
||||||
mlt.setMinDocFreq(minDocsFreq);
|
|
||||||
}
|
|
||||||
if (minTermFreq > 0) {
|
|
||||||
mlt.setMinTermFreq(minTermFreq);
|
|
||||||
}
|
|
||||||
this.query = query;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,64 +51,38 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
* {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
|
* {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
|
||||||
* index
|
* index
|
||||||
*/
|
*/
|
||||||
protected LeafReader leafReader;
|
protected final LeafReader leafReader;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* names of the fields to be used as input text
|
* names of the fields to be used as input text
|
||||||
*/
|
*/
|
||||||
protected String[] textFieldNames;
|
protected final String[] textFieldNames;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* name of the field to be used as a class / category output
|
* name of the field to be used as a class / category output
|
||||||
*/
|
*/
|
||||||
protected String classFieldName;
|
protected final String classFieldName;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text
|
* {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text
|
||||||
*/
|
*/
|
||||||
protected Analyzer analyzer;
|
protected final Analyzer analyzer;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies
|
* {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies
|
||||||
*/
|
*/
|
||||||
protected IndexSearcher indexSearcher;
|
protected final IndexSearcher indexSearcher;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify
|
* {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify
|
||||||
*/
|
*/
|
||||||
protected Query query;
|
protected final Query query;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new NaiveBayes classifier.
|
* Creates a new NaiveBayes classifier.
|
||||||
* Note that you must call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before you can
|
|
||||||
* classify any documents.
|
* classify any documents.
|
||||||
*/
|
*/
|
||||||
public SimpleNaiveBayesClassifier() {
|
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
|
||||||
train(leafReader, textFieldName, classFieldName, analyzer, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
|
||||||
throws IOException {
|
|
||||||
train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
|
|
||||||
throws IOException {
|
|
||||||
this.leafReader = leafReader;
|
this.leafReader = leafReader;
|
||||||
this.indexSearcher = new IndexSearcher(this.leafReader);
|
this.indexSearcher = new IndexSearcher(this.leafReader);
|
||||||
this.textFieldNames = textFieldNames;
|
this.textFieldNames = textFieldNames;
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
/**
|
/**
|
||||||
* Uses already seen data (the indexed documents) to classify new documents.
|
* Uses already seen data (the indexed documents) to classify new documents.
|
||||||
* <p>
|
* <p>
|
||||||
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
|
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
|
||||||
* Neighbor classifier and a Perceptron based classifier.
|
* Neighbor classifier and a Perceptron based classifier.
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.classification;
|
package org.apache.lucene.classification;
|
||||||
|
|
|
@ -33,7 +33,8 @@ public class DocToDoubleVectorUtils {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create a sparse <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
|
* create a sparse <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
|
||||||
* @param docTerms term vectors for a given document
|
*
|
||||||
|
* @param docTerms term vectors for a given document
|
||||||
* @param fieldTerms field term vectors
|
* @param fieldTerms field term vectors
|
||||||
* @return a sparse vector of <code>Double</code>s as an array
|
* @return a sparse vector of <code>Double</code>s as an array
|
||||||
* @throws IOException in case accessing the underlying index fails
|
* @throws IOException in case accessing the underlying index fails
|
||||||
|
@ -54,8 +55,7 @@ public class DocToDoubleVectorUtils {
|
||||||
if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) {
|
if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) {
|
||||||
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
|
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
|
||||||
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
|
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
freqVector[i] = 0d;
|
freqVector[i] = 0d;
|
||||||
}
|
}
|
||||||
i++;
|
i++;
|
||||||
|
@ -66,6 +66,7 @@ public class DocToDoubleVectorUtils {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create a dense <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
|
* create a dense <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
|
||||||
|
*
|
||||||
* @param docTerms term vectors for a given document
|
* @param docTerms term vectors for a given document
|
||||||
* @return a dense vector of <code>Double</code>s as an array
|
* @return a dense vector of <code>Double</code>s as an array
|
||||||
* @throws IOException in case accessing the underlying index fails
|
* @throws IOException in case accessing the underlying index fails
|
||||||
|
@ -73,16 +74,16 @@ public class DocToDoubleVectorUtils {
|
||||||
public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms) throws IOException {
|
public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms) throws IOException {
|
||||||
Double[] freqVector = null;
|
Double[] freqVector = null;
|
||||||
if (docTerms != null) {
|
if (docTerms != null) {
|
||||||
freqVector = new Double[(int) docTerms.size()];
|
freqVector = new Double[(int) docTerms.size()];
|
||||||
int i = 0;
|
int i = 0;
|
||||||
TermsEnum docTermsEnum = docTerms.iterator();
|
TermsEnum docTermsEnum = docTerms.iterator();
|
||||||
|
|
||||||
while (docTermsEnum.next() != null) {
|
while (docTermsEnum.next() != null) {
|
||||||
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
|
long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
|
||||||
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
|
freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return freqVector;
|
return freqVector;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
package org.apache.lucene.classification;
|
package org.apache.lucene.classification;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.MockAnalyzer;
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.search.TermQuery;
|
import org.apache.lucene.search.TermQuery;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -28,22 +30,45 @@ public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Bool
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsage() throws Exception {
|
public void testBasicUsage() throws Exception {
|
||||||
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
|
LeafReader leafReader = null;
|
||||||
|
try {
|
||||||
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
|
leafReader = populateSampleIndex(analyzer);
|
||||||
|
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, null, 1, null), TECHNOLOGY_INPUT, false);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testExplicitThreshold() throws Exception {
|
public void testExplicitThreshold() throws Exception {
|
||||||
checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
|
LeafReader leafReader = null;
|
||||||
|
try {
|
||||||
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
|
leafReader = populateSampleIndex(analyzer);
|
||||||
|
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, null, 1, 100d), TECHNOLOGY_INPUT, false);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsageWithQuery() throws Exception {
|
public void testBasicUsageWithQuery() throws Exception {
|
||||||
checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName, new TermQuery(new Term(textFieldName, "it")));
|
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||||
}
|
LeafReader leafReader = null;
|
||||||
|
try {
|
||||||
@Test
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
public void testPerformance() throws Exception {
|
leafReader = populateSampleIndex(analyzer);
|
||||||
checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName);
|
checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, query, 1, null), TECHNOLOGY_INPUT, false);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,8 @@ import org.apache.lucene.analysis.Tokenizer;
|
||||||
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
||||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
||||||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.search.TermQuery;
|
import org.apache.lucene.search.TermQuery;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
@ -35,18 +37,46 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsage() throws Exception {
|
public void testBasicUsage() throws Exception {
|
||||||
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
LeafReader leafReader = null;
|
||||||
checkCorrectClassification(new CachingNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
try {
|
||||||
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
|
leafReader = populateSampleIndex(analyzer);
|
||||||
|
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||||
|
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsageWithQuery() throws Exception {
|
public void testBasicUsageWithQuery() throws Exception {
|
||||||
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
|
LeafReader leafReader = null;
|
||||||
|
try {
|
||||||
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
|
leafReader = populateSampleIndex(analyzer);
|
||||||
|
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||||
|
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNGramUsage() throws Exception {
|
public void testNGramUsage() throws Exception {
|
||||||
checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
|
LeafReader leafReader = null;
|
||||||
|
try {
|
||||||
|
NGramAnalyzer analyzer = new NGramAnalyzer();
|
||||||
|
leafReader = populateSampleIndex(analyzer);
|
||||||
|
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class NGramAnalyzer extends Analyzer {
|
private class NGramAnalyzer extends Analyzer {
|
||||||
|
@ -57,9 +87,4 @@ public class CachingNaiveBayesClassifierTest extends ClassificationTestBase<Byte
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testPerformance() throws Exception {
|
|
||||||
checkPerformance(new CachingNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,14 +41,14 @@ import org.junit.Before;
|
||||||
*/
|
*/
|
||||||
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
|
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
|
||||||
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
|
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
|
||||||
public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
|
public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
|
||||||
|
|
||||||
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
|
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
|
||||||
" Truth is, Amazon may know more.";
|
" Truth is, Amazon may know more.";
|
||||||
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
||||||
|
|
||||||
private RandomIndexWriter indexWriter;
|
protected RandomIndexWriter indexWriter;
|
||||||
private Directory dir;
|
private Directory dir;
|
||||||
private FieldType ft;
|
private FieldType ft;
|
||||||
|
|
||||||
|
@ -79,53 +79,34 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
dir.close();
|
dir.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
|
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult) throws Exception {
|
||||||
checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
|
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||||
|
assertNotNull(classificationResult.getAssignedClass());
|
||||||
|
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||||
|
double score = classificationResult.getScore();
|
||||||
|
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
|
|
||||||
LeafReader leafReader = null;
|
|
||||||
try {
|
|
||||||
populateSampleIndex(analyzer);
|
|
||||||
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
|
||||||
classifier.train(leafReader, textFieldName, classFieldName, analyzer, query);
|
|
||||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
|
||||||
assertNotNull(classificationResult.getAssignedClass());
|
|
||||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
|
||||||
double score = classificationResult.getScore();
|
|
||||||
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
|
|
||||||
} finally {
|
|
||||||
if (leafReader != null)
|
|
||||||
leafReader.close();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
|
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
|
||||||
checkOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
|
checkOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
|
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
|
||||||
LeafReader leafReader = null;
|
populateSampleIndex(analyzer);
|
||||||
try {
|
|
||||||
populateSampleIndex(analyzer);
|
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||||
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
assertNotNull(classificationResult.getAssignedClass());
|
||||||
classifier.train(leafReader, textFieldName, classFieldName, analyzer, query);
|
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||||
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
double score = classificationResult.getScore();
|
||||||
assertNotNull(classificationResult.getAssignedClass());
|
assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
|
||||||
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
updateSampleIndex();
|
||||||
double score = classificationResult.getScore();
|
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
|
||||||
assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
|
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
|
||||||
updateSampleIndex();
|
assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
|
||||||
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
|
|
||||||
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
|
|
||||||
assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
|
|
||||||
|
|
||||||
} finally {
|
|
||||||
if (leafReader != null)
|
|
||||||
leafReader.close();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void populateSampleIndex(Analyzer analyzer) throws IOException {
|
protected LeafReader populateSampleIndex(Analyzer analyzer) throws IOException {
|
||||||
indexWriter.close();
|
indexWriter.close();
|
||||||
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
|
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
|
||||||
indexWriter.commit();
|
indexWriter.commit();
|
||||||
|
@ -134,8 +115,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
|
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
|
text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
|
||||||
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
|
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
|
||||||
"the Unknown Soldier in Warsaw Tuesday.";
|
"the Unknown Soldier in Warsaw Tuesday.";
|
||||||
doc.add(new Field(textFieldName, text, ft));
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
doc.add(new Field(booleanFieldName, "true", ft));
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
|
@ -144,7 +125,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
|
text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
|
||||||
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
|
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
|
||||||
doc.add(new Field(textFieldName, text, ft));
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
doc.add(new Field(booleanFieldName, "true", ft));
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
|
@ -152,8 +133,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
text = "And there's a threshold question that he has to answer for the American people and " +
|
text = "And there's a threshold question that he has to answer for the American people and " +
|
||||||
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
|
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
|
||||||
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
|
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
|
||||||
doc.add(new Field(textFieldName, text, ft));
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
doc.add(new Field(booleanFieldName, "true", ft));
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
|
@ -161,8 +142,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
|
text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
|
||||||
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
|
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
|
||||||
"Albany's School of Criminal Justice.";
|
"Albany's School of Criminal Justice.";
|
||||||
doc.add(new Field(textFieldName, text, ft));
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "politics", ft));
|
doc.add(new Field(categoryFieldName, "politics", ft));
|
||||||
doc.add(new Field(booleanFieldName, "true", ft));
|
doc.add(new Field(booleanFieldName, "true", ft));
|
||||||
|
@ -170,8 +151,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
|
text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
|
||||||
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
|
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
|
||||||
"world through the Internet.";
|
"world through the Internet.";
|
||||||
doc.add(new Field(textFieldName, text, ft));
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "technology", ft));
|
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||||
doc.add(new Field(booleanFieldName, "false", ft));
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
@ -179,7 +160,7 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
text = "So, about all those experts and analysts who've spent the past year or so saying " +
|
text = "So, about all those experts and analysts who've spent the past year or so saying " +
|
||||||
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
|
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
|
||||||
doc.add(new Field(textFieldName, text, ft));
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "technology", ft));
|
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||||
doc.add(new Field(booleanFieldName, "false", ft));
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
@ -187,8 +168,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
|
|
||||||
doc = new Document();
|
doc = new Document();
|
||||||
text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
|
text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
|
||||||
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
|
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
|
||||||
"generally transfer or store huge volumes of personal data online.";
|
"generally transfer or store huge volumes of personal data online.";
|
||||||
doc.add(new Field(textFieldName, text, ft));
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
doc.add(new Field(categoryFieldName, "technology", ft));
|
doc.add(new Field(categoryFieldName, "technology", ft));
|
||||||
doc.add(new Field(booleanFieldName, "false", ft));
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
@ -200,22 +181,15 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
indexWriter.addDocument(doc);
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
indexWriter.commit();
|
indexWriter.commit();
|
||||||
|
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
|
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
|
||||||
LeafReader leafReader = null;
|
|
||||||
long trainStart = System.currentTimeMillis();
|
long trainStart = System.currentTimeMillis();
|
||||||
try {
|
populatePerformanceIndex(analyzer);
|
||||||
populatePerformanceIndex(analyzer);
|
long trainEnd = System.currentTimeMillis();
|
||||||
leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
long trainTime = trainEnd - trainStart;
|
||||||
classifier.train(leafReader, textFieldName, classFieldName, analyzer);
|
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
|
||||||
long trainEnd = System.currentTimeMillis();
|
|
||||||
long trainTime = trainEnd - trainStart;
|
|
||||||
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
|
|
||||||
} finally {
|
|
||||||
if (leafReader != null)
|
|
||||||
leafReader.close();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
|
private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
package org.apache.lucene.classification;
|
package org.apache.lucene.classification;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.MockAnalyzer;
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.search.TermQuery;
|
import org.apache.lucene.search.TermQuery;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
@ -29,20 +31,32 @@ public class KNearestNeighborClassifierTest extends ClassificationTestBase<Bytes
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsage() throws Exception {
|
public void testBasicUsage() throws Exception {
|
||||||
// usage with default MLT min docs / term freq
|
LeafReader leafReader = null;
|
||||||
checkCorrectClassification(new KNearestNeighborClassifier(3), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
try {
|
||||||
// usage without custom min docs / term freq for MLT
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
leafReader = populateSampleIndex(analyzer);
|
||||||
|
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||||
|
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsageWithQuery() throws Exception {
|
public void testBasicUsageWithQuery() throws Exception {
|
||||||
checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
|
LeafReader leafReader = null;
|
||||||
}
|
try {
|
||||||
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
@Test
|
leafReader = populateSampleIndex(analyzer);
|
||||||
public void testPerformance() throws Exception {
|
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||||
checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);
|
checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,14 +22,13 @@ import org.apache.lucene.analysis.Tokenizer;
|
||||||
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
import org.apache.lucene.analysis.core.KeywordTokenizer;
|
||||||
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
|
||||||
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.search.TermQuery;
|
import org.apache.lucene.search.TermQuery;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import java.io.Reader;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Testcase for {@link SimpleNaiveBayesClassifier}
|
* Testcase for {@link SimpleNaiveBayesClassifier}
|
||||||
*/
|
*/
|
||||||
|
@ -37,18 +36,46 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsage() throws Exception {
|
public void testBasicUsage() throws Exception {
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
LeafReader leafReader = null;
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
|
try {
|
||||||
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
|
leafReader = populateSampleIndex(analyzer);
|
||||||
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||||
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicUsageWithQuery() throws Exception {
|
public void testBasicUsageWithQuery() throws Exception {
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
|
LeafReader leafReader = null;
|
||||||
|
try {
|
||||||
|
MockAnalyzer analyzer = new MockAnalyzer(random());
|
||||||
|
leafReader = populateSampleIndex(analyzer);
|
||||||
|
TermQuery query = new TermQuery(new Term(textFieldName, "it"));
|
||||||
|
checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNGramUsage() throws Exception {
|
public void testNGramUsage() throws Exception {
|
||||||
checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
|
LeafReader leafReader = null;
|
||||||
|
try {
|
||||||
|
Analyzer analyzer = new NGramAnalyzer();
|
||||||
|
leafReader = populateSampleIndex(analyzer);
|
||||||
|
checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class NGramAnalyzer extends Analyzer {
|
private class NGramAnalyzer extends Analyzer {
|
||||||
|
@ -59,9 +86,4 @@ public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase<Bytes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testPerformance() throws Exception {
|
|
||||||
checkPerformance(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue