mirror of https://github.com/apache/lucene.git
LUCENE-5699 - patch from Gergő Törcsvári for normalized score and return lists in classification
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1619053 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
7223a40b6d
commit
a934cc7106
|
@ -42,6 +42,7 @@ import org.apache.lucene.util.fst.PositiveIntOutputs;
|
|||
import org.apache.lucene.util.fst.Util;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.SortedMap;
|
||||
import java.util.TreeMap;
|
||||
|
@ -131,9 +132,7 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
this.textTerms = MultiFields.getTerms(atomicReader, textFieldName);
|
||||
|
||||
if (textTerms == null) {
|
||||
throw new IOException(new StringBuilder(
|
||||
"term vectors need to be available for field ").append(textFieldName)
|
||||
.toString());
|
||||
throw new IOException("term vectors need to be available for field " + textFieldName);
|
||||
}
|
||||
|
||||
this.analyzer = analyzer;
|
||||
|
@ -246,4 +245,22 @@ public class BooleanPerceptronClassifier implements Classifier<Boolean> {
|
|||
fst = fstBuilder.finish();
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text)
|
||||
throws IOException {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max)
|
||||
throws IOException {
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
|
@ -20,10 +20,10 @@ package org.apache.lucene.classification;
|
|||
* The result of a call to {@link Classifier#assignClass(String)} holding an assigned class of type <code>T</code> and a score.
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class ClassificationResult<T> {
|
||||
public class ClassificationResult<T> implements Comparable<ClassificationResult<T>>{
|
||||
|
||||
private final T assignedClass;
|
||||
private final double score;
|
||||
private double score;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
|
@ -50,4 +50,18 @@ public class ClassificationResult<T> {
|
|||
public double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the score value
|
||||
* @param score the score for the assignedClass as a <code>double</code>
|
||||
*/
|
||||
public void setScore(double score) {
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public int compareTo(ClassificationResult<T> o) {
|
||||
return this.getScore() < o.getScore() ? 1 : this.getScore() > o.getScore() ? -1 : 0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,33 +16,57 @@
|
|||
*/
|
||||
package org.apache.lucene.classification;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.index.AtomicReader;
|
||||
import org.apache.lucene.search.Query;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
|
||||
* <code>T</code>
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public interface Classifier<T> {
|
||||
|
||||
/**
|
||||
* Assign a class (with score) to the given text String
|
||||
*
|
||||
* @param text a String containing text to be classified
|
||||
* @return a {@link ClassificationResult} holding assigned class of type <code>T</code> and score
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public ClassificationResult<T> assignClass(String text) throws IOException;
|
||||
|
||||
/**
|
||||
* Get all the classes (sorted by score, descending) assigned to the given text String.
|
||||
*
|
||||
* @param text a String containing text to be classified
|
||||
* @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException;
|
||||
|
||||
/**
|
||||
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
|
||||
*
|
||||
* @param text a String containing text to be classified
|
||||
* @param max the number of return list elements
|
||||
* @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException;
|
||||
|
||||
/**
|
||||
* Train the classifier using the underlying Lucene index
|
||||
* @param atomicReader the reader to use to access the Lucene index
|
||||
* @param textFieldName the name of the field used to compare documents
|
||||
*
|
||||
* @param atomicReader the reader to use to access the Lucene index
|
||||
* @param textFieldName the name of the field used to compare documents
|
||||
* @param classFieldName the name of the field containing the class assigned to documents
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
|
||||
|
@ -50,11 +74,12 @@ public interface Classifier<T> {
|
|||
|
||||
/**
|
||||
* Train the classifier using the underlying Lucene index
|
||||
* @param atomicReader the reader to use to access the Lucene index
|
||||
* @param textFieldName the name of the field used to compare documents
|
||||
*
|
||||
* @param atomicReader the reader to use to access the Lucene index
|
||||
* @param textFieldName the name of the field used to compare documents
|
||||
* @param classFieldName the name of the field containing the class assigned to documents
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @param query the query to filter which documents use for training
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @param query the query to filter which documents use for training
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
|
||||
|
@ -62,11 +87,12 @@ public interface Classifier<T> {
|
|||
|
||||
/**
|
||||
* Train the classifier using the underlying Lucene index
|
||||
* @param atomicReader the reader to use to access the Lucene index
|
||||
*
|
||||
* @param atomicReader the reader to use to access the Lucene index
|
||||
* @param textFieldNames the names of the fields to be used to compare documents
|
||||
* @param classFieldName the name of the field containing the class assigned to documents
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @param query the query to filter which documents use for training
|
||||
* @param analyzer the analyzer used to tokenize / filter the unseen text
|
||||
* @param query the query to filter which documents use for training
|
||||
* @throws IOException If there is a low-level I/O error.
|
||||
*/
|
||||
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
|
||||
|
|
|
@ -31,7 +31,10 @@ import org.apache.lucene.util.BytesRef;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.io.StringReader;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
|
@ -79,6 +82,42 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
|
||||
TopDocs topDocs=knnSearcher(text);
|
||||
List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
|
||||
ClassificationResult<BytesRef> retval=null;
|
||||
double maxscore=-Double.MAX_VALUE;
|
||||
for(ClassificationResult<BytesRef> element:doclist){
|
||||
if(element.getScore()>maxscore){
|
||||
retval=element;
|
||||
maxscore=element.getScore();
|
||||
}
|
||||
}
|
||||
return retval;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
|
||||
TopDocs topDocs=knnSearcher(text);
|
||||
List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
|
||||
Collections.sort(doclist);
|
||||
return doclist;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
|
||||
TopDocs topDocs=knnSearcher(text);
|
||||
List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
|
||||
Collections.sort(doclist);
|
||||
return doclist.subList(0, max);
|
||||
}
|
||||
|
||||
private TopDocs knnSearcher(String text) throws IOException{
|
||||
if (mlt == null) {
|
||||
throw new IOException("You must first call Classifier#train");
|
||||
}
|
||||
|
@ -91,33 +130,36 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
|||
if (query != null) {
|
||||
mltQuery.add(query, BooleanClause.Occur.MUST);
|
||||
}
|
||||
TopDocs topDocs = indexSearcher.search(mltQuery, k);
|
||||
return selectClassFromNeighbors(topDocs);
|
||||
return indexSearcher.search(mltQuery, k);
|
||||
}
|
||||
|
||||
private ClassificationResult<BytesRef> selectClassFromNeighbors(TopDocs topDocs) throws IOException {
|
||||
// TODO : improve the nearest neighbor selection
|
||||
|
||||
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
|
||||
Map<BytesRef, Integer> classCounts = new HashMap<>();
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
|
||||
Integer count = classCounts.get(cl);
|
||||
if (count != null) {
|
||||
classCounts.put(cl, count + 1);
|
||||
} else {
|
||||
classCounts.put(cl, 1);
|
||||
}
|
||||
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
|
||||
Integer count = classCounts.get(cl);
|
||||
if (count != null) {
|
||||
classCounts.put(cl, count + 1);
|
||||
} else {
|
||||
classCounts.put(cl, 1);
|
||||
}
|
||||
}
|
||||
double max = 0;
|
||||
BytesRef assignedClass = new BytesRef();
|
||||
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||
int sumdoc=0;
|
||||
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
|
||||
Integer count = entry.getValue();
|
||||
if (count > max) {
|
||||
max = count;
|
||||
assignedClass = entry.getKey().clone();
|
||||
Integer count = entry.getValue();
|
||||
returnList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
|
||||
sumdoc+=count;
|
||||
|
||||
}
|
||||
|
||||
//correction
|
||||
if(sumdoc<k){
|
||||
for(ClassificationResult<BytesRef> cr:returnList){
|
||||
cr.setScore(cr.getScore()*(double)k/(double)sumdoc);
|
||||
}
|
||||
}
|
||||
double score = max / (double) k;
|
||||
return new ClassificationResult<>(assignedClass, score);
|
||||
return returnList;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -16,6 +16,13 @@
|
|||
*/
|
||||
package org.apache.lucene.classification;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
|
@ -33,10 +40,6 @@ import org.apache.lucene.search.TotalHitCountCollector;
|
|||
import org.apache.lucene.search.WildcardQuery;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedList;
|
||||
|
||||
/**
|
||||
* A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code>
|
||||
*
|
||||
|
@ -44,13 +47,12 @@ import java.util.LinkedList;
|
|||
*/
|
||||
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||
|
||||
private AtomicReader atomicReader;
|
||||
private String[] textFieldNames;
|
||||
private String classFieldName;
|
||||
private int docsWithClassSize;
|
||||
private Analyzer analyzer;
|
||||
private IndexSearcher indexSearcher;
|
||||
private Query query;
|
||||
protected AtomicReader atomicReader;
|
||||
protected String[] textFieldNames;
|
||||
protected String classFieldName;
|
||||
protected Analyzer analyzer;
|
||||
protected IndexSearcher indexSearcher;
|
||||
protected Query query;
|
||||
|
||||
/**
|
||||
* Creates a new NaiveBayes classifier.
|
||||
|
@ -89,10 +91,88 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
this.classFieldName = classFieldName;
|
||||
this.analyzer = analyzer;
|
||||
this.query = query;
|
||||
this.docsWithClassSize = countDocsWithClass();
|
||||
}
|
||||
|
||||
private int countDocsWithClass() throws IOException {
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(inputDocument);
|
||||
ClassificationResult<BytesRef> retval = null;
|
||||
double maxscore = -Double.MAX_VALUE;
|
||||
for (ClassificationResult<BytesRef> element : doclist) {
|
||||
if (element.getScore() > maxscore) {
|
||||
retval = element;
|
||||
maxscore = element.getScore();
|
||||
}
|
||||
}
|
||||
return retval;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
|
||||
Collections.sort(doclist);
|
||||
return doclist;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
|
||||
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
|
||||
Collections.sort(doclist);
|
||||
return doclist.subList(0, max);
|
||||
}
|
||||
|
||||
private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||
if (atomicReader == null) {
|
||||
throw new IOException("You must first call Classifier#train");
|
||||
}
|
||||
List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();
|
||||
|
||||
Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
|
||||
TermsEnum termsEnum = terms.iterator(null);
|
||||
BytesRef next;
|
||||
String[] tokenizedDoc = tokenizeDoc(inputDocument);
|
||||
int docsWithClassSize = countDocsWithClass();
|
||||
while ((next = termsEnum.next()) != null) {
|
||||
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
|
||||
dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal));
|
||||
}
|
||||
|
||||
// normalization; the values transforms to a 0-1 range
|
||||
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||
if (!dataList.isEmpty()) {
|
||||
Collections.sort(dataList);
|
||||
// this is a negative number closest to 0 = a
|
||||
double smax = dataList.get(0).getScore();
|
||||
|
||||
double sumLog = 0;
|
||||
// log(sum(exp(x_n-a)))
|
||||
for (ClassificationResult<BytesRef> cr : dataList) {
|
||||
// getScore-smax <=0 (both negative, smax is the smallest abs()
|
||||
sumLog += Math.exp(cr.getScore() - smax);
|
||||
}
|
||||
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
|
||||
double loga = smax;
|
||||
loga += Math.log(sumLog);
|
||||
|
||||
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
|
||||
for (ClassificationResult<BytesRef> cr : dataList) {
|
||||
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
|
||||
}
|
||||
}
|
||||
|
||||
return returnList;
|
||||
}
|
||||
|
||||
protected int countDocsWithClass() throws IOException {
|
||||
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
|
||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
||||
|
@ -108,7 +188,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
return docCount;
|
||||
}
|
||||
|
||||
private String[] tokenizeDoc(String doc) throws IOException {
|
||||
protected String[] tokenizeDoc(String doc) throws IOException {
|
||||
Collection<String> result = new LinkedList<>();
|
||||
for (String textFieldName : textFieldNames) {
|
||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
|
||||
|
@ -123,34 +203,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
return result.toArray(new String[result.size()]);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
|
||||
if (atomicReader == null) {
|
||||
throw new IOException("You must first call Classifier#train");
|
||||
}
|
||||
double max = - Double.MAX_VALUE;
|
||||
BytesRef foundClass = new BytesRef();
|
||||
|
||||
Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
|
||||
TermsEnum termsEnum = terms.iterator(null);
|
||||
BytesRef next;
|
||||
String[] tokenizedDoc = tokenizeDoc(inputDocument);
|
||||
while ((next = termsEnum.next()) != null) {
|
||||
double clVal = calculateLogPrior(next) + calculateLogLikelihood(tokenizedDoc, next);
|
||||
if (clVal > max) {
|
||||
max = clVal;
|
||||
foundClass = BytesRef.deepCopyOf(next);
|
||||
}
|
||||
}
|
||||
double score = 10 / Math.abs(max);
|
||||
return new ClassificationResult<>(foundClass, score);
|
||||
}
|
||||
|
||||
|
||||
private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c) throws IOException {
|
||||
private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c, int docsWithClassSize) throws IOException {
|
||||
// for each word
|
||||
double result = 0d;
|
||||
for (String word : tokenizedDoc) {
|
||||
|
@ -187,7 +240,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
BooleanQuery booleanQuery = new BooleanQuery();
|
||||
BooleanQuery subQuery = new BooleanQuery();
|
||||
for (String textFieldName : textFieldNames) {
|
||||
subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
|
||||
subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
|
||||
}
|
||||
booleanQuery.add(new BooleanClause(subQuery, BooleanClause.Occur.MUST));
|
||||
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
||||
|
@ -199,7 +252,7 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
|||
return totalHitCountCollector.getTotalHits();
|
||||
}
|
||||
|
||||
private double calculateLogPrior(BytesRef currentClass) throws IOException {
|
||||
private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
|
||||
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue