mirror of https://github.com/apache/lucene.git
LUCENE-6631 - added document classification api and impls
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1709522 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
0fe5ab3b9b
commit
71cea88773
|
@ -2,7 +2,6 @@ package org.apache.lucene.classification;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -81,38 +80,17 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
||||||
|
|
||||||
|
|
||||||
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||||
String[] tokenizedDoc = tokenizeDoc(inputDocument);
|
String[] tokenizedText = tokenize(inputDocument);
|
||||||
|
|
||||||
List<ClassificationResult<BytesRef>> dataList = calculateLogLikelihood(tokenizedDoc);
|
List<ClassificationResult<BytesRef>> assignedClasses = calculateLogLikelihood(tokenizedText);
|
||||||
|
|
||||||
// normalization
|
// normalization
|
||||||
// The values transforms to a 0-1 range
|
// The values transforms to a 0-1 range
|
||||||
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
ArrayList<ClassificationResult<BytesRef>> asignedClassesNorm = super.normClassificationResults(assignedClasses);
|
||||||
if (!dataList.isEmpty()) {
|
return asignedClassesNorm;
|
||||||
Collections.sort(dataList);
|
|
||||||
// this is a negative number closest to 0 = a
|
|
||||||
double smax = dataList.get(0).getScore();
|
|
||||||
|
|
||||||
double sumLog = 0;
|
|
||||||
// log(sum(exp(x_n-a)))
|
|
||||||
for (ClassificationResult<BytesRef> cr : dataList) {
|
|
||||||
// getScore-smax <=0 (both negative, smax is the smallest abs()
|
|
||||||
sumLog += Math.exp(cr.getScore() - smax);
|
|
||||||
}
|
|
||||||
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
|
|
||||||
double loga = smax;
|
|
||||||
loga += Math.log(sumLog);
|
|
||||||
|
|
||||||
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
|
|
||||||
for (ClassificationResult<BytesRef> cr : dataList) {
|
|
||||||
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return returnList;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<ClassificationResult<BytesRef>> calculateLogLikelihood(String[] tokenizedDoc) throws IOException {
|
private List<ClassificationResult<BytesRef>> calculateLogLikelihood(String[] tokenizedText) throws IOException {
|
||||||
// initialize the return List
|
// initialize the return List
|
||||||
ArrayList<ClassificationResult<BytesRef>> ret = new ArrayList<>();
|
ArrayList<ClassificationResult<BytesRef>> ret = new ArrayList<>();
|
||||||
for (BytesRef cclass : cclasses) {
|
for (BytesRef cclass : cclasses) {
|
||||||
|
@ -120,7 +98,7 @@ public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
|
||||||
ret.add(cr);
|
ret.add(cr);
|
||||||
}
|
}
|
||||||
// for each word
|
// for each word
|
||||||
for (String word : tokenizedDoc) {
|
for (String word : tokenizedText) {
|
||||||
// search with text:word for all class:c
|
// search with text:word for all class:c
|
||||||
Map<BytesRef, Integer> hitsInClasses = getWordFreqForClassess(word);
|
Map<BytesRef, Integer> hitsInClasses = getWordFreqForClassess(word);
|
||||||
// for each class
|
// for each class
|
||||||
|
|
|
@ -48,12 +48,12 @@ import org.apache.lucene.util.BytesRef;
|
||||||
*/
|
*/
|
||||||
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
|
|
||||||
private final MoreLikeThis mlt;
|
protected final MoreLikeThis mlt;
|
||||||
private final String[] textFieldNames;
|
protected final String[] textFieldNames;
|
||||||
private final String classFieldName;
|
protected final String classFieldName;
|
||||||
private final IndexSearcher indexSearcher;
|
protected final IndexSearcher indexSearcher;
|
||||||
private final int k;
|
protected final int k;
|
||||||
private final Query query;
|
protected final Query query;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a {@link KNearestNeighborClassifier}.
|
* Creates a {@link KNearestNeighborClassifier}.
|
||||||
|
@ -159,7 +159,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
|
||||||
}
|
}
|
||||||
|
|
||||||
//ranking of classes must be taken in consideration
|
//ranking of classes must be taken in consideration
|
||||||
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
|
protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
|
||||||
Map<BytesRef, Integer> classCounts = new HashMap<>();
|
Map<BytesRef, Integer> classCounts = new HashMap<>();
|
||||||
Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
|
Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
|
||||||
float maxScore = topDocs.getMaxScore();
|
float maxScore = topDocs.getMaxScore();
|
||||||
|
|
|
@ -85,8 +85,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
* @param analyzer an {@link Analyzer} used to analyze unseen text
|
||||||
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||||
* if all the indexed docs should be used
|
* if all the indexed docs should be used
|
||||||
* @param classFieldName the name of the field used as the output for the classifier
|
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
|
||||||
* @param textFieldNames the name of the fields used as the inputs for the classifier
|
* as the returned class will be a token indexed for this field
|
||||||
|
* @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
|
||||||
*/
|
*/
|
||||||
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
|
||||||
this.leafReader = leafReader;
|
this.leafReader = leafReader;
|
||||||
|
@ -102,16 +103,16 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
|
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
|
||||||
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(inputDocument);
|
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(inputDocument);
|
||||||
ClassificationResult<BytesRef> retval = null;
|
ClassificationResult<BytesRef> assignedClass = null;
|
||||||
double maxscore = -Double.MAX_VALUE;
|
double maxscore = -Double.MAX_VALUE;
|
||||||
for (ClassificationResult<BytesRef> element : doclist) {
|
for (ClassificationResult<BytesRef> c : assignedClasses) {
|
||||||
if (element.getScore() > maxscore) {
|
if (c.getScore() > maxscore) {
|
||||||
retval = element;
|
assignedClass = c;
|
||||||
maxscore = element.getScore();
|
maxscore = c.getScore();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return retval;
|
return assignedClass;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -119,9 +120,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
|
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
|
||||||
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
|
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
|
||||||
Collections.sort(doclist);
|
Collections.sort(assignedClasses);
|
||||||
return doclist;
|
return assignedClasses;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -129,9 +130,9 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
|
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
|
||||||
List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
|
List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
|
||||||
Collections.sort(doclist);
|
Collections.sort(assignedClasses);
|
||||||
return doclist.subList(0, max);
|
return assignedClasses.subList(0, max);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -141,46 +142,26 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
* @throws IOException if assigning probabilities fails
|
* @throws IOException if assigning probabilities fails
|
||||||
*/
|
*/
|
||||||
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
|
||||||
List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();
|
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||||
|
|
||||||
Terms terms = MultiFields.getTerms(leafReader, classFieldName);
|
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
||||||
TermsEnum termsEnum = terms.iterator();
|
TermsEnum classesEnum = classes.iterator();
|
||||||
BytesRef next;
|
BytesRef next;
|
||||||
String[] tokenizedDoc = tokenizeDoc(inputDocument);
|
String[] tokenizedText = tokenize(inputDocument);
|
||||||
int docsWithClassSize = countDocsWithClass();
|
int docsWithClassSize = countDocsWithClass();
|
||||||
while ((next = termsEnum.next()) != null) {
|
while ((next = classesEnum.next()) != null) {
|
||||||
if (next.length > 0) {
|
if (next.length > 0) {
|
||||||
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
|
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
|
||||||
next = BytesRef.deepCopyOf(next);
|
next = BytesRef.deepCopyOf(next);
|
||||||
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
|
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedText, next, docsWithClassSize);
|
||||||
dataList.add(new ClassificationResult<>(next, clVal));
|
assignedClasses.add(new ClassificationResult<>(next, clVal));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalization; the values transforms to a 0-1 range
|
// normalization; the values transforms to a 0-1 range
|
||||||
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
|
||||||
if (!dataList.isEmpty()) {
|
|
||||||
Collections.sort(dataList);
|
|
||||||
// this is a negative number closest to 0 = a
|
|
||||||
double smax = dataList.get(0).getScore();
|
|
||||||
|
|
||||||
double sumLog = 0;
|
return assignedClassesNorm;
|
||||||
// log(sum(exp(x_n-a)))
|
|
||||||
for (ClassificationResult<BytesRef> cr : dataList) {
|
|
||||||
// getScore-smax <=0 (both negative, smax is the smallest abs()
|
|
||||||
sumLog += Math.exp(cr.getScore() - smax);
|
|
||||||
}
|
|
||||||
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
|
|
||||||
double loga = smax;
|
|
||||||
loga += Math.log(sumLog);
|
|
||||||
|
|
||||||
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
|
|
||||||
for (ClassificationResult<BytesRef> cr : dataList) {
|
|
||||||
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return returnList;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -192,15 +173,15 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
protected int countDocsWithClass() throws IOException {
|
protected int countDocsWithClass() throws IOException {
|
||||||
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
|
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
|
||||||
if (docCount == -1) { // in case codec doesn't support getDocCount
|
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||||
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
||||||
BooleanQuery.Builder q = new BooleanQuery.Builder();
|
BooleanQuery.Builder q = new BooleanQuery.Builder();
|
||||||
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
||||||
if (query != null) {
|
if (query != null) {
|
||||||
q.add(query, BooleanClause.Occur.MUST);
|
q.add(query, BooleanClause.Occur.MUST);
|
||||||
}
|
}
|
||||||
indexSearcher.search(q.build(),
|
indexSearcher.search(q.build(),
|
||||||
totalHitCountCollector);
|
classQueryCountCollector);
|
||||||
docCount = totalHitCountCollector.getTotalHits();
|
docCount = classQueryCountCollector.getTotalHits();
|
||||||
}
|
}
|
||||||
return docCount;
|
return docCount;
|
||||||
}
|
}
|
||||||
|
@ -208,14 +189,14 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
/**
|
/**
|
||||||
* tokenize a <code>String</code> on this classifier's text fields and analyzer
|
* tokenize a <code>String</code> on this classifier's text fields and analyzer
|
||||||
*
|
*
|
||||||
* @param doc the <code>String</code> representing an input text (to be classified)
|
* @param text the <code>String</code> representing an input text (to be classified)
|
||||||
* @return a <code>String</code> array of the resulting tokens
|
* @return a <code>String</code> array of the resulting tokens
|
||||||
* @throws IOException if tokenization fails
|
* @throws IOException if tokenization fails
|
||||||
*/
|
*/
|
||||||
protected String[] tokenizeDoc(String doc) throws IOException {
|
protected String[] tokenize(String text) throws IOException {
|
||||||
Collection<String> result = new LinkedList<>();
|
Collection<String> result = new LinkedList<>();
|
||||||
for (String textFieldName : textFieldNames) {
|
for (String textFieldName : textFieldNames) {
|
||||||
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
|
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
|
||||||
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
|
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
|
||||||
tokenStream.reset();
|
tokenStream.reset();
|
||||||
while (tokenStream.incrementToken()) {
|
while (tokenStream.incrementToken()) {
|
||||||
|
@ -227,18 +208,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
return result.toArray(new String[result.size()]);
|
return result.toArray(new String[result.size()]);
|
||||||
}
|
}
|
||||||
|
|
||||||
private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c, int docsWithClassSize) throws IOException {
|
private double calculateLogLikelihood(String[] tokenizedText, BytesRef c, int docsWithClass) throws IOException {
|
||||||
// for each word
|
// for each word
|
||||||
double result = 0d;
|
double result = 0d;
|
||||||
for (String word : tokenizedDoc) {
|
for (String word : tokenizedText) {
|
||||||
// search with text:word AND class:c
|
// search with text:word AND class:c
|
||||||
int hits = getWordFreqForClass(word, c);
|
int hits = getWordFreqForClass(word,c);
|
||||||
|
|
||||||
// num : count the no of times the word appears in documents of class c (+1)
|
// num : count the no of times the word appears in documents of class c (+1)
|
||||||
double num = hits + 1; // +1 is added because of add 1 smoothing
|
double num = hits + 1; // +1 is added because of add 1 smoothing
|
||||||
|
|
||||||
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
|
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
|
||||||
double den = getTextTermFreqForClass(c) + docsWithClassSize;
|
double den = getTextTermFreqForClass(c) + docsWithClass;
|
||||||
|
|
||||||
// P(w|c) = num/den
|
// P(w|c) = num/den
|
||||||
double wordProbability = num / den;
|
double wordProbability = num / den;
|
||||||
|
@ -249,6 +230,12 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the average number of unique terms times the number of docs belonging to the input class
|
||||||
|
* @param c the class
|
||||||
|
* @return the average number of unique terms
|
||||||
|
* @throws IOException if a low level I/O problem happens
|
||||||
|
*/
|
||||||
private double getTextTermFreqForClass(BytesRef c) throws IOException {
|
private double getTextTermFreqForClass(BytesRef c) throws IOException {
|
||||||
double avgNumberOfUniqueTerms = 0;
|
double avgNumberOfUniqueTerms = 0;
|
||||||
for (String textFieldName : textFieldNames) {
|
for (String textFieldName : textFieldNames) {
|
||||||
|
@ -260,6 +247,14 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the number of documents of the input class ( from the whole index or from a subset)
|
||||||
|
* that contains the word ( in a specific field or in all the fields if no one selected)
|
||||||
|
* @param word the token produced by the analyzer
|
||||||
|
* @param c the class
|
||||||
|
* @return the number of documents of the input class
|
||||||
|
* @throws IOException if a low level I/O problem happens
|
||||||
|
*/
|
||||||
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
|
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
|
||||||
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
|
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
|
||||||
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
|
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
|
||||||
|
@ -283,4 +278,36 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
|
||||||
private int docCount(BytesRef countedClass) throws IOException {
|
private int docCount(BytesRef countedClass) throws IOException {
|
||||||
return leafReader.docFreq(new Term(classFieldName, countedClass));
|
return leafReader.docFreq(new Term(classFieldName, countedClass));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Normalize the classification results based on the max score available
|
||||||
|
* @param assignedClasses the list of assigned classes
|
||||||
|
* @return the normalized results
|
||||||
|
*/
|
||||||
|
protected ArrayList<ClassificationResult<BytesRef>> normClassificationResults(List<ClassificationResult<BytesRef>> assignedClasses) {
|
||||||
|
// normalization; the values transforms to a 0-1 range
|
||||||
|
ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
|
||||||
|
if (!assignedClasses.isEmpty()) {
|
||||||
|
Collections.sort(assignedClasses);
|
||||||
|
// this is a negative number closest to 0 = a
|
||||||
|
double smax = assignedClasses.get(0).getScore();
|
||||||
|
|
||||||
|
double sumLog = 0;
|
||||||
|
// log(sum(exp(x_n-a)))
|
||||||
|
for (ClassificationResult<BytesRef> cr : assignedClasses) {
|
||||||
|
// getScore-smax <=0 (both negative, smax is the smallest abs()
|
||||||
|
sumLog += Math.exp(cr.getScore() - smax);
|
||||||
|
}
|
||||||
|
// loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
|
||||||
|
double loga = smax;
|
||||||
|
loga += Math.log(sumLog);
|
||||||
|
|
||||||
|
// 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
|
||||||
|
for (ClassificationResult<BytesRef> cr : assignedClasses) {
|
||||||
|
double scoreDiff = cr.getScore() - loga;
|
||||||
|
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return returnList;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
package org.apache.lucene.classification.document;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.lucene.classification.ClassificationResult;
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
|
||||||
|
* <code>T</code> to a {@link org.apache.lucene.document.Document}s
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public interface DocumentClassifier<T> {
|
||||||
|
/**
|
||||||
|
* Assign a class (with score) to the given {@link org.apache.lucene.document.Document}
|
||||||
|
*
|
||||||
|
* @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
|
||||||
|
* @return a {@link org.apache.lucene.classification.ClassificationResult} holding assigned class of type <code>T</code> and score
|
||||||
|
* @throws java.io.IOException If there is a low-level I/O error.
|
||||||
|
*/
|
||||||
|
ClassificationResult<T> assignClass(Document document) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all the classes (sorted by score, descending) assigned to the given {@link org.apache.lucene.document.Document}.
|
||||||
|
*
|
||||||
|
* @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
|
||||||
|
* @return the whole list of {@link org.apache.lucene.classification.ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
|
||||||
|
* @throws java.io.IOException If there is a low-level I/O error.
|
||||||
|
*/
|
||||||
|
List<ClassificationResult<T>> getClasses(Document document) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
|
||||||
|
*
|
||||||
|
* @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
|
||||||
|
* @param max the number of return list elements
|
||||||
|
* @return the whole list of {@link org.apache.lucene.classification.ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
|
||||||
|
* @throws java.io.IOException If there is a low-level I/O error.
|
||||||
|
*/
|
||||||
|
List<ClassificationResult<T>> getClasses(Document document, int max) throws IOException;
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,146 @@
|
||||||
|
package org.apache.lucene.classification.document;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.StringReader;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
|
import org.apache.lucene.classification.ClassificationResult;
|
||||||
|
import org.apache.lucene.classification.KNearestNeighborClassifier;
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.search.BooleanClause;
|
||||||
|
import org.apache.lucene.search.BooleanQuery;
|
||||||
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.lucene.search.WildcardQuery;
|
||||||
|
import org.apache.lucene.search.similarities.Similarity;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A k-Nearest Neighbor Document classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors</code>) based
|
||||||
|
* on {@link org.apache.lucene.queries.mlt.MoreLikeThis} .
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifier implements DocumentClassifier<BytesRef> {
|
||||||
|
protected Map<String, Analyzer> field2analyzer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link KNearestNeighborClassifier}.
|
||||||
|
*
|
||||||
|
* @param leafReader the reader on the index to be used for classification
|
||||||
|
* @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
|
||||||
|
* (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
|
||||||
|
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||||
|
* if all the indexed docs should be used
|
||||||
|
* @param k the no. of docs to select in the MLT results to find the nearest neighbor
|
||||||
|
* @param minDocsFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq} parameter
|
||||||
|
* @param minTermFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq} parameter
|
||||||
|
* @param classFieldName the name of the field used as the output for the classifier
|
||||||
|
* @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer}
|
||||||
|
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||||
|
*/
|
||||||
|
public KNearestNeighborDocumentClassifier(LeafReader leafReader, Similarity similarity, Query query, int k, int minDocsFreq,
|
||||||
|
int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
||||||
|
super(leafReader,similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
|
||||||
|
this.field2analyzer = field2analyzer;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
|
||||||
|
TopDocs knnResults = knnSearch(document);
|
||||||
|
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||||
|
ClassificationResult<BytesRef> assignedClass = null;
|
||||||
|
double maxscore = -Double.MAX_VALUE;
|
||||||
|
for (ClassificationResult<BytesRef> cl : assignedClasses) {
|
||||||
|
if (cl.getScore() > maxscore) {
|
||||||
|
assignedClass = cl;
|
||||||
|
maxscore = cl.getScore();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return assignedClass;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
|
||||||
|
TopDocs knnResults = knnSearch(document);
|
||||||
|
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||||
|
Collections.sort(assignedClasses);
|
||||||
|
return assignedClasses;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
|
||||||
|
TopDocs knnResults = knnSearch(document);
|
||||||
|
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
|
||||||
|
Collections.sort(assignedClasses);
|
||||||
|
return assignedClasses.subList(0, max);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the top k results from a More Like This query based on the input document
|
||||||
|
*
|
||||||
|
* @param document the document to use for More Like This search
|
||||||
|
* @return the top results for the MLT query
|
||||||
|
* @throws IOException If there is a low-level I/O error
|
||||||
|
*/
|
||||||
|
private TopDocs knnSearch(Document document) throws IOException {
|
||||||
|
BooleanQuery.Builder mltQuery = new BooleanQuery.Builder();
|
||||||
|
|
||||||
|
for (String fieldName : textFieldNames) {
|
||||||
|
String boost = null;
|
||||||
|
if (fieldName.contains("^")) {
|
||||||
|
String[] field2boost = fieldName.split("\\^");
|
||||||
|
fieldName = field2boost[0];
|
||||||
|
boost = field2boost[1];
|
||||||
|
}
|
||||||
|
String[] fieldValues = document.getValues(fieldName);
|
||||||
|
if (boost != null) {
|
||||||
|
mlt.setBoost(true);
|
||||||
|
mlt.setBoostFactor(Float.parseFloat(boost));
|
||||||
|
}
|
||||||
|
mlt.setAnalyzer(field2analyzer.get(fieldName));
|
||||||
|
for (String fieldContent : fieldValues) {
|
||||||
|
mltQuery.add(new BooleanClause(mlt.like(fieldName, new StringReader(fieldContent)), BooleanClause.Occur.SHOULD));
|
||||||
|
}
|
||||||
|
mlt.setBoost(false);
|
||||||
|
}
|
||||||
|
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
|
||||||
|
mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
|
||||||
|
if (query != null) {
|
||||||
|
mltQuery.add(query, BooleanClause.Occur.MUST);
|
||||||
|
}
|
||||||
|
return indexSearcher.search(mltQuery.build(), k);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,289 @@
|
||||||
|
package org.apache.lucene.classification.document;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
|
import org.apache.lucene.analysis.TokenStream;
|
||||||
|
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||||
|
import org.apache.lucene.classification.ClassificationResult;
|
||||||
|
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.document.Field;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.MultiFields;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.index.Terms;
|
||||||
|
import org.apache.lucene.index.TermsEnum;
|
||||||
|
import org.apache.lucene.search.BooleanClause;
|
||||||
|
import org.apache.lucene.search.BooleanQuery;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.lucene.search.TermQuery;
|
||||||
|
import org.apache.lucene.search.TotalHitCountCollector;
|
||||||
|
import org.apache.lucene.search.WildcardQuery;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simplistic Lucene based NaiveBayes classifier, see {@code http://en.wikipedia.org/wiki/Naive_Bayes_classifier}
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifier implements DocumentClassifier<BytesRef> {
|
||||||
|
/**
|
||||||
|
* {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing document fields
|
||||||
|
*/
|
||||||
|
protected Map<String, Analyzer> field2analyzer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new NaiveBayes classifier.
|
||||||
|
*
|
||||||
|
* @param leafReader the reader on the index to be used for classification
|
||||||
|
* @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
|
||||||
|
* if all the indexed docs should be used
|
||||||
|
* @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
|
||||||
|
* as the returned class will be a token indexed for this field
|
||||||
|
* @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
|
||||||
|
*/
|
||||||
|
public SimpleNaiveBayesDocumentClassifier(LeafReader leafReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
|
||||||
|
super(leafReader, null, query, classFieldName, textFieldNames);
|
||||||
|
this.field2analyzer = field2analyzer;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
|
||||||
|
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
|
||||||
|
ClassificationResult<BytesRef> assignedClass = null;
|
||||||
|
double maxscore = -Double.MAX_VALUE;
|
||||||
|
for (ClassificationResult<BytesRef> c : assignedClasses) {
|
||||||
|
if (c.getScore() > maxscore) {
|
||||||
|
assignedClass = c;
|
||||||
|
maxscore = c.getScore();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return assignedClass;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
|
||||||
|
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
|
||||||
|
Collections.sort(assignedClasses);
|
||||||
|
return assignedClasses;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
|
||||||
|
List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
|
||||||
|
Collections.sort(assignedClasses);
|
||||||
|
return assignedClasses.subList(0, max);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<ClassificationResult<BytesRef>> assignNormClasses(Document inputDocument) throws IOException {
|
||||||
|
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
|
||||||
|
Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
|
||||||
|
Map<String, Float> fieldName2boost = new LinkedHashMap<>();
|
||||||
|
Terms classes = MultiFields.getTerms(leafReader, classFieldName);
|
||||||
|
TermsEnum classesEnum = classes.iterator();
|
||||||
|
BytesRef c;
|
||||||
|
|
||||||
|
analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost);
|
||||||
|
|
||||||
|
int docsWithClassSize = countDocsWithClass();
|
||||||
|
while ((c = classesEnum.next()) != null) {
|
||||||
|
double classScore = 0;
|
||||||
|
for (String fieldName : textFieldNames) {
|
||||||
|
List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName);
|
||||||
|
double fieldScore = 0;
|
||||||
|
for (String[] fieldTokensArray : tokensArrays) {
|
||||||
|
fieldScore += calculateLogPrior(c, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, c, docsWithClassSize) * fieldName2boost.get(fieldName);
|
||||||
|
}
|
||||||
|
classScore += fieldScore;
|
||||||
|
}
|
||||||
|
assignedClasses.add(new ClassificationResult<>(BytesRef.deepCopyOf(c), classScore));
|
||||||
|
}
|
||||||
|
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
|
||||||
|
return assignedClassesNorm;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This methods performs the analysis for the seed document and extract the boosts if present.
|
||||||
|
* This is done only one time for the Seed Document.
|
||||||
|
*
|
||||||
|
* @param inputDocument the seed unseen document
|
||||||
|
* @param fieldName2tokensArray a map that associated to a field name the list of token arrays for all its values
|
||||||
|
* @param fieldName2boost a map that associates the boost to the field
|
||||||
|
* @throws IOException If there is a low-level I/O error
|
||||||
|
*/
|
||||||
|
private void analyzeSeedDocument(Document inputDocument, Map<String, List<String[]>> fieldName2tokensArray, Map<String, Float> fieldName2boost) throws IOException {
|
||||||
|
for (int i = 0; i < textFieldNames.length; i++) {
|
||||||
|
String fieldName = textFieldNames[i];
|
||||||
|
float boost = 1;
|
||||||
|
List<String[]> tokenizedValues = new LinkedList<>();
|
||||||
|
if (fieldName.contains("^")) {
|
||||||
|
String[] field2boost = fieldName.split("\\^");
|
||||||
|
fieldName = field2boost[0];
|
||||||
|
boost = Float.parseFloat(field2boost[1]);
|
||||||
|
}
|
||||||
|
Field[] fieldValues = inputDocument.getFields(fieldName);
|
||||||
|
for (Field fieldValue : fieldValues) {
|
||||||
|
TokenStream fieldTokens = fieldValue.tokenStream(field2analyzer.get(fieldName), null);
|
||||||
|
String[] fieldTokensArray = getTokenArray(fieldTokens);
|
||||||
|
tokenizedValues.add(fieldTokensArray);
|
||||||
|
}
|
||||||
|
fieldName2tokensArray.put(fieldName, tokenizedValues);
|
||||||
|
fieldName2boost.put(fieldName, boost);
|
||||||
|
textFieldNames[i] = fieldName;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Counts the number of documents in the index having at least a value for the 'class' field
|
||||||
|
*
|
||||||
|
* @return the no. of documents having a value for the 'class' field
|
||||||
|
* @throws java.io.IOException If accessing to term vectors or search fails
|
||||||
|
*/
|
||||||
|
protected int countDocsWithClass() throws IOException {
|
||||||
|
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
|
||||||
|
if (docCount == -1) { // in case codec doesn't support getDocCount
|
||||||
|
TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
|
||||||
|
BooleanQuery.Builder q = new BooleanQuery.Builder();
|
||||||
|
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
|
||||||
|
if (query != null) {
|
||||||
|
q.add(query, BooleanClause.Occur.MUST);
|
||||||
|
}
|
||||||
|
indexSearcher.search(q.build(),
|
||||||
|
classQueryCountCollector);
|
||||||
|
docCount = classQueryCountCollector.getTotalHits();
|
||||||
|
}
|
||||||
|
return docCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input
|
||||||
|
*
|
||||||
|
* @param tokenizedText the tokenized content of a field
|
||||||
|
* @return a {@code String} array of the resulting tokens
|
||||||
|
* @throws java.io.IOException If tokenization fails because there is a low-level I/O error
|
||||||
|
*/
|
||||||
|
protected String[] getTokenArray(TokenStream tokenizedText) throws IOException {
|
||||||
|
Collection<String> tokens = new LinkedList<>();
|
||||||
|
CharTermAttribute charTermAttribute = tokenizedText.addAttribute(CharTermAttribute.class);
|
||||||
|
tokenizedText.reset();
|
||||||
|
while (tokenizedText.incrementToken()) {
|
||||||
|
tokens.add(charTermAttribute.toString());
|
||||||
|
}
|
||||||
|
tokenizedText.end();
|
||||||
|
tokenizedText.close();
|
||||||
|
return tokens.toArray(new String[tokens.size()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param tokenizedText the tokenized content of a field
|
||||||
|
* @param fieldName the input field name
|
||||||
|
* @param c the class to calculate the score of
|
||||||
|
* @param docsWithClass the total number of docs that have a class
|
||||||
|
* @return a normalized score for the class
|
||||||
|
* @throws IOException If there is a low-level I/O error
|
||||||
|
*/
|
||||||
|
private double calculateLogLikelihood(String[] tokenizedText, String fieldName, BytesRef c, int docsWithClass) throws IOException {
|
||||||
|
// for each word
|
||||||
|
double result = 0d;
|
||||||
|
for (String word : tokenizedText) {
|
||||||
|
// search with text:word AND class:c
|
||||||
|
int hits = getWordFreqForClass(word, fieldName, c);
|
||||||
|
|
||||||
|
// num : count the no of times the word appears in documents of class c (+1)
|
||||||
|
double num = hits + 1; // +1 is added because of add 1 smoothing
|
||||||
|
|
||||||
|
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
|
||||||
|
double den = getTextTermFreqForClass(c, fieldName) + docsWithClass;
|
||||||
|
|
||||||
|
// P(w|c) = num/den
|
||||||
|
double wordProbability = num / den;
|
||||||
|
result += Math.log(wordProbability);
|
||||||
|
}
|
||||||
|
|
||||||
|
// log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c))
|
||||||
|
double normScore = result / (tokenizedText.length); // this is normalized because if not, long text fields will always be more important than short fields
|
||||||
|
return normScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the average number of unique terms times the number of docs belonging to the input class
|
||||||
|
*
|
||||||
|
* @param c the class
|
||||||
|
* @return the average number of unique terms
|
||||||
|
* @throws java.io.IOException If there is a low-level I/O error
|
||||||
|
*/
|
||||||
|
private double getTextTermFreqForClass(BytesRef c, String fieldName) throws IOException {
|
||||||
|
double avgNumberOfUniqueTerms;
|
||||||
|
Terms terms = MultiFields.getTerms(leafReader, fieldName);
|
||||||
|
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
|
||||||
|
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
|
||||||
|
int docsWithC = leafReader.docFreq(new Term(classFieldName, c));
|
||||||
|
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the number of documents of the input class ( from the whole index or from a subset)
|
||||||
|
* that contains the word ( in a specific field or in all the fields if no one selected)
|
||||||
|
*
|
||||||
|
* @param word the token produced by the analyzer
|
||||||
|
* @param fieldName the field the word is coming from
|
||||||
|
* @param c the class
|
||||||
|
* @return number of documents of the input class
|
||||||
|
* @throws java.io.IOException If there is a low-level I/O error
|
||||||
|
*/
|
||||||
|
private int getWordFreqForClass(String word, String fieldName, BytesRef c) throws IOException {
|
||||||
|
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
|
||||||
|
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
|
||||||
|
subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD));
|
||||||
|
booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
|
||||||
|
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
|
||||||
|
if (query != null) {
|
||||||
|
booleanQuery.add(query, BooleanClause.Occur.MUST);
|
||||||
|
}
|
||||||
|
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
|
||||||
|
indexSearcher.search(booleanQuery.build(), totalHitCountCollector);
|
||||||
|
return totalHitCountCollector.getTotalHits();
|
||||||
|
}
|
||||||
|
|
||||||
|
private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
|
||||||
|
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int docCount(BytesRef countedClass) throws IOException {
|
||||||
|
return leafReader.docFreq(new Term(classFieldName, countedClass));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
/**
|
||||||
|
* Uses already seen data (the indexed documents) to classify new documents.
|
||||||
|
* <p>
|
||||||
|
* Currently contains a (simplistic) Naive Bayes classifier and a k-Nearest
|
||||||
|
* Neighbor classifier.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.classification.document;
|
|
@ -16,7 +16,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Uses already seen data (the indexed documents) to classify new documents.
|
* Uses already seen data (the indexed documents) to classify an input ( can be simple text or a structured document).
|
||||||
* <p>
|
* <p>
|
||||||
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
|
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
|
||||||
* Neighbor classifier and a Perceptron based classifier.
|
* Neighbor classifier and a Perceptron based classifier.
|
||||||
|
|
|
@ -57,8 +57,8 @@ public abstract class ClassificationTestBase<T> extends LuceneTestCase {
|
||||||
protected static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
protected static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
|
||||||
|
|
||||||
protected RandomIndexWriter indexWriter;
|
protected RandomIndexWriter indexWriter;
|
||||||
private Directory dir;
|
protected Directory dir;
|
||||||
private FieldType ft;
|
protected FieldType ft;
|
||||||
|
|
||||||
protected String textFieldName;
|
protected String textFieldName;
|
||||||
protected String categoryFieldName;
|
protected String categoryFieldName;
|
||||||
|
|
|
@ -0,0 +1,259 @@
|
||||||
|
package org.apache.lucene.classification.document;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
|
import org.apache.lucene.analysis.en.EnglishAnalyzer;
|
||||||
|
import org.apache.lucene.classification.ClassificationResult;
|
||||||
|
import org.apache.lucene.classification.ClassificationTestBase;
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.document.Field;
|
||||||
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.RandomIndexWriter;
|
||||||
|
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base class for testing {@link org.apache.lucene.classification.Classifier}s
|
||||||
|
*/
|
||||||
|
public abstract class DocumentClassificationTestBase<T> extends ClassificationTestBase {
|
||||||
|
|
||||||
|
protected static final BytesRef VIDEOGAME_RESULT = new BytesRef("videogames");
|
||||||
|
protected static final BytesRef VIDEOGAME_ANALYZED_RESULT = new BytesRef("videogam");
|
||||||
|
protected static final BytesRef BATMAN_RESULT = new BytesRef("batman");
|
||||||
|
|
||||||
|
protected String titleFieldName = "title";
|
||||||
|
protected String authorFieldName = "author";
|
||||||
|
|
||||||
|
protected Analyzer analyzer;
|
||||||
|
protected Map<String, Analyzer> field2analyzer;
|
||||||
|
protected LeafReader leafReader;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void init() throws IOException {
|
||||||
|
analyzer = new EnglishAnalyzer();
|
||||||
|
field2analyzer = new LinkedHashMap<>();
|
||||||
|
field2analyzer.put(textFieldName, analyzer);
|
||||||
|
field2analyzer.put(titleFieldName, analyzer);
|
||||||
|
field2analyzer.put(authorFieldName, analyzer);
|
||||||
|
leafReader = populateDocumentClassificationIndex(analyzer);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected double checkCorrectDocumentClassification(DocumentClassifier<T> classifier, Document inputDoc, T expectedResult) throws Exception {
|
||||||
|
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
|
||||||
|
assertNotNull(classificationResult.getAssignedClass());
|
||||||
|
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
|
||||||
|
double score = classificationResult.getScore();
|
||||||
|
assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected LeafReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException {
|
||||||
|
indexWriter.close();
|
||||||
|
indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
|
||||||
|
indexWriter.commit();
|
||||||
|
String text;
|
||||||
|
String title;
|
||||||
|
String author;
|
||||||
|
|
||||||
|
Document doc = new Document();
|
||||||
|
title = "Video games are an economic business";
|
||||||
|
text = "Video games have become an art form and an industry. The video game industry is of increasing" +
|
||||||
|
" commercial importance, with growth driven particularly by the emerging Asian markets and mobile games." +
|
||||||
|
" As of 2015, video games generated sales of USD 74 billion annually worldwide, and were the third-largest" +
|
||||||
|
" segment in the U.S. entertainment market, behind broadcast and cable TV.";
|
||||||
|
author = "Ign";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(categoryFieldName, "videogames", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
|
doc = new Document();
|
||||||
|
title = "Video games: the definition of fun on PC and consoles";
|
||||||
|
text = "A video game is an electronic game that involves human interaction with a user interface to generate" +
|
||||||
|
" visual feedback on a video device. The word video in video game traditionally referred to a raster display device," +
|
||||||
|
"[1] but it now implies any type of display device that can produce two- or three-dimensional images." +
|
||||||
|
" The electronic systems used to play video games are known as platforms; examples of these are personal" +
|
||||||
|
" computers and video game consoles. These platforms range from large mainframe computers to small handheld devices." +
|
||||||
|
" Specialized video games such as arcade games, while previously common, have gradually declined in use.";
|
||||||
|
author = "Ign";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(categoryFieldName, "videogames", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
|
doc = new Document();
|
||||||
|
title = "Video games: the history across PC, consoles and fun";
|
||||||
|
text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
|
||||||
|
" from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
|
||||||
|
" and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
|
||||||
|
"Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
|
||||||
|
" dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
|
||||||
|
author = "Ign";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(categoryFieldName, "videogames", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
|
doc = new Document();
|
||||||
|
title = "Video games: the history";
|
||||||
|
text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
|
||||||
|
" from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
|
||||||
|
" and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
|
||||||
|
"Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
|
||||||
|
" dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
|
||||||
|
author = "Ign";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(categoryFieldName, "videogames", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
|
doc = new Document();
|
||||||
|
title = "Batman: Arkham Knight PC Benchmarks, For What They're Worth";
|
||||||
|
text = "Although I didn’t spend much time playing Batman: Arkham Origins, I remember the game rather well after" +
|
||||||
|
" testing it on no less than 30 graphics cards and 20 CPUs. Arkham Origins appeared to take full advantage of" +
|
||||||
|
" Unreal Engine 3, it ran smoothly on affordable GPUs, though it’s worth remembering that Origins was developed " +
|
||||||
|
"for last-gen consoles.This week marked the arrival of Batman: Arkham Knight, the fourth entry in WB’s Batman:" +
|
||||||
|
" Arkham series and a direct sequel to 2013’s Arkham Origins 2011’s Arkham City." +
|
||||||
|
"Arkham Knight is also powered by Unreal Engine 3, but you can expect noticeably improved graphics, in part because" +
|
||||||
|
" the PlayStation 4 and Xbox One have replaced the PS3 and 360 as the lowest common denominator.";
|
||||||
|
author = "Rocksteady Studios";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(categoryFieldName, "batman", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
|
doc = new Document();
|
||||||
|
title = "Face-Off: Batman: Arkham Knight, the Dark Knight returns!";
|
||||||
|
text = "Despite the drama surrounding the PC release leading to its subsequent withdrawal, there's a sense of success" +
|
||||||
|
" in the console space as PlayStation 4 owners, and indeed those on Xbox One, get a superb rendition of Batman:" +
|
||||||
|
" Arkham Knight. It's fair to say Rocksteady sized up each console's strengths well ahead of producing its first" +
|
||||||
|
" current-gen title, and it's paid off in one of the best Batman games we've seen in years.";
|
||||||
|
author = "Rocksteady Studios";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(categoryFieldName, "batman", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
|
doc = new Document();
|
||||||
|
title = "Batman: Arkham Knight Having More Trouble, But This Time not in Gotham";
|
||||||
|
text = "As news began to break about the numerous issues affecting the PC version of Batman: Arkham Knight, players" +
|
||||||
|
" of the console version breathed a sigh of relief and got back to playing the game. Now players of the PlayStation" +
|
||||||
|
" 4 version are having problems of their own, albeit much less severe ones." +
|
||||||
|
"This time Batman will have a difficult time in Gotham.";
|
||||||
|
author = "Rocksteady Studios";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(categoryFieldName, "batman", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
|
doc = new Document();
|
||||||
|
title = "Batman: Arkham Knight the new legend of Gotham";
|
||||||
|
text = "As news began to break about the numerous issues affecting the PC version of the game, players" +
|
||||||
|
" of the console version breathed a sigh of relief and got back to play. Now players of the PlayStation" +
|
||||||
|
" 4 version are having problems of their own, albeit much less severe ones.";
|
||||||
|
author = "Rocksteady Studios";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(categoryFieldName, "batman", ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
|
|
||||||
|
doc = new Document();
|
||||||
|
text = "unlabeled doc";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
indexWriter.addDocument(doc);
|
||||||
|
|
||||||
|
indexWriter.commit();
|
||||||
|
return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Document getVideoGameDocument() {
|
||||||
|
Document doc = new Document();
|
||||||
|
String title = "The new generation of PC and Console Video games";
|
||||||
|
String text = "Recently a lot of games have been released for the latest generations of consoles and personal computers." +
|
||||||
|
"One of them is Batman: Arkham Knight released recently on PS4, X-box and personal computer." +
|
||||||
|
"Another important video game that will be released in November is Assassin's Creed, a classic series that sees its new installement on Halloween." +
|
||||||
|
"Recently a lot of problems affected the Assassin's creed series but this time it should ran smoothly on affordable GPUs." +
|
||||||
|
"Players are waiting for the versions of their favourite video games and so do we.";
|
||||||
|
String author = "Ign";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Document getBatmanDocument() {
|
||||||
|
Document doc = new Document();
|
||||||
|
String title = "Batman: Arkham Knight new adventures for the super hero across Gotham, the Dark Knight has returned!";
|
||||||
|
String title2 = "I am a second title !";
|
||||||
|
String text = "This game is the electronic version of the famous super hero adventures.It involves the interaction with the open world" +
|
||||||
|
" of the city of Gotham. Finally the player will be able to have fun on its personal device." +
|
||||||
|
" The three-dimensional images of the game are stunning, because it uses the Unreal Engine 3." +
|
||||||
|
" The systems available are PS4, X-Box and personal computer." +
|
||||||
|
" Will the simulate missile that is going to be fired, success ?\" +\n" +
|
||||||
|
" Will this video game make the history" +
|
||||||
|
" Help you favourite super hero to defeat all his enemies. The Dark Knight has returned !";
|
||||||
|
String author = "Rocksteady Studios";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title2, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Document getBatmanAmbiguosDocument() {
|
||||||
|
Document doc = new Document();
|
||||||
|
String title = "Batman: Arkham Knight new adventures for the super hero across Gotham, the Dark Knight has returned! Batman will win !";
|
||||||
|
String text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
|
||||||
|
" from 1947—a device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
|
||||||
|
" and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
|
||||||
|
"Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
|
||||||
|
" dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
|
||||||
|
String author = "Ign";
|
||||||
|
doc.add(new Field(textFieldName, text, ft));
|
||||||
|
doc.add(new Field(titleFieldName, title, ft));
|
||||||
|
doc.add(new Field(authorFieldName, author, ft));
|
||||||
|
doc.add(new Field(booleanFieldName, "false", ft));
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,96 @@
|
||||||
|
package org.apache.lucene.classification.document;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.search.TermQuery;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests for {@link org.apache.lucene.classification.KNearestNeighborClassifier}
|
||||||
|
*/
|
||||||
|
public class KNearestNeighborDocumentClassifierTest extends DocumentClassificationTestBase<BytesRef> {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicDocumentClassification() throws Exception {
|
||||||
|
try {
|
||||||
|
Document videoGameDocument = getVideoGameDocument();
|
||||||
|
Document batmanDocument = getBatmanDocument();
|
||||||
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||||
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||||
|
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
||||||
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||||
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicDocumentClassificationScore() throws Exception {
|
||||||
|
try {
|
||||||
|
Document videoGameDocument = getVideoGameDocument();
|
||||||
|
Document batmanDocument = getBatmanDocument();
|
||||||
|
double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
|
||||||
|
assertEquals(1.0,score1,0);
|
||||||
|
double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
|
||||||
|
assertEquals(1.0,score2,0);
|
||||||
|
// considering only the text we have wrong classification because the text was ambiguos on purpose
|
||||||
|
double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
|
||||||
|
assertEquals(1.0,score3,0);
|
||||||
|
double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
|
||||||
|
assertEquals(1.0,score4,0);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBoostedDocumentClassification() throws Exception {
|
||||||
|
try {
|
||||||
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
||||||
|
// considering without boost wrong classification will appear
|
||||||
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicDocumentClassificationWithQuery() throws Exception {
|
||||||
|
try {
|
||||||
|
TermQuery query = new TermQuery(new Term(authorFieldName, "ign"));
|
||||||
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT);
|
||||||
|
checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
package org.apache.lucene.classification.document;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests for {@link org.apache.lucene.classification.SimpleNaiveBayesClassifier}
|
||||||
|
*/
|
||||||
|
public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificationTestBase<BytesRef> {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicDocumentClassification() throws Exception {
|
||||||
|
try {
|
||||||
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||||
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
||||||
|
|
||||||
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
||||||
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicDocumentClassificationScore() throws Exception {
|
||||||
|
try {
|
||||||
|
double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||||
|
assertEquals(0.88,score1,0.01);
|
||||||
|
double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
|
||||||
|
assertEquals(0.89,score2,0.01);
|
||||||
|
//taking in consideration only the text
|
||||||
|
double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
|
||||||
|
assertEquals(0.55,score3,0.01);
|
||||||
|
double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||||
|
assertEquals(0.52,score4,0.01);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBoostedDocumentClassification() throws Exception {
|
||||||
|
try {
|
||||||
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
|
||||||
|
// considering without boost wrong classification will appear
|
||||||
|
checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT);
|
||||||
|
} finally {
|
||||||
|
if (leafReader != null) {
|
||||||
|
leafReader.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue