LUCENE-7350 - Let classifiers be constructed from IndexReaders

This commit is contained in:
Tommaso Teofili 2016-06-21 13:10:34 +02:00
parent 6ef174f527
commit fcf4389d82
13 changed files with 105 additions and 96 deletions

View File

@ -26,6 +26,7 @@ import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiFields; import org.apache.lucene.index.MultiFields;
@ -67,7 +68,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
/** /**
* Creates a {@link BooleanPerceptronClassifier} * Creates a {@link BooleanPerceptronClassifier}
* *
* @param leafReader the reader on the index to be used for classification * @param indexReader the reader on the index to be used for classification
* @param analyzer an {@link Analyzer} used to analyze unseen text * @param analyzer an {@link Analyzer} used to analyze unseen text
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null} * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used * if all the indexed docs should be used
@ -78,9 +79,9 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
* @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field * @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field
* cannot be found * cannot be found
*/ */
public BooleanPerceptronClassifier(LeafReader leafReader, Analyzer analyzer, Query query, Integer batchSize, public BooleanPerceptronClassifier(IndexReader indexReader, Analyzer analyzer, Query query, Integer batchSize,
Double threshold, String classFieldName, String textFieldName) throws IOException { Double threshold, String classFieldName, String textFieldName) throws IOException {
this.textTerms = MultiFields.getTerms(leafReader, textFieldName); this.textTerms = MultiFields.getTerms(indexReader, textFieldName);
if (textTerms == null) { if (textTerms == null) {
throw new IOException("term vectors need to be available for field " + textFieldName); throw new IOException("term vectors need to be available for field " + textFieldName);
@ -91,7 +92,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
if (threshold == null || threshold == 0d) { if (threshold == null || threshold == 0d) {
// automatic assign a threshold // automatic assign a threshold
long sumDocFreq = leafReader.getSumDocFreq(textFieldName); long sumDocFreq = indexReader.getSumDocFreq(textFieldName);
if (sumDocFreq != -1) { if (sumDocFreq != -1) {
this.threshold = (double) sumDocFreq / 2d; this.threshold = (double) sumDocFreq / 2d;
} else { } else {
@ -113,7 +114,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
} }
updateFST(weights); updateFST(weights);
IndexSearcher indexSearcher = new IndexSearcher(leafReader); IndexSearcher indexSearcher = new IndexSearcher(indexReader);
int batchCount = 0; int batchCount = 0;
@ -140,7 +141,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
Boolean correctClass = Boolean.valueOf(classField.stringValue()); Boolean correctClass = Boolean.valueOf(classField.stringValue());
long modifier = correctClass.compareTo(assignedClass); long modifier = correctClass.compareTo(assignedClass);
if (modifier != 0) { if (modifier != 0) {
updateWeights(leafReader, scoreDoc.doc, assignedClass, updateWeights(indexReader, scoreDoc.doc, assignedClass,
weights, modifier, batchCount % batchSize == 0); weights, modifier, batchCount % batchSize == 0);
} }
batchCount++; batchCount++;
@ -149,13 +150,13 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
weights.clear(); // free memory while waiting for GC weights.clear(); // free memory while waiting for GC
} }
private void updateWeights(LeafReader leafReader, private void updateWeights(IndexReader indexReader,
int docId, Boolean assignedClass, SortedMap<String, Double> weights, int docId, Boolean assignedClass, SortedMap<String, Double> weights,
double modifier, boolean updateFST) throws IOException { double modifier, boolean updateFST) throws IOException {
TermsEnum cte = textTerms.iterator(); TermsEnum cte = textTerms.iterator();
// get the doc term vectors // get the doc term vectors
Terms terms = leafReader.getTermVector(docId, textFieldName); Terms terms = indexReader.getTermVector(docId, textFieldName);
if (terms == null) { if (terms == null) {
throw new IOException("term vectors must be stored for field " throw new IOException("term vectors must be stored for field "
@ -201,7 +202,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
@Override @Override
public ClassificationResult<Boolean> assignClass(String text) public ClassificationResult<Boolean> assignClass(String text)
throws IOException { throws IOException {
Long output = 0l; Long output = 0L;
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) { try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream CharTermAttribute charTermAttribute = tokenStream
.addAttribute(CharTermAttribute.class); .addAttribute(CharTermAttribute.class);

View File

@ -212,7 +212,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
// build the cache for the word // build the cache for the word
Map<String, Long> frequencyMap = new HashMap<>(); Map<String, Long> frequencyMap = new HashMap<>();
for (String textFieldName : textFieldNames) { for (String textFieldName : textFieldNames) {
TermsEnum termsEnum = leafReader.terms(textFieldName).iterator(); TermsEnum termsEnum = MultiFields.getTerms(indexReader, textFieldName).iterator();
while (termsEnum.next() != null) { while (termsEnum.next() != null) {
BytesRef term = termsEnum.term(); BytesRef term = termsEnum.term();
String termText = term.utf8ToString(); String termText = term.utf8ToString();
@ -229,7 +229,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
} }
// fill the class list // fill the class list
Terms terms = MultiFields.getTerms(leafReader, classFieldName); Terms terms = MultiFields.getTerms(indexReader, classFieldName);
TermsEnum termsEnum = terms.iterator(); TermsEnum termsEnum = terms.iterator();
while ((termsEnum.next()) != null) { while ((termsEnum.next()) != null) {
cclasses.add(BytesRef.deepCopyOf(termsEnum.term())); cclasses.add(BytesRef.deepCopyOf(termsEnum.term()));
@ -238,11 +238,11 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
for (BytesRef cclass : cclasses) { for (BytesRef cclass : cclasses) {
double avgNumberOfUniqueTerms = 0; double avgNumberOfUniqueTerms = 0;
for (String textFieldName : textFieldNames) { for (String textFieldName : textFieldNames) {
terms = MultiFields.getTerms(leafReader, textFieldName); terms = MultiFields.getTerms(indexReader, textFieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount();
} }
int docsWithC = leafReader.docFreq(new Term(classFieldName, cclass)); int docsWithC = indexReader.docFreq(new Term(classFieldName, cclass));
classTermFreq.put(cclass, avgNumberOfUniqueTerms * docsWithC); classTermFreq.put(cclass, avgNumberOfUniqueTerms * docsWithC);
} }
} }

View File

@ -26,6 +26,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
@ -82,7 +83,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
/** /**
* Creates a {@link KNearestNeighborClassifier}. * Creates a {@link KNearestNeighborClassifier}.
* *
* @param leafReader the reader on the index to be used for classification * @param indexReader the reader on the index to be used for classification
* @param analyzer an {@link Analyzer} used to analyze unseen text * @param analyzer an {@link Analyzer} used to analyze unseen text
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null} * @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity}) * (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
@ -94,14 +95,14 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
* @param classFieldName the name of the field used as the output for the classifier * @param classFieldName the name of the field used as the output for the classifier
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10 * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
*/ */
public KNearestNeighborClassifier(LeafReader leafReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq, public KNearestNeighborClassifier(IndexReader indexReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq,
int minTermFreq, String classFieldName, String... textFieldNames) { int minTermFreq, String classFieldName, String... textFieldNames) {
this.textFieldNames = textFieldNames; this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName; this.classFieldName = classFieldName;
this.mlt = new MoreLikeThis(leafReader); this.mlt = new MoreLikeThis(indexReader);
this.mlt.setAnalyzer(analyzer); this.mlt.setAnalyzer(analyzer);
this.mlt.setFieldNames(textFieldNames); this.mlt.setFieldNames(textFieldNames);
this.indexSearcher = new IndexSearcher(leafReader); this.indexSearcher = new IndexSearcher(indexReader);
if (similarity != null) { if (similarity != null) {
this.indexSearcher.setSimilarity(similarity); this.indexSearcher.setSimilarity(similarity);
} else { } else {

View File

@ -26,7 +26,7 @@ import java.util.List;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiFields; import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms; import org.apache.lucene.index.Terms;
@ -48,10 +48,10 @@ import org.apache.lucene.util.BytesRef;
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> { public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
/** /**
* {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s * {@link org.apache.lucene.index.IndexReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
* index * index
*/ */
protected final LeafReader leafReader; protected final IndexReader indexReader;
/** /**
* names of the fields to be used as input text * names of the fields to be used as input text
@ -81,7 +81,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
/** /**
* Creates a new NaiveBayes classifier. * Creates a new NaiveBayes classifier.
* *
* @param leafReader the reader on the index to be used for classification * @param indexReader the reader on the index to be used for classification
* @param analyzer an {@link Analyzer} used to analyze unseen text * @param analyzer an {@link Analyzer} used to analyze unseen text
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null} * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used * if all the indexed docs should be used
@ -89,9 +89,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* as the returned class will be a token indexed for this field * as the returned class will be a token indexed for this field
* @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field * @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
*/ */
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) { public SimpleNaiveBayesClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
this.leafReader = leafReader; this.indexReader = indexReader;
this.indexSearcher = new IndexSearcher(this.leafReader); this.indexSearcher = new IndexSearcher(this.indexReader);
this.textFieldNames = textFieldNames; this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName; this.classFieldName = classFieldName;
this.analyzer = analyzer; this.analyzer = analyzer;
@ -144,7 +144,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException { protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>(); List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
Terms classes = MultiFields.getTerms(leafReader, classFieldName); Terms classes = MultiFields.getTerms(indexReader, classFieldName);
if (classes != null) { if (classes != null) {
TermsEnum classesEnum = classes.iterator(); TermsEnum classesEnum = classes.iterator();
BytesRef next; BytesRef next;
@ -169,7 +169,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* @throws IOException if accessing to term vectors or search fails * @throws IOException if accessing to term vectors or search fails
*/ */
protected int countDocsWithClass() throws IOException { protected int countDocsWithClass() throws IOException {
Terms terms = MultiFields.getTerms(this.leafReader, this.classFieldName); Terms terms = MultiFields.getTerms(this.indexReader, this.classFieldName);
int docCount; int docCount;
if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector(); TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
@ -240,11 +240,11 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
private double getTextTermFreqForClass(Term term) throws IOException { private double getTextTermFreqForClass(Term term) throws IOException {
double avgNumberOfUniqueTerms = 0; double avgNumberOfUniqueTerms = 0;
for (String textFieldName : textFieldNames) { for (String textFieldName : textFieldNames) {
Terms terms = MultiFields.getTerms(leafReader, textFieldName); Terms terms = MultiFields.getTerms(indexReader, textFieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
} }
int docsWithC = leafReader.docFreq(term); int docsWithC = indexReader.docFreq(term);
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
} }
@ -277,7 +277,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
} }
private int docCount(Term term) throws IOException { private int docCount(Term term) throws IOException {
return leafReader.docFreq(term); return indexReader.docFreq(term);
} }
/** /**

View File

@ -27,6 +27,7 @@ import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.classification.ClassificationResult; import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.KNearestNeighborClassifier; import org.apache.lucene.classification.KNearestNeighborClassifier;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanClause;
@ -54,7 +55,7 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi
/** /**
* Creates a {@link KNearestNeighborClassifier}. * Creates a {@link KNearestNeighborClassifier}.
* *
* @param leafReader the reader on the index to be used for classification * @param indexReader the reader on the index to be used for classification
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null} * @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity}) * (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null} * @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
@ -66,9 +67,9 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi
* @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer} * @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer}
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10 * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
*/ */
public KNearestNeighborDocumentClassifier(LeafReader leafReader, Similarity similarity, Query query, int k, int minDocsFreq, public KNearestNeighborDocumentClassifier(IndexReader indexReader, Similarity similarity, Query query, int k, int minDocsFreq,
int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) { int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
super(leafReader,similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames); super(indexReader, similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
this.field2analyzer = field2analyzer; this.field2analyzer = field2analyzer;
} }

View File

@ -32,6 +32,7 @@ import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult; import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.SimpleNaiveBayesClassifier; import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiFields; import org.apache.lucene.index.MultiFields;
@ -59,15 +60,15 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
/** /**
* Creates a new NaiveBayes classifier. * Creates a new NaiveBayes classifier.
* *
* @param leafReader the reader on the index to be used for classification * @param indexReader the reader on the index to be used for classification
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null} * @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used * if all the indexed docs should be used
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed * @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
* as the returned class will be a token indexed for this field * as the returned class will be a token indexed for this field
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10 * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
*/ */
public SimpleNaiveBayesDocumentClassifier(LeafReader leafReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) { public SimpleNaiveBayesDocumentClassifier(IndexReader indexReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
super(leafReader, null, query, classFieldName, textFieldNames); super(indexReader, null, query, classFieldName, textFieldNames);
this.field2analyzer = field2analyzer; this.field2analyzer = field2analyzer;
} }
@ -112,7 +113,7 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>(); List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>(); Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
Map<String, Float> fieldName2boost = new LinkedHashMap<>(); Map<String, Float> fieldName2boost = new LinkedHashMap<>();
Terms classes = MultiFields.getTerms(leafReader, classFieldName); Terms classes = MultiFields.getTerms(indexReader, classFieldName);
TermsEnum classesEnum = classes.iterator(); TermsEnum classesEnum = classes.iterator();
BytesRef c; BytesRef c;
@ -225,10 +226,10 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
*/ */
private double getTextTermFreqForClass(Term term, String fieldName) throws IOException { private double getTextTermFreqForClass(Term term, String fieldName) throws IOException {
double avgNumberOfUniqueTerms; double avgNumberOfUniqueTerms;
Terms terms = MultiFields.getTerms(leafReader, fieldName); Terms terms = MultiFields.getTerms(indexReader, fieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
int docsWithC = leafReader.docFreq(term); int docsWithC = indexReader.docFreq(term);
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
} }
@ -261,6 +262,6 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
} }
private int docCount(Term term) throws IOException { private int docCount(Term term) throws IOException {
return leafReader.docFreq(term); return indexReader.docFreq(term);
} }
} }

View File

@ -31,7 +31,7 @@ import java.util.concurrent.TimeoutException;
import org.apache.lucene.classification.ClassificationResult; import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier; import org.apache.lucene.classification.Classifier;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermRangeQuery; import org.apache.lucene.search.TermRangeQuery;
@ -50,9 +50,9 @@ public class ConfusionMatrixGenerator {
/** /**
* get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier}, * get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier},
* generated on the given {@link LeafReader}, class and text fields. * generated on the given {@link IndexReader}, class and text fields.
* *
* @param reader the {@link LeafReader} containing the index used for creating the {@link Classifier} * @param reader the {@link IndexReader} containing the index used for creating the {@link Classifier}
* @param classifier the {@link Classifier} whose confusion matrix has to be generated * @param classifier the {@link Classifier} whose confusion matrix has to be generated
* @param classFieldName the name of the Lucene field used as the classifier's output * @param classFieldName the name of the Lucene field used as the classifier's output
* @param textFieldName the nome the Lucene field used as the classifier's input * @param textFieldName the nome the Lucene field used as the classifier's input
@ -61,7 +61,7 @@ public class ConfusionMatrixGenerator {
* @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} * @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
* @throws IOException if problems occurr while reading the index or using the classifier * @throws IOException if problems occurr while reading the index or using the classifier
*/ */
public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName, public static <T> ConfusionMatrix getConfusionMatrix(IndexReader reader, Classifier<T> classifier, String classFieldName,
String textFieldName, long timeoutMilliseconds) throws IOException { String textFieldName, long timeoutMilliseconds) throws IOException {
ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-")); ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-"));

View File

@ -18,15 +18,17 @@ package org.apache.lucene.classification.utils;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType; import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField; import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
@ -69,7 +71,7 @@ public class DatasetSplitter {
* @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used * @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used
* @throws IOException if any writing operation fails on any of the indexes * @throws IOException if any writing operation fails on any of the indexes
*/ */
public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, public void split(IndexReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException { Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException {
// create IWs for train / test / cv IDXs // create IWs for train / test / cv IDXs
@ -78,13 +80,15 @@ public class DatasetSplitter {
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer)); IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
// get the exact no. of existing classes // get the exact no. of existing classes
SortedDocValues classValues = originalIndex.getSortedDocValues(classFieldName); int noOfClasses = 0;
if (classValues == null) { for (LeafReaderContext leave : originalIndex.leaves()) {
throw new IllegalStateException("the classFieldName \"" + classFieldName + "\" must index sorted doc values"); SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName);
if (classValues == null) {
throw new IllegalStateException("the classFieldName \"" + classFieldName + "\" must index sorted doc values");
}
noOfClasses += classValues.getValueCount();
} }
int noOfClasses = classValues.getValueCount();
try { try {
IndexSearcher indexSearcher = new IndexSearcher(originalIndex); IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
@ -150,7 +154,7 @@ public class DatasetSplitter {
} }
} }
private Document createNewDoc(LeafReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException { private Document createNewDoc(IndexReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
Document doc = new Document(); Document doc = new Document();
Document document = originalIndex.document(scoreDoc.doc); Document document = originalIndex.document(scoreDoc.doc);
if (fieldNames != null && fieldNames.length > 0) { if (fieldNames != null && fieldNames.length > 0) {

View File

@ -27,8 +27,8 @@ import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.ClassificationTestBase; import org.apache.lucene.classification.ClassificationTestBase;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.junit.Before; import org.junit.Before;
@ -47,7 +47,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
protected Analyzer analyzer; protected Analyzer analyzer;
protected Map<String, Analyzer> field2analyzer; protected Map<String, Analyzer> field2analyzer;
protected LeafReader leafReader; protected IndexReader indexReader;
@Before @Before
public void init() throws IOException { public void init() throws IOException {
@ -56,7 +56,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
field2analyzer.put(textFieldName, analyzer); field2analyzer.put(textFieldName, analyzer);
field2analyzer.put(titleFieldName, analyzer); field2analyzer.put(titleFieldName, analyzer);
field2analyzer.put(authorFieldName, analyzer); field2analyzer.put(authorFieldName, analyzer);
leafReader = populateDocumentClassificationIndex(analyzer); indexReader = populateDocumentClassificationIndex(analyzer);
} }
protected double checkCorrectDocumentClassification(DocumentClassifier<T> classifier, Document inputDoc, T expectedResult) throws Exception { protected double checkCorrectDocumentClassification(DocumentClassifier<T> classifier, Document inputDoc, T expectedResult) throws Exception {
@ -68,7 +68,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
return score; return score;
} }
protected LeafReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException { protected IndexReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException {
indexWriter.close(); indexWriter.close();
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE)); indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
indexWriter.commit(); indexWriter.commit();
@ -201,8 +201,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
indexWriter.addDocument(doc); indexWriter.addDocument(doc);
indexWriter.commit(); indexWriter.commit();
indexWriter.forceMerge(1); return indexWriter.getReader();
return getOnlyLeafReader(indexWriter.getReader());
} }
protected Document getVideoGameDocument() { protected Document getVideoGameDocument() {

View File

@ -33,15 +33,15 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
try { try {
Document videoGameDocument = getVideoGameDocument(); Document videoGameDocument = getVideoGameDocument();
Document batmanDocument = getBatmanDocument(); Document batmanDocument = getBatmanDocument();
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT); checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT); checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
// considering only the text we have wrong classification because the text was ambiguos on purpose // considering only the text we have wrong classification because the text was ambiguos on purpose
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT); checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT); checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
} finally { } finally {
if (leafReader != null) { if (indexReader != null) {
leafReader.close(); indexReader.close();
} }
} }
} }
@ -51,18 +51,18 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
try { try {
Document videoGameDocument = getVideoGameDocument(); Document videoGameDocument = getVideoGameDocument();
Document batmanDocument = getBatmanDocument(); Document batmanDocument = getBatmanDocument();
double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT); double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
assertEquals(1.0,score1,0); assertEquals(1.0,score1,0);
double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT); double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
assertEquals(1.0,score2,0); assertEquals(1.0,score2,0);
// considering only the text we have wrong classification because the text was ambiguos on purpose // considering only the text we have wrong classification because the text was ambiguos on purpose
double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT); double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
assertEquals(1.0,score3,0); assertEquals(1.0,score3,0);
double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT); double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
assertEquals(1.0,score4,0); assertEquals(1.0,score4,0);
} finally { } finally {
if (leafReader != null) { if (indexReader != null) {
leafReader.close(); indexReader.close();
} }
} }
} }
@ -70,12 +70,12 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
@Test @Test
public void testBoostedDocumentClassification() throws Exception { public void testBoostedDocumentClassification() throws Exception {
try { try {
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT); checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
// considering without boost wrong classification will appear // considering without boost wrong classification will appear
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT); checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT);
} finally { } finally {
if (leafReader != null) { if (indexReader != null) {
leafReader.close(); indexReader.close();
} }
} }
} }
@ -84,11 +84,11 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
public void testBasicDocumentClassificationWithQuery() throws Exception { public void testBasicDocumentClassificationWithQuery() throws Exception {
try { try {
TermQuery query = new TermQuery(new Term(authorFieldName, "ign")); TermQuery query = new TermQuery(new Term(authorFieldName, "ign"));
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT); checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT);
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT); checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT);
} finally { } finally {
if (leafReader != null) { if (indexReader != null) {
leafReader.close(); indexReader.close();
} }
} }
} }

View File

@ -28,14 +28,14 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati
@Test @Test
public void testBasicDocumentClassification() throws Exception { public void testBasicDocumentClassification() throws Exception {
try { try {
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT); checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT); checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT); checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT); checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
} finally { } finally {
if (leafReader != null) { if (indexReader != null) {
leafReader.close(); indexReader.close();
} }
} }
} }
@ -43,18 +43,18 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati
@Test @Test
public void testBasicDocumentClassificationScore() throws Exception { public void testBasicDocumentClassificationScore() throws Exception {
try { try {
double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT); double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
assertEquals(0.88,score1,0.01); assertEquals(0.88,score1,0.01);
double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT); double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
assertEquals(0.89,score2,0.01); assertEquals(0.89,score2,0.01);
//taking in consideration only the text //taking in consideration only the text
double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT); double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
assertEquals(0.55,score3,0.01); assertEquals(0.55,score3,0.01);
double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT); double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
assertEquals(0.52,score4,0.01); assertEquals(0.52,score4,0.01);
} finally { } finally {
if (leafReader != null) { if (indexReader != null) {
leafReader.close(); indexReader.close();
} }
} }
} }
@ -62,12 +62,12 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati
@Test @Test
public void testBoostedDocumentClassification() throws Exception { public void testBoostedDocumentClassification() throws Exception {
try { try {
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT); checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
// considering without boost wrong classification will appear // considering without boost wrong classification will appear
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT); checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT);
} finally { } finally {
if (leafReader != null) { if (indexReader != null) {
leafReader.close(); indexReader.close();
} }
} }
} }

View File

@ -10,6 +10,7 @@ import org.apache.lucene.classification.document.DocumentClassifier;
import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier; import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier;
import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier; import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.SolrInputDocument; import org.apache.solr.common.SolrInputDocument;
@ -60,7 +61,7 @@ class ClassificationUpdateProcessor
* @param schema schema * @param schema schema
*/ */
public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm, public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
UpdateRequestProcessor next, LeafReader indexReader, IndexSchema schema) { UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
super(next); super(next);
this.classFieldName = classFieldName; this.classFieldName = classFieldName;
Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>(); Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();

View File

@ -1,5 +1,6 @@
package org.apache.solr.update.processor; package org.apache.solr.update.processor;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.params.SolrParams;
@ -109,8 +110,8 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
@Override @Override
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) { public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
IndexSchema schema = req.getSchema(); IndexSchema schema = req.getSchema();
LeafReader leafReader = req.getSearcher().getLeafReader(); IndexReader indexReader = req.getSearcher().getIndexReader();
return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, leafReader, schema); return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, indexReader, schema);
} }
/** /**