LUCENE-5311 - added support for training using multiple content fields for knn and naive bayes

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1538205 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2013-11-02 15:41:49 +00:00
parent 24279e5c6a
commit b4a343d6ba
6 changed files with 82 additions and 23 deletions

View File

@ -191,6 +191,11 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
weights.clear(); // free memory while waiting for GC
}
@Override
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
}
private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse,
int docId, Boolean assignedClass, SortedMap<String,Double> weights,
double modifier, boolean updateFST) throws IOException {

View File

@ -22,8 +22,8 @@ package org.apache.lucene.classification;
*/
public class ClassificationResult<T> {
private T assignedClass;
private double score;
private final T assignedClass;
private final double score;
/**
* Constructor

View File

@ -60,4 +60,16 @@ public interface Classifier<T> {
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
throws IOException;
/**
* Train the classifier using the underlying Lucene index
* @param atomicReader the reader to use to access the Lucene index
* @param textFieldNames the names of the fields to be used to compare documents
* @param classFieldName the name of the field containing the class assigned to documents
* @param analyzer the analyzer used to tokenize / filter the unseen text
* @param query the query to filter which documents use for training
* @throws IOException If there is a low-level I/O error.
*/
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
throws IOException;
}

View File

@ -41,10 +41,10 @@ import java.util.Map;
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
private MoreLikeThis mlt;
private String textFieldName;
private String[] textFieldNames;
private String classFieldName;
private IndexSearcher indexSearcher;
private int k;
private final int k;
private Query query;
/**
@ -65,14 +65,17 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
throw new IOException("You must first call Classifier#train");
}
Query q;
BooleanQuery mltQuery = new BooleanQuery();
for (String textFieldName : textFieldNames) {
mltQuery.add(new BooleanClause(mlt.like(new StringReader(text), textFieldName), BooleanClause.Occur.SHOULD));
}
if (query != null) {
Query mltQuery = mlt.like(new StringReader(text), textFieldName);
BooleanQuery bq = new BooleanQuery();
bq.add(query, BooleanClause.Occur.MUST);
bq.add(mltQuery, BooleanClause.Occur.MUST);
q = bq;
} else {
q = mlt.like(new StringReader(text), textFieldName);
q = mltQuery;
}
TopDocs topDocs = indexSearcher.search(q, k);
return selectClassFromNeighbors(topDocs);
@ -116,7 +119,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
*/
@Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textFieldName = textFieldName;
this.textFieldNames = new String[]{textFieldName};
this.classFieldName = classFieldName;
mlt = new MoreLikeThis(atomicReader);
mlt.setAnalyzer(analyzer);
@ -124,4 +127,18 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
indexSearcher = new IndexSearcher(atomicReader);
this.query = query;
}
/**
* {@inheritDoc}
*/
@Override
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
mlt = new MoreLikeThis(atomicReader);
mlt.setAnalyzer(analyzer);
mlt.setFieldNames(textFieldNames);
indexSearcher = new IndexSearcher(atomicReader);
this.query = query;
}
}

View File

@ -45,7 +45,7 @@ import java.util.LinkedList;
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
private AtomicReader atomicReader;
private String textFieldName;
private String[] textFieldNames;
private String classFieldName;
private int docsWithClassSize;
private Analyzer analyzer;
@ -68,18 +68,36 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
throws IOException {
this.atomicReader = atomicReader;
this.indexSearcher = new IndexSearcher(this.atomicReader);
this.textFieldName = textFieldName;
this.textFieldNames = new String[]{textFieldName};
this.classFieldName = classFieldName;
this.analyzer = analyzer;
this.docsWithClassSize = countDocsWithClass();
this.query = query;
}
/**
* {@inheritDoc}
*/
@Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
train(atomicReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
throws IOException {
this.atomicReader = atomicReader;
this.indexSearcher = new IndexSearcher(this.atomicReader);
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
this.analyzer = analyzer;
this.docsWithClassSize = countDocsWithClass();
this.query = query;
}
private int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount
@ -103,6 +121,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
private String[] tokenizeDoc(String doc) throws IOException {
Collection<String> result = new LinkedList<String>();
for (String textFieldName : textFieldNames) {
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
tokenStream.reset();
@ -111,6 +130,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
}
tokenStream.end();
}
}
return result.toArray(new String[result.size()]);
}
@ -164,16 +184,23 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
}
private double getTextTermFreqForClass(BytesRef c) throws IOException {
double avgNumberOfUniqueTerms = 0;
for (String textFieldName : textFieldNames) {
Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
}
int docsWithC = atomicReader.docFreq(new Term(classFieldName, c));
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
}
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
BooleanQuery booleanQuery = new BooleanQuery();
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
BooleanQuery subQuery = new BooleanQuery();
for (String textFieldName : textFieldNames) {
subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
}
booleanQuery.add(new BooleanClause(subQuery, BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST);

View File

@ -40,8 +40,8 @@ import java.io.IOException;
*/
public class DatasetSplitter {
private double crossValidationRatio;
private double testRatio;
private final double crossValidationRatio;
private final double testRatio;
/**
* Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes
@ -68,8 +68,6 @@ public class DatasetSplitter {
public void split(AtomicReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
Analyzer analyzer, String... fieldNames) throws IOException {
// TODO : check that the passed fields are stored in the original index
// create IWs for train / test / cv IDXs
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));