mirror of https://github.com/apache/lucene.git
LUCENE-5311 - added support for training using multiple content fields for knn and naive bayes
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1538205 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
24279e5c6a
commit
b4a343d6ba
|
@ -191,6 +191,11 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
||||||
weights.clear(); // free memory while waiting for GC
|
weights.clear(); // free memory while waiting for GC
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||||
|
throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
|
||||||
|
}
|
||||||
|
|
||||||
private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse,
|
private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse,
|
||||||
int docId, Boolean assignedClass, SortedMap<String,Double> weights,
|
int docId, Boolean assignedClass, SortedMap<String,Double> weights,
|
||||||
double modifier, boolean updateFST) throws IOException {
|
double modifier, boolean updateFST) throws IOException {
|
||||||
|
|
|
@ -22,8 +22,8 @@ package org.apache.lucene.classification;
|
||||||
*/
|
*/
|
||||||
public class ClassificationResult<T> {
|
public class ClassificationResult<T> {
|
||||||
|
|
||||||
private T assignedClass;
|
private final T assignedClass;
|
||||||
private double score;
|
private final double score;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructor
|
* Constructor
|
||||||
|
|
|
@ -60,4 +60,16 @@ public interface Classifier<T> {
|
||||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
||||||
throws IOException;
|
throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Train the classifier using the underlying Lucene index
|
||||||
|
* @param atomicReader the reader to use to access the Lucene index
|
||||||
|
* @param textFieldNames the names of the fields to be used to compare documents
|
||||||
|
* @param classFieldName the name of the field containing the class assigned to documents
|
||||||
|
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||||
|
* @param query the query to filter which documents use for training
|
||||||
|
* @throws IOException If there is a low-level I/O error.
|
||||||
|
*/
|
||||||
|
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
|
||||||
|
throws IOException;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,10 +41,10 @@ import java.util.Map;
|
||||||
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
|
|
||||||
private MoreLikeThis mlt;
|
private MoreLikeThis mlt;
|
||||||
private String textFieldName;
|
private String[] textFieldNames;
|
||||||
private String classFieldName;
|
private String classFieldName;
|
||||||
private IndexSearcher indexSearcher;
|
private IndexSearcher indexSearcher;
|
||||||
private int k;
|
private final int k;
|
||||||
private Query query;
|
private Query query;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -65,14 +65,17 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
throw new IOException("You must first call Classifier#train");
|
throw new IOException("You must first call Classifier#train");
|
||||||
}
|
}
|
||||||
Query q;
|
Query q;
|
||||||
|
BooleanQuery mltQuery = new BooleanQuery();
|
||||||
|
for (String textFieldName : textFieldNames) {
|
||||||
|
mltQuery.add(new BooleanClause(mlt.like(new StringReader(text), textFieldName), BooleanClause.Occur.SHOULD));
|
||||||
|
}
|
||||||
if (query != null) {
|
if (query != null) {
|
||||||
Query mltQuery = mlt.like(new StringReader(text), textFieldName);
|
|
||||||
BooleanQuery bq = new BooleanQuery();
|
BooleanQuery bq = new BooleanQuery();
|
||||||
bq.add(query, BooleanClause.Occur.MUST);
|
bq.add(query, BooleanClause.Occur.MUST);
|
||||||
bq.add(mltQuery, BooleanClause.Occur.MUST);
|
bq.add(mltQuery, BooleanClause.Occur.MUST);
|
||||||
q = bq;
|
q = bq;
|
||||||
} else {
|
} else {
|
||||||
q = mlt.like(new StringReader(text), textFieldName);
|
q = mltQuery;
|
||||||
}
|
}
|
||||||
TopDocs topDocs = indexSearcher.search(q, k);
|
TopDocs topDocs = indexSearcher.search(q, k);
|
||||||
return selectClassFromNeighbors(topDocs);
|
return selectClassFromNeighbors(topDocs);
|
||||||
|
@ -116,7 +119,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||||
this.textFieldName = textFieldName;
|
this.textFieldNames = new String[]{textFieldName};
|
||||||
this.classFieldName = classFieldName;
|
this.classFieldName = classFieldName;
|
||||||
mlt = new MoreLikeThis(atomicReader);
|
mlt = new MoreLikeThis(atomicReader);
|
||||||
mlt.setAnalyzer(analyzer);
|
mlt.setAnalyzer(analyzer);
|
||||||
|
@ -124,4 +127,18 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
indexSearcher = new IndexSearcher(atomicReader);
|
indexSearcher = new IndexSearcher(atomicReader);
|
||||||
this.query = query;
|
this.query = query;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
|
||||||
|
this.textFieldNames = textFieldNames;
|
||||||
|
this.classFieldName = classFieldName;
|
||||||
|
mlt = new MoreLikeThis(atomicReader);
|
||||||
|
mlt.setAnalyzer(analyzer);
|
||||||
|
mlt.setFieldNames(textFieldNames);
|
||||||
|
indexSearcher = new IndexSearcher(atomicReader);
|
||||||
|
this.query = query;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ import java.util.LinkedList;
|
||||||
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
|
|
||||||
private AtomicReader atomicReader;
|
private AtomicReader atomicReader;
|
||||||
private String textFieldName;
|
private String[] textFieldNames;
|
||||||
private String classFieldName;
|
private String classFieldName;
|
||||||
private int docsWithClassSize;
|
private int docsWithClassSize;
|
||||||
private Analyzer analyzer;
|
private Analyzer analyzer;
|
||||||
|
@ -68,18 +68,36 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
throws IOException {
|
throws IOException {
|
||||||
this.atomicReader = atomicReader;
|
this.atomicReader = atomicReader;
|
||||||
this.indexSearcher = new IndexSearcher(this.atomicReader);
|
this.indexSearcher = new IndexSearcher(this.atomicReader);
|
||||||
this.textFieldName = textFieldName;
|
this.textFieldNames = new String[]{textFieldName};
|
||||||
this.classFieldName = classFieldName;
|
this.classFieldName = classFieldName;
|
||||||
this.analyzer = analyzer;
|
this.analyzer = analyzer;
|
||||||
this.docsWithClassSize = countDocsWithClass();
|
this.docsWithClassSize = countDocsWithClass();
|
||||||
this.query = query;
|
this.query = query;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
|
||||||
train(atomicReader, textFieldName, classFieldName, analyzer, null);
|
train(atomicReader, textFieldName, classFieldName, analyzer, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
|
||||||
|
throws IOException {
|
||||||
|
this.atomicReader = atomicReader;
|
||||||
|
this.indexSearcher = new IndexSearcher(this.atomicReader);
|
||||||
|
this.textFieldNames = textFieldNames;
|
||||||
|
this.classFieldName = classFieldName;
|
||||||
|
this.analyzer = analyzer;
|
||||||
|
this.docsWithClassSize = countDocsWithClass();
|
||||||
|
this.query = query;
|
||||||
|
}
|
||||||
|
|
||||||
private int countDocsWithClass() throws IOException {
|
private int countDocsWithClass() throws IOException {
|
||||||
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
|
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
|
||||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||||
|
@ -103,13 +121,15 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
|
|
||||||
private String[] tokenizeDoc(String doc) throws IOException {
|
private String[] tokenizeDoc(String doc) throws IOException {
|
||||||
Collection<String> result = new LinkedList<String>();
|
Collection<String> result = new LinkedList<String>();
|
||||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
|
for (String textFieldName : textFieldNames) {
|
||||||
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
|
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
|
||||||
tokenStream.reset();
|
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
|
||||||
while (tokenStream.incrementToken()) {
|
tokenStream.reset();
|
||||||
result.add(charTermAttribute.toString());
|
while (tokenStream.incrementToken()) {
|
||||||
|
result.add(charTermAttribute.toString());
|
||||||
|
}
|
||||||
|
tokenStream.end();
|
||||||
}
|
}
|
||||||
tokenStream.end();
|
|
||||||
}
|
}
|
||||||
return result.toArray(new String[result.size()]);
|
return result.toArray(new String[result.size()]);
|
||||||
}
|
}
|
||||||
|
@ -164,16 +184,23 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
}
|
}
|
||||||
|
|
||||||
private double getTextTermFreqForClass(BytesRef c) throws IOException {
|
private double getTextTermFreqForClass(BytesRef c) throws IOException {
|
||||||
Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
|
double avgNumberOfUniqueTerms = 0;
|
||||||
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
for (String textFieldName : textFieldNames) {
|
||||||
double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
|
||||||
|
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||||
|
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||||
|
}
|
||||||
int docsWithC = atomicReader.docFreq(new Term(classFieldName, c));
|
int docsWithC = atomicReader.docFreq(new Term(classFieldName, c));
|
||||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
|
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
|
||||||
}
|
}
|
||||||
|
|
||||||
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
|
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
|
||||||
BooleanQuery booleanQuery = new BooleanQuery();
|
BooleanQuery booleanQuery = new BooleanQuery();
|
||||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
|
BooleanQuery subQuery = new BooleanQuery();
|
||||||
|
for (String textFieldName : textFieldNames) {
|
||||||
|
subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
|
||||||
|
}
|
||||||
|
booleanQuery.add(new BooleanClause(subQuery, BooleanClause.Occur.MUST));
|
||||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
||||||
if (query != null) {
|
if (query != null) {
|
||||||
booleanQuery.add(query, BooleanClause.Occur.MUST);
|
booleanQuery.add(query, BooleanClause.Occur.MUST);
|
||||||
|
|
|
@ -40,8 +40,8 @@ import java.io.IOException;
|
||||||
*/
|
*/
|
||||||
public class DatasetSplitter {
|
public class DatasetSplitter {
|
||||||
|
|
||||||
private double crossValidationRatio;
|
private final double crossValidationRatio;
|
||||||
private double testRatio;
|
private final double testRatio;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes
|
* Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes
|
||||||
|
@ -68,8 +68,6 @@ public class DatasetSplitter {
|
||||||
public void split(AtomicReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
|
public void split(AtomicReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
|
||||||
Analyzer analyzer, String... fieldNames) throws IOException {
|
Analyzer analyzer, String... fieldNames) throws IOException {
|
||||||
|
|
||||||
// TODO : check that the passed fields are stored in the original index
|
|
||||||
|
|
||||||
// create IWs for train / test / cv IDXs
|
// create IWs for train / test / cv IDXs
|
||||||
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));
|
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));
|
||||||
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));
|
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer));
|
||||||
|
|
Loading…
Reference in New Issue