mirror of https://github.com/apache/lucene.git
LUCENE-7350 - Let classifiers be constructed from IndexReaders
This commit is contained in:
parent
6ef174f527
commit
fcf4389d82
|
@ -26,6 +26,7 @@ import org.apache.lucene.analysis.Analyzer;
|
|||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexableField;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
|
@ -67,7 +68,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
/**
|
||||
* Creates a {@link BooleanPerceptronClassifier}
|
||||
*
|
||||
* @param leafReader the reader on the index to be used for classification
|
||||
* @param indexReader the reader on the index to be used for classification
|
||||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||
* if all the indexed docs should be used
|
||||
|
@ -78,9 +79,9 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
* @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field
|
||||
* cannot be found
|
||||
*/
|
||||
public BooleanPerceptronClassifier(LeafReader leafReader, Analyzer analyzer, Query query, Integer batchSize,
|
||||
public BooleanPerceptronClassifier(IndexReader indexReader, Analyzer analyzer, Query query, Integer batchSize,
|
||||
Double threshold, String classFieldName, String textFieldName) throws IOException {
|
||||
this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
|
||||
this.textTerms = MultiFields.getTerms(indexReader, textFieldName);
|
||||
|
||||
if (textTerms == null) {
|
||||
throw new IOException("term vectors need to be available for field " + textFieldName);
|
||||
|
@ -91,7 +92,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
|
||||
if (threshold == null || threshold == 0d) {
|
||||
// automatic assign a threshold
|
||||
long sumDocFreq = leafReader.getSumDocFreq(textFieldName);
|
||||
long sumDocFreq = indexReader.getSumDocFreq(textFieldName);
|
||||
if (sumDocFreq != -1) {
|
||||
this.threshold = (double) sumDocFreq / 2d;
|
||||
} else {
|
||||
|
@ -113,7 +114,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
}
|
||||
updateFST(weights);
|
||||
|
||||
IndexSearcher indexSearcher = new IndexSearcher(leafReader);
|
||||
IndexSearcher indexSearcher = new IndexSearcher(indexReader);
|
||||
|
||||
int batchCount = 0;
|
||||
|
||||
|
@ -140,7 +141,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
Boolean correctClass = Boolean.valueOf(classField.stringValue());
|
||||
long modifier = correctClass.compareTo(assignedClass);
|
||||
if (modifier != 0) {
|
||||
updateWeights(leafReader, scoreDoc.doc, assignedClass,
|
||||
updateWeights(indexReader, scoreDoc.doc, assignedClass,
|
||||
weights, modifier, batchCount % batchSize == 0);
|
||||
}
|
||||
batchCount++;
|
||||
|
@ -149,13 +150,13 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
weights.clear(); // free memory while waiting for GC
|
||||
}
|
||||
|
||||
private void updateWeights(LeafReader leafReader,
|
||||
private void updateWeights(IndexReader indexReader,
|
||||
int docId, Boolean assignedClass, SortedMap<String, Double> weights,
|
||||
double modifier, boolean updateFST) throws IOException {
|
||||
TermsEnum cte = textTerms.iterator();
|
||||
|
||||
// get the doc term vectors
|
||||
Terms terms = leafReader.getTermVector(docId, textFieldName);
|
||||
Terms terms = indexReader.getTermVector(docId, textFieldName);
|
||||
|
||||
if (terms == null) {
|
||||
throw new IOException("term vectors must be stored for field "
|
||||
|
@ -201,7 +202,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
@Override
|
||||
public ClassificationResult<Boolean> assignClass(String text)
|
||||
throws IOException {
|
||||
Long output = 0l;
|
||||
Long output = 0L;
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
||||
CharTermAttribute charTermAttribute = tokenStream
|
||||
.addAttribute(CharTermAttribute.class);
|
||||
|
|
|
@ -212,7 +212,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
|||
// build the cache for the word
|
||||
Map<String, Long> frequencyMap = new HashMap<>();
|
||||
for (String textFieldName : textFieldNames) {
|
||||
TermsEnum termsEnum = leafReader.terms(textFieldName).iterator();
|
||||
TermsEnum termsEnum = MultiFields.getTerms(indexReader, textFieldName).iterator();
|
||||
while (termsEnum.next() != null) {
|
||||
BytesRef term = termsEnum.term();
|
||||
String termText = term.utf8ToString();
|
||||
|
@ -229,7 +229,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
|||
}
|
||||
|
||||
// fill the class list
|
||||
Terms terms = MultiFields.getTerms(leafReader, classFieldName);
|
||||
Terms terms = MultiFields.getTerms(indexReader, classFieldName);
|
||||
TermsEnum termsEnum = terms.iterator();
|
||||
while ((termsEnum.next()) != null) {
|
||||
cclasses.add(BytesRef.deepCopyOf(termsEnum.term()));
|
||||
|
@ -238,11 +238,11 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
|||
for (BytesRef cclass : cclasses) {
|
||||
double avgNumberOfUniqueTerms = 0;
|
||||
for (String textFieldName : textFieldNames) {
|
||||
terms = MultiFields.getTerms(leafReader, textFieldName);
|
||||
terms = MultiFields.getTerms(indexReader, textFieldName);
|
||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount();
|
||||
}
|
||||
int docsWithC = leafReader.docFreq(new Term(classFieldName, cclass));
|
||||
int docsWithC = indexReader.docFreq(new Term(classFieldName, cclass));
|
||||
classTermFreq.put(cclass, avgNumberOfUniqueTerms * docsWithC);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.IndexableField;
|
||||
import org.apache.lucene.index.Term;
|
||||
|
@ -82,7 +83,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
/**
|
||||
* Creates a {@link KNearestNeighborClassifier}.
|
||||
*
|
||||
* @param leafReader the reader on the index to be used for classification
|
||||
* @param indexReader the reader on the index to be used for classification
|
||||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
|
||||
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
|
||||
|
@ -94,14 +95,14 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
* @param classFieldName the name of the field used as the output for the classifier
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||
*/
|
||||
public KNearestNeighborClassifier(LeafReader leafReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq,
|
||||
public KNearestNeighborClassifier(IndexReader indexReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq,
|
||||
int minTermFreq, String classFieldName, String... textFieldNames) {
|
||||
this.textFieldNames = textFieldNames;
|
||||
this.classFieldName = classFieldName;
|
||||
this.mlt = new MoreLikeThis(leafReader);
|
||||
this.mlt = new MoreLikeThis(indexReader);
|
||||
this.mlt.setAnalyzer(analyzer);
|
||||
this.mlt.setFieldNames(textFieldNames);
|
||||
this.indexSearcher = new IndexSearcher(leafReader);
|
||||
this.indexSearcher = new IndexSearcher(indexReader);
|
||||
if (similarity != null) {
|
||||
this.indexSearcher.setSimilarity(similarity);
|
||||
} else {
|
||||
|
|
|
@ -26,7 +26,7 @@ import java.util.List;
|
|||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.Terms;
|
||||
|
@ -48,10 +48,10 @@ import org.apache.lucene.util.BytesRef;
|
|||
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||
|
||||
/**
|
||||
* {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
|
||||
* {@link org.apache.lucene.index.IndexReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
|
||||
* index
|
||||
*/
|
||||
protected final LeafReader leafReader;
|
||||
protected final IndexReader indexReader;
|
||||
|
||||
/**
|
||||
* names of the fields to be used as input text
|
||||
|
@ -81,7 +81,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
/**
|
||||
* Creates a new NaiveBayes classifier.
|
||||
*
|
||||
* @param leafReader the reader on the index to be used for classification
|
||||
* @param indexReader the reader on the index to be used for classification
|
||||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||
* if all the indexed docs should be used
|
||||
|
@ -89,9 +89,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
* as the returned class will be a token indexed for this field
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
|
||||
*/
|
||||
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
||||
this.leafReader = leafReader;
|
||||
this.indexSearcher = new IndexSearcher(this.leafReader);
|
||||
public SimpleNaiveBayesClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
||||
this.indexReader = indexReader;
|
||||
this.indexSearcher = new IndexSearcher(this.indexReader);
|
||||
this.textFieldNames = textFieldNames;
|
||||
this.classFieldName = classFieldName;
|
||||
this.analyzer = analyzer;
|
||||
|
@ -144,7 +144,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||
|
||||
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
||||
Terms classes = MultiFields.getTerms(indexReader, classFieldName);
|
||||
if (classes != null) {
|
||||
TermsEnum classesEnum = classes.iterator();
|
||||
BytesRef next;
|
||||
|
@ -169,7 +169,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
* @throws IOException if accessing to term vectors or search fails
|
||||
*/
|
||||
protected int countDocsWithClass() throws IOException {
|
||||
Terms terms = MultiFields.getTerms(this.leafReader, this.classFieldName);
|
||||
Terms terms = MultiFields.getTerms(this.indexReader, this.classFieldName);
|
||||
int docCount;
|
||||
if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount
|
||||
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
||||
|
@ -240,11 +240,11 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
private double getTextTermFreqForClass(Term term) throws IOException {
|
||||
double avgNumberOfUniqueTerms = 0;
|
||||
for (String textFieldName : textFieldNames) {
|
||||
Terms terms = MultiFields.getTerms(leafReader, textFieldName);
|
||||
Terms terms = MultiFields.getTerms(indexReader, textFieldName);
|
||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||
}
|
||||
int docsWithC = leafReader.docFreq(term);
|
||||
int docsWithC = indexReader.docFreq(term);
|
||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
||||
}
|
||||
|
||||
|
@ -277,7 +277,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
}
|
||||
|
||||
private int docCount(Term term) throws IOException {
|
||||
return leafReader.docFreq(term);
|
||||
return indexReader.docFreq(term);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.lucene.analysis.Analyzer;
|
|||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.classification.KNearestNeighborClassifier;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
|
@ -54,7 +55,7 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi
|
|||
/**
|
||||
* Creates a {@link KNearestNeighborClassifier}.
|
||||
*
|
||||
* @param leafReader the reader on the index to be used for classification
|
||||
* @param indexReader the reader on the index to be used for classification
|
||||
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
|
||||
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
|
||||
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||
|
@ -66,9 +67,9 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi
|
|||
* @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer}
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||
*/
|
||||
public KNearestNeighborDocumentClassifier(LeafReader leafReader, Similarity similarity, Query query, int k, int minDocsFreq,
|
||||
public KNearestNeighborDocumentClassifier(IndexReader indexReader, Similarity similarity, Query query, int k, int minDocsFreq,
|
||||
int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
||||
super(leafReader,similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
|
||||
super(indexReader, similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
|
||||
this.field2analyzer = field2analyzer;
|
||||
}
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
|||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexableField;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
|
@ -59,15 +60,15 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
|||
/**
|
||||
* Creates a new NaiveBayes classifier.
|
||||
*
|
||||
* @param leafReader the reader on the index to be used for classification
|
||||
* @param indexReader the reader on the index to be used for classification
|
||||
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||
* if all the indexed docs should be used
|
||||
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
|
||||
* as the returned class will be a token indexed for this field
|
||||
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||
*/
|
||||
public SimpleNaiveBayesDocumentClassifier(LeafReader leafReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
||||
super(leafReader, null, query, classFieldName, textFieldNames);
|
||||
public SimpleNaiveBayesDocumentClassifier(IndexReader indexReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
||||
super(indexReader, null, query, classFieldName, textFieldNames);
|
||||
this.field2analyzer = field2analyzer;
|
||||
}
|
||||
|
||||
|
@ -112,7 +113,7 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
|||
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||
Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
|
||||
Map<String, Float> fieldName2boost = new LinkedHashMap<>();
|
||||
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
||||
Terms classes = MultiFields.getTerms(indexReader, classFieldName);
|
||||
TermsEnum classesEnum = classes.iterator();
|
||||
BytesRef c;
|
||||
|
||||
|
@ -225,10 +226,10 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
|||
*/
|
||||
private double getTextTermFreqForClass(Term term, String fieldName) throws IOException {
|
||||
double avgNumberOfUniqueTerms;
|
||||
Terms terms = MultiFields.getTerms(leafReader, fieldName);
|
||||
Terms terms = MultiFields.getTerms(indexReader, fieldName);
|
||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||
int docsWithC = leafReader.docFreq(term);
|
||||
int docsWithC = indexReader.docFreq(term);
|
||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
||||
}
|
||||
|
||||
|
@ -261,6 +262,6 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
|
|||
}
|
||||
|
||||
private int docCount(Term term) throws IOException {
|
||||
return leafReader.docFreq(term);
|
||||
return indexReader.docFreq(term);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ import java.util.concurrent.TimeoutException;
|
|||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.classification.Classifier;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TermRangeQuery;
|
||||
|
@ -50,9 +50,9 @@ public class ConfusionMatrixGenerator {
|
|||
|
||||
/**
|
||||
* get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier},
|
||||
* generated on the given {@link LeafReader}, class and text fields.
|
||||
* generated on the given {@link IndexReader}, class and text fields.
|
||||
*
|
||||
* @param reader the {@link LeafReader} containing the index used for creating the {@link Classifier}
|
||||
* @param reader the {@link IndexReader} containing the index used for creating the {@link Classifier}
|
||||
* @param classifier the {@link Classifier} whose confusion matrix has to be generated
|
||||
* @param classFieldName the name of the Lucene field used as the classifier's output
|
||||
* @param textFieldName the nome the Lucene field used as the classifier's input
|
||||
|
@ -61,7 +61,7 @@ public class ConfusionMatrixGenerator {
|
|||
* @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
|
||||
* @throws IOException if problems occurr while reading the index or using the classifier
|
||||
*/
|
||||
public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName,
|
||||
public static <T> ConfusionMatrix getConfusionMatrix(IndexReader reader, Classifier<T> classifier, String classFieldName,
|
||||
String textFieldName, long timeoutMilliseconds) throws IOException {
|
||||
|
||||
ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-"));
|
||||
|
|
|
@ -18,15 +18,17 @@ package org.apache.lucene.classification.utils;
|
|||
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.TextField;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.IndexableField;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.SortedDocValues;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
|
@ -69,7 +71,7 @@ public class DatasetSplitter {
|
|||
* @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used
|
||||
* @throws IOException if any writing operation fails on any of the indexes
|
||||
*/
|
||||
public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
|
||||
public void split(IndexReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
|
||||
Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException {
|
||||
|
||||
// create IWs for train / test / cv IDXs
|
||||
|
@ -78,12 +80,14 @@ public class DatasetSplitter {
|
|||
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
|
||||
|
||||
// get the exact no. of existing classes
|
||||
SortedDocValues classValues = originalIndex.getSortedDocValues(classFieldName);
|
||||
int noOfClasses = 0;
|
||||
for (LeafReaderContext leave : originalIndex.leaves()) {
|
||||
SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName);
|
||||
if (classValues == null) {
|
||||
throw new IllegalStateException("the classFieldName \"" + classFieldName + "\" must index sorted doc values");
|
||||
}
|
||||
|
||||
int noOfClasses = classValues.getValueCount();
|
||||
noOfClasses += classValues.getValueCount();
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
|
@ -150,7 +154,7 @@ public class DatasetSplitter {
|
|||
}
|
||||
}
|
||||
|
||||
private Document createNewDoc(LeafReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
|
||||
private Document createNewDoc(IndexReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
|
||||
Document doc = new Document();
|
||||
Document document = originalIndex.document(scoreDoc.doc);
|
||||
if (fieldNames != null && fieldNames.length > 0) {
|
||||
|
|
|
@ -27,8 +27,8 @@ import org.apache.lucene.classification.ClassificationResult;
|
|||
import org.apache.lucene.classification.ClassificationTestBase;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Before;
|
||||
|
@ -47,7 +47,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
|
|||
|
||||
protected Analyzer analyzer;
|
||||
protected Map<String, Analyzer> field2analyzer;
|
||||
protected LeafReader leafReader;
|
||||
protected IndexReader indexReader;
|
||||
|
||||
@Before
|
||||
public void init() throws IOException {
|
||||
|
@ -56,7 +56,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
|
|||
field2analyzer.put(textFieldName, analyzer);
|
||||
field2analyzer.put(titleFieldName, analyzer);
|
||||
field2analyzer.put(authorFieldName, analyzer);
|
||||
leafReader = populateDocumentClassificationIndex(analyzer);
|
||||
indexReader = populateDocumentClassificationIndex(analyzer);
|
||||
}
|
||||
|
||||
protected double checkCorrectDocumentClassification(DocumentClassifier<T> classifier, Document inputDoc, T expectedResult) throws Exception {
|
||||
|
@ -68,7 +68,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
|
|||
return score;
|
||||
}
|
||||
|
||||
protected LeafReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException {
|
||||
protected IndexReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException {
|
||||
indexWriter.close();
|
||||
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
|
||||
indexWriter.commit();
|
||||
|
@ -201,8 +201,7 @@ public abstract class DocumentClassificationTestBase<T> extends ClassificationTe
|
|||
indexWriter.addDocument(doc);
|
||||
|
||||
indexWriter.commit();
|
||||
indexWriter.forceMerge(1);
|
||||
return getOnlyLeafReader(indexWriter.getReader());
|
||||
return indexWriter.getReader();
|
||||
}
|
||||
|
||||
protected Document getVideoGameDocument() {
|
||||
|
|
|
@ -33,15 +33,15 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
|
|||
try {
|
||||
Document videoGameDocument = getVideoGameDocument();
|
||||
Document batmanDocument = getBatmanDocument();
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
if (indexReader != null) {
|
||||
indexReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -51,18 +51,18 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
|
|||
try {
|
||||
Document videoGameDocument = getVideoGameDocument();
|
||||
Document batmanDocument = getBatmanDocument();
|
||||
double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||
double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||
assertEquals(1.0,score1,0);
|
||||
double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||
double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||
assertEquals(1.0,score2,0);
|
||||
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
||||
double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||
double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||
assertEquals(1.0,score3,0);
|
||||
double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||
double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||
assertEquals(1.0,score4,0);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
if (indexReader != null) {
|
||||
indexReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -70,12 +70,12 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
|
|||
@Test
|
||||
public void testBoostedDocumentClassification() throws Exception {
|
||||
try {
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
||||
// considering without boost wrong classification will appear
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
if (indexReader != null) {
|
||||
indexReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -84,11 +84,11 @@ public class KNearestNeighborDocumentClassifierTest extends DocumentClassificati
|
|||
public void testBasicDocumentClassificationWithQuery() throws Exception {
|
||||
try {
|
||||
TermQuery query = new TermQuery(new Term(authorFieldName, "ign"));
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT);
|
||||
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(indexReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
if (indexReader != null) {
|
||||
indexReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,14 +28,14 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati
|
|||
@Test
|
||||
public void testBasicDocumentClassification() throws Exception {
|
||||
try {
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
||||
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
if (indexReader != null) {
|
||||
indexReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -43,18 +43,18 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati
|
|||
@Test
|
||||
public void testBasicDocumentClassificationScore() throws Exception {
|
||||
try {
|
||||
double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
assertEquals(0.88,score1,0.01);
|
||||
double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
||||
double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
||||
assertEquals(0.89,score2,0.01);
|
||||
//taking in consideration only the text
|
||||
double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
||||
double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
||||
assertEquals(0.55,score3,0.01);
|
||||
double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
assertEquals(0.52,score4,0.01);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
if (indexReader != null) {
|
||||
indexReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -62,12 +62,12 @@ public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificati
|
|||
@Test
|
||||
public void testBoostedDocumentClassification() throws Exception {
|
||||
try {
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
||||
// considering without boost wrong classification will appear
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(indexReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||
} finally {
|
||||
if (leafReader != null) {
|
||||
leafReader.close();
|
||||
if (indexReader != null) {
|
||||
indexReader.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.apache.lucene.classification.document.DocumentClassifier;
|
|||
import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier;
|
||||
import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.solr.common.SolrInputDocument;
|
||||
|
@ -60,7 +61,7 @@ class ClassificationUpdateProcessor
|
|||
* @param schema schema
|
||||
*/
|
||||
public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
|
||||
UpdateRequestProcessor next, LeafReader indexReader, IndexSchema schema) {
|
||||
UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
|
||||
super(next);
|
||||
this.classFieldName = classFieldName;
|
||||
Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package org.apache.solr.update.processor;
|
||||
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
|
@ -109,8 +110,8 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
|
|||
@Override
|
||||
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
|
||||
IndexSchema schema = req.getSchema();
|
||||
LeafReader leafReader = req.getSearcher().getLeafReader();
|
||||
return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, leafReader, schema);
|
||||
IndexReader indexReader = req.getSearcher().getIndexReader();
|
||||
return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, indexReader, schema);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
Loading…
Reference in New Issue