mirror of https://github.com/apache/lucene.git
LUCENE-7350 - Let classifiers be constructed from IndexReaders
This commit is contained in:
parent
6ef174f527
commit
fcf4389d82
|
@ -26,6 +26,7 @@ import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.analysis.TokenStream;
|
import org.apache.lucene.analysis.TokenStream;
|
||||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.IndexableField;
|
import org.apache.lucene.index.IndexableField;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.MultiFields;
|
import org.apache.lucene.index.MultiFields;
|
||||||
|
@ -67,7 +68,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
/**
|
/**
|
||||||
* Creates a {@link BooleanPerceptronClassifier}
|
* Creates a {@link BooleanPerceptronClassifier}
|
||||||
*
|
*
|
||||||
* @param leafReader the reader on the index to be used for classification
|
* @param indexReader the reader on the index to be used for classification
|
||||||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||||
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||||
* if all the indexed docs should be used
|
* if all the indexed docs should be used
|
||||||
|
@ -78,9 +79,9 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
* @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field
|
* @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field
|
||||||
* cannot be found
|
* cannot be found
|
||||||
*/
|
*/
|
||||||
public BooleanPerceptronClassifier(LeafReader leafReader, Analyzer analyzer, Query query, Integer batchSize,
|
public BooleanPerceptronClassifier(IndexReader indexReader, Analyzer analyzer, Query query, Integer batchSize,
|
||||||
Double threshold, String classFieldName, String textFieldName) throws IOException {
|
Double threshold, String classFieldName, String textFieldName) throws IOException {
|
||||||
this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
|
this.textTerms = MultiFields.getTerms(indexReader, textFieldName);
|
||||||
|
|
||||||
if (textTerms == null) {
|
if (textTerms == null) {
|
||||||
throw new IOException("term vectors need to be available for field " + textFieldName);
|
throw new IOException("term vectors need to be available for field " + textFieldName);
|
||||||
|
@ -91,7 +92,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
|
|
||||||
if (threshold == null || threshold == 0d) {
|
if (threshold == null || threshold == 0d) {
|
||||||
// automatic assign a threshold
|
// automatic assign a threshold
|
||||||
long sumDocFreq = leafReader.getSumDocFreq(textFieldName);
|
long sumDocFreq = indexReader.getSumDocFreq(textFieldName);
|
||||||
if (sumDocFreq != -1) {
|
if (sumDocFreq != -1) {
|
||||||
this.threshold = (double) sumDocFreq / 2d;
|
this.threshold = (double) sumDocFreq / 2d;
|
||||||
} else {
|
} else {
|
||||||
|
@ -113,7 +114,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
}
|
}
|
||||||
updateFST(weights);
|
updateFST(weights);
|
||||||
|
|
||||||
IndexSearcher indexSearcher = new IndexSearcher(leafReader);
|
IndexSearcher indexSearcher = new IndexSearcher(indexReader);
|
||||||
|
|
||||||
int batchCount = 0;
|
int batchCount = 0;
|
||||||
|
|
||||||
|
@ -140,7 +141,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
Boolean correctClass = Boolean.valueOf(classField.stringValue());
|
Boolean correctClass = Boolean.valueOf(classField.stringValue());
|
||||||
long modifier = correctClass.compareTo(assignedClass);
|
long modifier = correctClass.compareTo(assignedClass);
|
||||||
if (modifier != 0) {
|
if (modifier != 0) {
|
||||||
updateWeights(leafReader, scoreDoc.doc, assignedClass,
|
updateWeights(indexReader, scoreDoc.doc, assignedClass,
|
||||||
weights, modifier, batchCount % batchSize == 0);
|
weights, modifier, batchCount % batchSize == 0);
|
||||||
}
|
}
|
||||||
batchCount++;
|
batchCount++;
|
||||||
|
@ -149,13 +150,13 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
weights.clear(); // free memory while waiting for GC
|
weights.clear(); // free memory while waiting for GC
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateWeights(LeafReader leafReader,
|
private void updateWeights(IndexReader indexReader,
|
||||||
int docId, Boolean assignedClass, SortedMap<String, Double> weights,
|
int docId, Boolean assignedClass, SortedMap<String, Double> weights,
|
||||||
double modifier, boolean updateFST) throws IOException {
|
double modifier, boolean updateFST) throws IOException {
|
||||||
TermsEnum cte = textTerms.iterator();
|
TermsEnum cte = textTerms.iterator();
|
||||||
|
|
||||||
// get the doc term vectors
|
// get the doc term vectors
|
||||||
Terms terms = leafReader.getTermVector(docId, textFieldName);
|
Terms terms = indexReader.getTermVector(docId, textFieldName);
|
||||||
|
|
||||||
if (terms == null) {
|
if (terms == null) {
|
||||||
throw new IOException("term vectors must be stored for field "
|
throw new IOException("term vectors must be stored for field "
|
||||||
|
@ -201,7 +202,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
@Override
|
@Override
|
||||||
public ClassificationResult<Boolean> assignClass(String text)
|
public ClassificationResult<Boolean> assignClass(String text)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
Long output = 0l;
|
Long output = 0L;
|
||||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
||||||
CharTermAttribute charTermAttribute = tokenStream
|
CharTermAttribute charTermAttribute = tokenStream
|
||||||
.addAttribute(CharTermAttribute.class);
|
.addAttribute(CharTermAttribute.class);
|
||||||
|
|
|
@ -212,7 +212,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
||||||
// build the cache for the word
|
// build the cache for the word
|
||||||
Map<String, Long> frequencyMap = new HashMap<>();
|
Map<String, Long> frequencyMap = new HashMap<>();
|
||||||
for (String textFieldName : textFieldNames) {
|
for (String textFieldName : textFieldNames) {
|
||||||
TermsEnum termsEnum = leafReader.terms(textFieldName).iterator();
|
TermsEnum termsEnum = MultiFields.getTerms(indexReader, textFieldName).iterator();
|
||||||
while (termsEnum.next() != null) {
|
while (termsEnum.next() != null) {
|
||||||
BytesRef term = termsEnum.term();
|
BytesRef term = termsEnum.term();
|
||||||
String termText = term.utf8ToString();
|
String termText = term.utf8ToString();
|
||||||
|
@ -229,7 +229,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
||||||
}
|
}
|
||||||
|
|
||||||
// fill the class list
|
// fill the class list
|
||||||
Terms terms = MultiFields.getTerms(leafReader, classFieldName);
|
Terms terms = MultiFields.getTerms(indexReader, classFieldName);
|
||||||
TermsEnum termsEnum = terms.iterator();
|
TermsEnum termsEnum = terms.iterator();
|
||||||
while ((termsEnum.next()) != null) {
|
while ((termsEnum.next()) != null) {
|
||||||
cclasses.add(BytesRef.deepCopyOf(termsEnum.term()));
|
cclasses.add(BytesRef.deepCopyOf(termsEnum.term()));
|
||||||
|
@ -238,11 +238,11 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
||||||
for (BytesRef cclass : cclasses) {
|
for (BytesRef cclass : cclasses) {
|
||||||
double avgNumberOfUniqueTerms = 0;
|
double avgNumberOfUniqueTerms = 0;
|
||||||
for (String textFieldName : textFieldNames) {
|
for (String textFieldName : textFieldNames) {
|
||||||
terms = MultiFields.getTerms(leafReader, textFieldName);
|
terms = MultiFields.getTerms(indexReader, textFieldName);
|
||||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||||
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount();
|
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount();
|
||||||
}
|
}
|
||||||
int docsWithC = leafReader.docFreq(new Term(classFieldName, cclass));
|
int docsWithC = indexReader.docFreq(new Term(classFieldName, cclass));
|
||||||
classTermFreq.put(cclass, avgNumberOfUniqueTerms * docsWithC);
|
classTermFreq.put(cclass, avgNumberOfUniqueTerms * docsWithC);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.IndexableField;
|
import org.apache.lucene.index.IndexableField;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
|
@ -82,7 +83,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
/**
|
/**
|
||||||
* Creates a {@link KNearestNeighborClassifier}.
|
* Creates a {@link KNearestNeighborClassifier}.
|
||||||
*
|
*
|
||||||
* @param leafReader the reader on the index to be used for classification
|
* @param indexReader the reader on the index to be used for classification
|
||||||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||||
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
|
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
|
||||||
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
|
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
|
||||||
|
@ -94,14 +95,14 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
* @param classFieldName the name of the field used as the output for the classifier
|
* @param classFieldName the name of the field used as the output for the classifier
|
||||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||||
*/
|
*/
|
||||||
public KNearestNeighborClassifier(LeafReader leafReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq,
|
public KNearestNeighborClassifier(IndexReader indexReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq,
|
||||||
int minTermFreq, String classFieldName, String... textFieldNames) {
|
int minTermFreq, String classFieldName, String... textFieldNames) {
|
||||||
this.textFieldNames = textFieldNames;
|
this.textFieldNames = textFieldNames;
|
||||||
this.classFieldName = classFieldName;
|
this.classFieldName = classFieldName;
|
||||||
this.mlt = new MoreLikeThis(leafReader);
|
this.mlt = new MoreLikeThis(indexReader);
|
||||||
this.mlt.setAnalyzer(analyzer);
|
this.mlt.setAnalyzer(analyzer);
|
||||||
this.mlt.setFieldNames(textFieldNames);
|
this.mlt.setFieldNames(textFieldNames);
|
||||||
this.indexSearcher = new IndexSearcher(leafReader);
|
this.indexSearcher = new IndexSearcher(indexReader);
|
||||||
if (similarity != null) {
|
if (similarity != null) {
|
||||||
this.indexSearcher.setSimilarity(similarity);
|
this.indexSearcher.setSimilarity(similarity);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -26,7 +26,7 @@ import java.util.List;
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.analysis.TokenStream;
|
import org.apache.lucene.analysis.TokenStream;
|
||||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.MultiFields;
|
import org.apache.lucene.index.MultiFields;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.index.Terms;
|
import org.apache.lucene.index.Terms;
|
||||||
|
@ -48,10 +48,10 @@ import org.apache.lucene.util.BytesRef;
|
||||||
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
|
* {@link org.apache.lucene.index.IndexReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
|
||||||
* index
|
* index
|
||||||
*/
|
*/
|
||||||
protected final LeafReader leafReader;
|
protected final IndexReader indexReader;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* names of the fields to be used as input text
|
* names of the fields to be used as input text
|
||||||
|
@ -81,7 +81,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
/**
|
/**
|
||||||
* Creates a new NaiveBayes classifier.
|
* Creates a new NaiveBayes classifier.
|
||||||
*
|
*
|
||||||
* @param leafReader the reader on the index to be used for classification
|
* @param indexReader the reader on the index to be used for classification
|
||||||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||||
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||||
* if all the indexed docs should be used
|
* if all the indexed docs should be used
|
||||||
|
@ -89,9 +89,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
* as the returned class will be a token indexed for this field
|
* as the returned class will be a token indexed for this field
|
||||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
|
* @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
|
||||||
*/
|
*/
|
||||||
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
public SimpleNaiveBayesClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
||||||
this.leafReader = leafReader;
|
this.indexReader = indexReader;
|
||||||
this.indexSearcher = new IndexSearcher(this.leafReader);
|
this.indexSearcher = new IndexSearcher(this.indexReader);
|
||||||
this.textFieldNames = textFieldNames;
|
this.textFieldNames = textFieldNames;
|
||||||
this.classFieldName = classFieldName;
|
this.classFieldName = classFieldName;
|
||||||
this.analyzer = analyzer;
|
this.analyzer = analyzer;
|
||||||
|
@ -144,7 +144,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||||
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||||
|
|
||||||
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
Terms classes = MultiFields.getTerms(indexReader, classFieldName);
|
||||||
if (classes != null) {
|
if (classes != null) {
|
||||||
TermsEnum classesEnum = classes.iterator();
|
TermsEnum classesEnum = classes.iterator();
|
||||||
BytesRef next;
|
BytesRef next;
|
||||||
|
@ -169,7 +169,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
* @throws IOException if accessing to term vectors or search fails
|
* @throws IOException if accessing to term vectors or search fails
|
||||||
*/
|
*/
|
||||||
protected int countDocsWithClass() throws IOException {
|
protected int countDocsWithClass() throws IOException {
|
||||||
Terms terms = MultiFields.getTerms(this.leafReader, this.classFieldName);
|
Terms terms = MultiFields.getTerms(this.indexReader, this.classFieldName);
|
||||||
int docCount;
|
int docCount;
|
||||||
if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount
|
if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount
|
||||||
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
||||||
|
@ -240,11 +240,11 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
private double getTextTermFreqForClass(Term term) throws IOException {
|
private double getTextTermFreqForClass(Term term) throws IOException {
|
||||||
double avgNumberOfUniqueTerms = 0;
|
double avgNumberOfUniqueTerms = 0;
|
||||||
for (String textFieldName : textFieldNames) {
|
for (String textFieldName : textFieldNames) {
|
||||||
Terms terms = MultiFields.getTerms(leafReader, textFieldName);
|
Terms terms = MultiFields.getTerms(indexReader, textFieldName);
|
||||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||||
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||||
}
|
}
|
||||||
int docsWithC = leafReader.docFreq(term);
|
int docsWithC = indexReader.docFreq(term);
|
||||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -277,7 +277,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
}
|
}
|
||||||
|
|
||||||
private int docCount(Term term) throws IOException {
|
private int docCount(Term term) throws IOException {
|
||||||
return leafReader.docFreq(term);
|
return indexReader.docFreq(term);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.classification.ClassificationResult;
|
import org.apache.lucene.classification.ClassificationResult;
|
||||||
import org.apache.lucene.classification.KNearestNeighborClassifier;
|
import org.apache.lucene.classification.KNearestNeighborClassifier;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.search.BooleanClause;
|
import org.apache.lucene.search.BooleanClause;
|
||||||
|
@ -54,7 +55,7 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi
|
||||||
/**
|
/**
|
||||||
* Creates a {@link KNearestNeighborClassifier}.
|
* Creates a {@link KNearestNeighborClassifier}.
|
||||||
*
|
*
|
||||||
* @param leafReader the reader on the index to be used for classification
|
* @param indexReader the reader on the index to be used for classification
|
||||||
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
|
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
|
||||||
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
|
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
|
||||||
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
|
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||||
|
@ -66,9 +67,9 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi
|
||||||
* @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer}
|
* @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer}
|
||||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||||
*/
|
*/
|
||||||
public KNearestNeighborDocumentClassifier(LeafReader leafReader, Similarity similarity, Query query, int k, int minDocsFreq,
|
public KNearestNeighborDocumentClassifier(IndexReader indexReader, Similarity similarity, Query query, int k, int minDocsFreq,
|
||||||
int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
||||||
super(leafReader,similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
|
super(indexReader, similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
|
||||||
this.field2analyzer = field2analyzer;
|
this.field2analyzer = field2analyzer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||||
import org.apache.lucene.classification.ClassificationResult;
|
import org.apache.lucene.classification.ClassificationResult;
|
||||||
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
|
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.IndexableField;
|
import org.apache.lucene.index.IndexableField;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.MultiFields;
|
import org.apache.lucene.index.MultiFields;
|
||||||
|
@ -59,15 +60,15 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
||||||
/**
|
/**
|
||||||
* Creates a new NaiveBayes classifier.
|
* Creates a new NaiveBayes classifier.
|
||||||
*
|
*
|
||||||
* @param leafReader the reader on the index to be used for classification
|
* @param indexReader the reader on the index to be used for classification
|
||||||
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
|
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||||
* if all the indexed docs should be used
|
* if all the indexed docs should be used
|
||||||
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
|
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
|
||||||
* as the returned class will be a token indexed for this field
|
* as the returned class will be a token indexed for this field
|
||||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||||
*/
|
*/
|
||||||
public SimpleNaiveBayesDocumentClassifier(LeafReader leafReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
public SimpleNaiveBayesDocumentClassifier(IndexReader indexReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
||||||
super(leafReader, null, query, classFieldName, textFieldNames);
|
super(indexReader, null, query, classFieldName, textFieldNames);
|
||||||
this.field2analyzer = field2analyzer;
|
this.field2analyzer = field2analyzer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,7 +113,7 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
||||||
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||||
Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
|
Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
|
||||||
Map<String, Float> fieldName2boost = new LinkedHashMap<>();
|
Map<String, Float> fieldName2boost = new LinkedHashMap<>();
|
||||||
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
Terms classes = MultiFields.getTerms(indexReader, classFieldName);
|
||||||
TermsEnum classesEnum = classes.iterator();
|
TermsEnum classesEnum = classes.iterator();
|
||||||
BytesRef c;
|
BytesRef c;
|
||||||
|
|
||||||
|
@ -225,10 +226,10 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
||||||
*/
|
*/
|
||||||
private double getTextTermFreqForClass(Term term, String fieldName) throws IOException {
|
private double getTextTermFreqForClass(Term term, String fieldName) throws IOException {
|
||||||
double avgNumberOfUniqueTerms;
|
double avgNumberOfUniqueTerms;
|
||||||
Terms terms = MultiFields.getTerms(leafReader, fieldName);
|
Terms terms = MultiFields.getTerms(indexReader, fieldName);
|
||||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||||
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||||
int docsWithC = leafReader.docFreq(term);
|
int docsWithC = indexReader.docFreq(term);
|
||||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -261,6 +262,6 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
||||||
}
|
}
|
||||||
|
|
||||||
private int docCount(Term term) throws IOException {
|
private int docCount(Term term) throws IOException {
|
||||||
return leafReader.docFreq(term);
|
return indexReader.docFreq(term);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ import java.util.concurrent.TimeoutException;
|
||||||
import org.apache.lucene.classification.ClassificationResult;
|
import org.apache.lucene.classification.ClassificationResult;
|
||||||
import org.apache.lucene.classification.Classifier;
|
import org.apache.lucene.classification.Classifier;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TermRangeQuery;
|
import org.apache.lucene.search.TermRangeQuery;
|
||||||
|
@ -50,9 +50,9 @@ public class ConfusionMatrixGenerator {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier},
|
* get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier},
|
||||||
* generated on the given {@link LeafReader}, class and text fields.
|
* generated on the given {@link IndexReader}, class and text fields.
|
||||||
*
|
*
|
||||||
* @param reader the {@link LeafReader} containing the index used for creating the {@link Classifier}
|
* @param reader the {@link IndexReader} containing the index used for creating the {@link Classifier}
|
||||||
* @param classifier the {@link Classifier} whose confusion matrix has to be generated
|
* @param classifier the {@link Classifier} whose confusion matrix has to be generated
|
||||||
* @param classFieldName the name of the Lucene field used as the classifier's output
|
* @param classFieldName the name of the Lucene field used as the classifier's output
|
||||||
* @param textFieldName the nome the Lucene field used as the classifier's input
|
* @param textFieldName the nome the Lucene field used as the classifier's input
|
||||||
|
@ -61,7 +61,7 @@ public class ConfusionMatrixGenerator {
|
||||||
* @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
|
* @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
|
||||||
* @throws IOException if problems occurr while reading the index or using the classifier
|
* @throws IOException if problems occurr while reading the index or using the classifier
|
||||||
*/
|
*/
|
||||||
public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName,
|
public static <T> ConfusionMatrix getConfusionMatrix(IndexReader reader, Classifier<T> classifier, String classFieldName,
|
||||||
String textFieldName, long timeoutMilliseconds) throws IOException {
|
String textFieldName, long timeoutMilliseconds) throws IOException {
|
||||||
|
|
||||||
ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-"));
|
ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-"));
|
||||||
|
|
|
@ -18,15 +18,17 @@ package org.apache.lucene.classification.utils;
|
||||||
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.FieldType;
|
import org.apache.lucene.document.FieldType;
|
||||||
import org.apache.lucene.document.TextField;
|
import org.apache.lucene.document.TextField;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.IndexWriter;
|
import org.apache.lucene.index.IndexWriter;
|
||||||
import org.apache.lucene.index.IndexWriterConfig;
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
import org.apache.lucene.index.IndexableField;
|
import org.apache.lucene.index.IndexableField;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.SortedDocValues;
|
import org.apache.lucene.index.SortedDocValues;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||||
|
@ -69,7 +71,7 @@ public class DatasetSplitter {
|
||||||
* @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used
|
* @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used
|
||||||
* @throws IOException if any writing operation fails on any of the indexes
|
* @throws IOException if any writing operation fails on any of the indexes
|
||||||
*/
|
*/
|
||||||
public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
|
public void split(IndexReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
|
||||||
Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException {
|
Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException {
|
||||||
|
|
||||||
// create IWs for train / test / cv IDXs
|
// create IWs for train / test / cv IDXs
|
||||||
|
@ -78,13 +80,15 @@ public class DatasetSplitter {
|
||||||
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
|
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
|
||||||
|
|
||||||
// get the exact no. of existing classes
|
// get the exact no. of existing classes
|
||||||
SortedDocValues classValues = originalIndex.getSortedDocValues(classFieldName);
|
int noOfClasses = 0;
|
||||||
if (classValues == null) {
|
for (LeafReaderContext leave : originalIndex.leaves()) {
|
||||||
throw new IllegalStateException("the classFieldName \"" + classFieldName + "\" must index sorted doc values");
|
SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName);
|
||||||
|
if (classValues == null) {
|
||||||
|
throw new IllegalStateException("the classFieldName \"" + classFieldName + "\" must index sorted doc values");
|
||||||
|
}
|
||||||
|
noOfClasses += classValues.getValueCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
int noOfClasses = classValues.getValueCount();
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
||||||
IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
|
IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
|
||||||
|
@ -150,7 +154,7 @@ public class DatasetSplitter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private Document createNewDoc(LeafReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
|
private Document createNewDoc(IndexReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
Document document = originalIndex.document(scoreDoc.doc);
|
Document document = originalIndex.document(scoreDoc.doc);
|
||||||
if (fieldNames != null && fieldNames.length > 0) {
|
if (fieldNames != null && fieldNames.length > 0) {
|
||||||
|
|
|
@ -27,8 +27,8 @@ import org.apache.lucene.classification.ClassificationResult;
|
||||||
import org.apache.lucene.classification.ClassificationTestBase;
|
import org.apache.lucene.classification.ClassificationTestBase;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.IndexWriterConfig;
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
import org.apache.lucene.index.LeafReader;
|
|
||||||
import org.apache.lucene.index.RandomIndexWriter;
|
import org.apache.lucene.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -47,7 +47,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
|
||||||
|
|
||||||
protected Analyzer analyzer;
|
protected Analyzer analyzer;
|
||||||
protected Map<String, Analyzer> field2analyzer;
|
protected Map<String, Analyzer> field2analyzer;
|
||||||
protected LeafReader leafReader;
|
protected IndexReader indexReader;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void init() throws IOException {
|
public void init() throws IOException {
|
||||||
|
@ -56,7 +56,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
|
||||||
field2analyzer.put(textFieldName, analyzer);
|
field2analyzer.put(textFieldName, analyzer);
|
||||||
field2analyzer.put(titleFieldName, analyzer);
|
field2analyzer.put(titleFieldName, analyzer);
|
||||||
field2analyzer.put(authorFieldName, analyzer);
|
field2analyzer.put(authorFieldName, analyzer);
|
||||||
leafReader = populateDocumentClassificationIndex(analyzer);
|
indexReader = populateDocumentClassificationIndex(analyzer);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected double checkCorrectDocumentClassification(DocumentClassifier<T> classifier, Document inputDoc, T expectedResult) throws Exception {
|
protected double checkCorrectDocumentClassification(DocumentClassifier<T> classifier, Document inputDoc, T expectedResult) throws Exception {
|
||||||
|
@ -68,7 +68,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
|
||||||
return score;
|
return score;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected LeafReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException {
|
protected IndexReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException {
|
||||||
indexWriter.close();
|
indexWriter.close();
|
||||||
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
|
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
|
||||||
indexWriter.commit();
|
indexWriter.commit();
|
||||||
|
@ -201,8 +201,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
|
||||||
indexWriter.addDocument(doc);
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
indexWriter.commit();
|
indexWriter.commit();
|
||||||
indexWriter.forceMerge(1);
|
return indexWriter.getReader();
|
||||||
return getOnlyLeafReader(indexWriter.getReader());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Document getVideoGameDocument() {
|
protected Document getVideoGameDocument() {
|
||||||
|
|
|
@ -33,15 +33,15 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
|
||||||
try {
|
try {
|
||||||
Document videoGameDocument = getVideoGameDocument();
|
Document videoGameDocument = getVideoGameDocument();
|
||||||
Document batmanDocument = getBatmanDocument();
|
Document batmanDocument = getBatmanDocument();
|
||||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||||
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
||||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||||
|
|
||||||
} finally {
|
} finally {
|
||||||
if (leafReader != null) {
|
if (indexReader != null) {
|
||||||
leafReader.close();
|
indexReader.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -51,18 +51,18 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
|
||||||
try {
|
try {
|
||||||
Document videoGameDocument = getVideoGameDocument();
|
Document videoGameDocument = getVideoGameDocument();
|
||||||
Document batmanDocument = getBatmanDocument();
|
Document batmanDocument = getBatmanDocument();
|
||||||
double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||||
assertEquals(1.0,score1,0);
|
assertEquals(1.0,score1,0);
|
||||||
double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||||
assertEquals(1.0,score2,0);
|
assertEquals(1.0,score2,0);
|
||||||
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
||||||
double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||||
assertEquals(1.0,score3,0);
|
assertEquals(1.0,score3,0);
|
||||||
double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||||
assertEquals(1.0,score4,0);
|
assertEquals(1.0,score4,0);
|
||||||
} finally {
|
} finally {
|
||||||
if (leafReader != null) {
|
if (indexReader != null) {
|
||||||
leafReader.close();
|
indexReader.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -70,12 +70,12 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
|
||||||
@Test
|
@Test
|
||||||
public void testBoostedDocumentClassification() throws Exception {
|
public void testBoostedDocumentClassification() throws Exception {
|
||||||
try {
|
try {
|
||||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
||||||
// considering without boost wrong classification will appear
|
// considering without boost wrong classification will appear
|
||||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT);
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT);
|
||||||
} finally {
|
} finally {
|
||||||
if (leafReader != null) {
|
if (indexReader != null) {
|
||||||
leafReader.close();
|
indexReader.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,11 +84,11 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
|
||||||
public void testBasicDocumentClassificationWithQuery() throws Exception {
|
public void testBasicDocumentClassificationWithQuery() throws Exception {
|
||||||
try {
|
try {
|
||||||
TermQuery query = new TermQuery(new Term(authorFieldName, "ign"));
|
TermQuery query = new TermQuery(new Term(authorFieldName, "ign"));
|
||||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT);
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT);
|
||||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT);
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT);
|
||||||
} finally {
|
} finally {
|
||||||
if (leafReader != null) {
|
if (indexReader != null) {
|
||||||
leafReader.close();
|
indexReader.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,14 +28,14 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati
|
||||||
@Test
|
@Test
|
||||||
public void testBasicDocumentClassification() throws Exception {
|
public void testBasicDocumentClassification() throws Exception {
|
||||||
try {
|
try {
|
||||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
||||||
|
|
||||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
||||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||||
} finally {
|
} finally {
|
||||||
if (leafReader != null) {
|
if (indexReader != null) {
|
||||||
leafReader.close();
|
indexReader.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -43,18 +43,18 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati
|
||||||
@Test
|
@Test
|
||||||
public void testBasicDocumentClassificationScore() throws Exception {
|
public void testBasicDocumentClassificationScore() throws Exception {
|
||||||
try {
|
try {
|
||||||
double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||||
assertEquals(0.88,score1,0.01);
|
assertEquals(0.88,score1,0.01);
|
||||||
double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
||||||
assertEquals(0.89,score2,0.01);
|
assertEquals(0.89,score2,0.01);
|
||||||
//taking in consideration only the text
|
//taking in consideration only the text
|
||||||
double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
||||||
assertEquals(0.55,score3,0.01);
|
assertEquals(0.55,score3,0.01);
|
||||||
double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||||
assertEquals(0.52,score4,0.01);
|
assertEquals(0.52,score4,0.01);
|
||||||
} finally {
|
} finally {
|
||||||
if (leafReader != null) {
|
if (indexReader != null) {
|
||||||
leafReader.close();
|
indexReader.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -62,12 +62,12 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati
|
||||||
@Test
|
@Test
|
||||||
public void testBoostedDocumentClassification() throws Exception {
|
public void testBoostedDocumentClassification() throws Exception {
|
||||||
try {
|
try {
|
||||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
||||||
// considering without boost wrong classification will appear
|
// considering without boost wrong classification will appear
|
||||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT);
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||||
} finally {
|
} finally {
|
||||||
if (leafReader != null) {
|
if (indexReader != null) {
|
||||||
leafReader.close();
|
indexReader.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import org.apache.lucene.classification.document.DocumentClassifier;
|
||||||
import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier;
|
import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier;
|
||||||
import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier;
|
import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.solr.common.SolrInputDocument;
|
import org.apache.solr.common.SolrInputDocument;
|
||||||
|
@ -60,7 +61,7 @@ class ClassificationUpdateProcessor
|
||||||
* @param schema schema
|
* @param schema schema
|
||||||
*/
|
*/
|
||||||
public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
|
public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
|
||||||
UpdateRequestProcessor next, LeafReader indexReader, IndexSchema schema) {
|
UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
|
||||||
super(next);
|
super(next);
|
||||||
this.classFieldName = classFieldName;
|
this.classFieldName = classFieldName;
|
||||||
Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();
|
Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package org.apache.solr.update.processor;
|
package org.apache.solr.update.processor;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.solr.common.SolrException;
|
import org.apache.solr.common.SolrException;
|
||||||
import org.apache.solr.common.params.SolrParams;
|
import org.apache.solr.common.params.SolrParams;
|
||||||
|
@ -109,8 +110,8 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
|
||||||
@Override
|
@Override
|
||||||
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
|
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
|
||||||
IndexSchema schema = req.getSchema();
|
IndexSchema schema = req.getSchema();
|
||||||
LeafReader leafReader = req.getSearcher().getLeafReader();
|
IndexReader indexReader = req.getSearcher().getIndexReader();
|
||||||
return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, leafReader, schema);
|
return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, indexReader, schema);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Reference in New Issue