mirror of https://github.com/apache/lucene.git
LUCENE-5311 - added support for training using multiple content fields for knn and naive bayes
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1538205 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
24279e5c6a
commit
b4a343d6ba
|
@ -191,6 +191,11 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
weights.clear(); // free memory while waiting for GC
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
|
||||
}
|
||||
|
||||
private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse,
|
||||
int docId, Boolean assignedClass, SortedMap<String,Double> weights,
|
||||
double modifier, boolean updateFST) throws IOException {
|
||||
|
|
|
@ -22,8 +22,8 @@ package org.apache.lucene.classification;
|
|||
*/
|
||||
public class ClassificationResult<T> {
|
||||
|
||||
private T assignedClass;
|
||||
private double score;
|
||||
private final T assignedClass;
|
||||
private final double score;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
|
|
|
@ -60,4 +60,16 @@ public interface Classifier<T> {
|
|||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
||||
throws IOException;
|
||||
|
||||
/**
|
||||
* Train the classifier using the underlying Lucene index
|
||||
* @param atomicReader the reader to use to access the Lucene index
|
||||
* @param textFieldNames the names of the fields to be used to compare documents
|
||||
* @param classFieldName the name of the field containing the class assigned to documents
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @param query the query to filter which documents use for training
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
|
||||
throws IOException;
|
||||
|
||||
}
|
||||
|
|
|
@ -41,10 +41,10 @@ import java.util.Map;
|
|||
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||
|
||||
private MoreLikeThis mlt;
|
||||
private String textFieldName;
|
||||
private String[] textFieldNames;
|
||||
private String classFieldName;
|
||||
private IndexSearcher indexSearcher;
|
||||
private int k;
|
||||
private final int k;
|
||||
private Query query;
|
||||
|
||||
/**
|
||||
|
@ -65,14 +65,17 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
throw new IOException("You must first call Classifier#train");
|
||||
}
|
||||
Query q;
|
||||
BooleanQuery mltQuery = new BooleanQuery();
|
||||
for (String textFieldName : textFieldNames) {
|
||||
mltQuery.add(new BooleanClause(mlt.like(new StringReader(text), textFieldName), BooleanClause.Occur.SHOULD));
|
||||
}
|
||||
if (query != null) {
|
||||
Query mltQuery = mlt.like(new StringReader(text), textFieldName);
|
||||
BooleanQuery bq = new BooleanQuery();
|
||||
bq.add(query, BooleanClause.Occur.MUST);
|
||||
bq.add(mltQuery, BooleanClause.Occur.MUST);
|
||||
q = bq;
|
||||
} else {
|
||||
q = mlt.like(new StringReader(text), textFieldName);
|
||||
q = mltQuery;
|
||||
}
|
||||
TopDocs topDocs = indexSearcher.search(q, k);
|
||||
return selectClassFromNeighbors(topDocs);
|
||||
|
@ -116,7 +119,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
*/
|
||||
@Override
|
||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
this.textFieldName = textFieldName;
|
||||
this.textFieldNames = new String[]{textFieldName};
|
||||
this.classFieldName = classFieldName;
|
||||
mlt = new MoreLikeThis(atomicReader);
|
||||
mlt.setAnalyzer(analyzer);
|
||||
|
@ -124,4 +127,18 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
indexSearcher = new IndexSearcher(atomicReader);
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||
this.textFieldNames = textFieldNames;
|
||||
this.classFieldName = classFieldName;
|
||||
mlt = new MoreLikeThis(atomicReader);
|
||||
mlt.setAnalyzer(analyzer);
|
||||
mlt.setFieldNames(textFieldNames);
|
||||
indexSearcher = new IndexSearcher(atomicReader);
|
||||
this.query = query;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ import java.util.LinkedList;
|
|||
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||
|
||||
private AtomicReader atomicReader;
|
||||
private String textFieldName;
|
||||
private String[] textFieldNames;
|
||||
private String classFieldName;
|
||||
private int docsWithClassSize;
|
||||
private Analyzer analyzer;
|
||||
|
@ -68,18 +68,36 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
throws IOException {
|
||||
this.atomicReader = atomicReader;
|
||||
this.indexSearcher = new IndexSearcher(this.atomicReader);
|
||||
this.textFieldName = textFieldName;
|
||||
this.textFieldNames = new String[]{textFieldName};
|
||||
this.classFieldName = classFieldName;
|
||||
this.analyzer = analyzer;
|
||||
this.docsWithClassSize = countDocsWithClass();
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
||||
train(atomicReader, textFieldName, classFieldName, analyzer, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
|
||||
throws IOException {
|
||||
this.atomicReader = atomicReader;
|
||||
this.indexSearcher = new IndexSearcher(this.atomicReader);
|
||||
this.textFieldNames = textFieldNames;
|
||||
this.classFieldName = classFieldName;
|
||||
this.analyzer = analyzer;
|
||||
this.docsWithClassSize = countDocsWithClass();
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
private int countDocsWithClass() throws IOException {
|
||||
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
|
||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||
|
@ -103,13 +121,15 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
|
||||
private String[] tokenizeDoc(String doc) throws IOException {
|
||||
Collection<String> result = new LinkedList<String>();
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
|
||||
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
|
||||
tokenStream.reset();
|
||||
while (tokenStream.incrementToken()) {
|
||||
result.add(charTermAttribute.toString());
|
||||
for (String textFieldName : textFieldNames) {
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
|
||||
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
|
||||
tokenStream.reset();
|
||||
while (tokenStream.incrementToken()) {
|
||||
result.add(charTermAttribute.toString());
|
||||
}
|
||||
tokenStream.end();
|
||||
}
|
||||
tokenStream.end();
|
||||
}
|
||||
return result.toArray(new String[result.size()]);
|
||||
}
|
||||
|
@ -164,16 +184,23 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
}
|
||||
|
||||
private double getTextTermFreqForClass(BytesRef c) throws IOException {
|
||||
Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
|
||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||
double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||
double avgNumberOfUniqueTerms = 0;
|
||||
for (String textFieldName : textFieldNames) {
|
||||
Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
|
||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||
}
|
||||
int docsWithC = atomicReader.docFreq(new Term(classFieldName, c));
|
||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
|
||||
}
|
||||
|
||||
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
|
||||
BooleanQuery booleanQuery = new BooleanQuery();
|
||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
|
||||
BooleanQuery subQuery = new BooleanQuery();
|
||||
for (String textFieldName : textFieldNames) {
|
||||
subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
|
||||
}
|
||||
booleanQuery.add(new BooleanClause(subQuery, BooleanClause.Occur.MUST));
|
||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
||||
if (query != null) {
|
||||
booleanQuery.add(query, BooleanClause.Occur.MUST);
|
||||
|
|
|
@ -40,8 +40,8 @@ import java.io.IOException;
|
|||
*/
|
||||
public class DatasetSplitter {
|
||||
|
||||
private double crossValidationRatio;
|
||||
private double testRatio;
|
||||
private final double crossValidationRatio;
|
||||
private final double testRatio;
|
||||
|
||||
/**
|
||||
* Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes
|
||||
|
@ -68,8 +68,6 @@ public class DatasetSplitter {
|
|||
public void split(AtomicReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
|
||||
Analyzer analyzer, String... fieldNames) throws IOException {
|
||||
|
||||
// TODO : check that the passed fields are stored in the original index
|
||||
|
||||
// create IWs for train / test / cv IDXs
|
||||
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));
|
||||
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));
|
||||
|
|
Loading…
Reference in New Issue