LUCENE-5311 - added support for training using multiple content fields for knn and naive bayes

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1538205 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2013-11-02 15:41:49 +00:00
parent 24279e5c6a
commit b4a343d6ba
6 changed files with 82 additions and 23 deletions

View File

@ -191,6 +191,11 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
weights.clear(); // free memory while waiting for GC weights.clear(); // free memory while waiting for GC
} }
@Override
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
}
private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse, private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse,
int docId, Boolean assignedClass, SortedMap<String,Double> weights, int docId, Boolean assignedClass, SortedMap<String,Double> weights,
double modifier, boolean updateFST) throws IOException { double modifier, boolean updateFST) throws IOException {

View File

@ -22,8 +22,8 @@ package org.apache.lucene.classification;
*/ */
public class ClassificationResult<T> { public class ClassificationResult<T> {
private T assignedClass; private final T assignedClass;
private double score; private final double score;
/** /**
* Constructor * Constructor

View File

@ -60,4 +60,16 @@ public interface Classifier<T> {
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
throws IOException; throws IOException;
/**
* Train the classifier using the underlying Lucene index
* @param atomicReader the reader to use to access the Lucene index
* @param textFieldNames the names of the fields to be used to compare documents
* @param classFieldName the name of the field containing the class assigned to documents
* @param analyzer the analyzer used to tokenize / filter the unseen text
* @param query the query to filter which documents use for training
* @throws IOException If there is a low-level I/O error.
*/
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
throws IOException;
} }

View File

@ -41,10 +41,10 @@ import java.util.Map;
public class KNearestNeighborClassifier implements Classifier<BytesRef> { public class KNearestNeighborClassifier implements Classifier<BytesRef> {
private MoreLikeThis mlt; private MoreLikeThis mlt;
private String textFieldName; private String[] textFieldNames;
private String classFieldName; private String classFieldName;
private IndexSearcher indexSearcher; private IndexSearcher indexSearcher;
private int k; private final int k;
private Query query; private Query query;
/** /**
@ -65,14 +65,17 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
throw new IOException("You must first call Classifier#train"); throw new IOException("You must first call Classifier#train");
} }
Query q; Query q;
BooleanQuery mltQuery = new BooleanQuery();
for (String textFieldName : textFieldNames) {
mltQuery.add(new BooleanClause(mlt.like(new StringReader(text), textFieldName), BooleanClause.Occur.SHOULD));
}
if (query != null) { if (query != null) {
Query mltQuery = mlt.like(new StringReader(text), textFieldName);
BooleanQuery bq = new BooleanQuery(); BooleanQuery bq = new BooleanQuery();
bq.add(query, BooleanClause.Occur.MUST); bq.add(query, BooleanClause.Occur.MUST);
bq.add(mltQuery, BooleanClause.Occur.MUST); bq.add(mltQuery, BooleanClause.Occur.MUST);
q = bq; q = bq;
} else { } else {
q = mlt.like(new StringReader(text), textFieldName); q = mltQuery;
} }
TopDocs topDocs = indexSearcher.search(q, k); TopDocs topDocs = indexSearcher.search(q, k);
return selectClassFromNeighbors(topDocs); return selectClassFromNeighbors(topDocs);
@ -116,7 +119,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
*/ */
@Override @Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException { public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textFieldName = textFieldName; this.textFieldNames = new String[]{textFieldName};
this.classFieldName = classFieldName; this.classFieldName = classFieldName;
mlt = new MoreLikeThis(atomicReader); mlt = new MoreLikeThis(atomicReader);
mlt.setAnalyzer(analyzer); mlt.setAnalyzer(analyzer);
@ -124,4 +127,18 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
indexSearcher = new IndexSearcher(atomicReader); indexSearcher = new IndexSearcher(atomicReader);
this.query = query; this.query = query;
} }
/**
* {@inheritDoc}
*/
@Override
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
mlt = new MoreLikeThis(atomicReader);
mlt.setAnalyzer(analyzer);
mlt.setFieldNames(textFieldNames);
indexSearcher = new IndexSearcher(atomicReader);
this.query = query;
}
} }

View File

@ -45,7 +45,7 @@ import java.util.LinkedList;
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> { public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
private AtomicReader atomicReader; private AtomicReader atomicReader;
private String textFieldName; private String[] textFieldNames;
private String classFieldName; private String classFieldName;
private int docsWithClassSize; private int docsWithClassSize;
private Analyzer analyzer; private Analyzer analyzer;
@ -68,18 +68,36 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
throws IOException { throws IOException {
this.atomicReader = atomicReader; this.atomicReader = atomicReader;
this.indexSearcher = new IndexSearcher(this.atomicReader); this.indexSearcher = new IndexSearcher(this.atomicReader);
this.textFieldName = textFieldName; this.textFieldNames = new String[]{textFieldName};
this.classFieldName = classFieldName; this.classFieldName = classFieldName;
this.analyzer = analyzer; this.analyzer = analyzer;
this.docsWithClassSize = countDocsWithClass(); this.docsWithClassSize = countDocsWithClass();
this.query = query; this.query = query;
} }
/**
* {@inheritDoc}
*/
@Override @Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
train(atomicReader, textFieldName, classFieldName, analyzer, null); train(atomicReader, textFieldName, classFieldName, analyzer, null);
} }
/**
* {@inheritDoc}
*/
@Override
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
throws IOException {
this.atomicReader = atomicReader;
this.indexSearcher = new IndexSearcher(this.atomicReader);
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
this.analyzer = analyzer;
this.docsWithClassSize = countDocsWithClass();
this.query = query;
}
private int countDocsWithClass() throws IOException { private int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount(); int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount if (docCount == -1) { // in case codec doesn't support getDocCount
@ -103,6 +121,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
private String[] tokenizeDoc(String doc) throws IOException { private String[] tokenizeDoc(String doc) throws IOException {
Collection<String> result = new LinkedList<String>(); Collection<String> result = new LinkedList<String>();
for (String textFieldName : textFieldNames) {
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) { try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class); CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
tokenStream.reset(); tokenStream.reset();
@ -111,6 +130,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
} }
tokenStream.end(); tokenStream.end();
} }
}
return result.toArray(new String[result.size()]); return result.toArray(new String[result.size()]);
} }
@ -164,16 +184,23 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
} }
private double getTextTermFreqForClass(BytesRef c) throws IOException { private double getTextTermFreqForClass(BytesRef c) throws IOException {
double avgNumberOfUniqueTerms = 0;
for (String textFieldName : textFieldNames) {
Terms terms = MultiFields.getTerms(atomicReader, textFieldName); Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
}
int docsWithC = atomicReader.docFreq(new Term(classFieldName, c)); int docsWithC = atomicReader.docFreq(new Term(classFieldName, c));
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
} }
private int getWordFreqForClass(String word, BytesRef c) throws IOException { private int getWordFreqForClass(String word, BytesRef c) throws IOException {
BooleanQuery booleanQuery = new BooleanQuery(); BooleanQuery booleanQuery = new BooleanQuery();
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST)); BooleanQuery subQuery = new BooleanQuery();
for (String textFieldName : textFieldNames) {
subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
}
booleanQuery.add(new BooleanClause(subQuery, BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
if (query != null) { if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST); booleanQuery.add(query, BooleanClause.Occur.MUST);

View File

@ -40,8 +40,8 @@ import java.io.IOException;
*/ */
public class DatasetSplitter { public class DatasetSplitter {
private double crossValidationRatio; private final double crossValidationRatio;
private double testRatio; private final double testRatio;
/** /**
* Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes * Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes
@ -68,8 +68,6 @@ public class DatasetSplitter {
public void split(AtomicReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, public void split(AtomicReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
Analyzer analyzer, String... fieldNames) throws IOException { Analyzer analyzer, String... fieldNames) throws IOException {
// TODO : check that the passed fields are stored in the original index
// create IWs for train / test / cv IDXs // create IWs for train / test / cv IDXs
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer)); IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer)); IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));